1 //
2 // Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include <armnn/backends/TensorHandle.hpp>
7 #include <armnn/backends/WorkloadData.hpp>
8 #include <armnn/backends/WorkloadInfo.hpp>
9 #include <armnnUtils/DataLayoutIndexed.hpp>
10 #include <armnnUtils/TensorUtils.hpp>
11 #include <armnnUtils/Permute.hpp>
12 #include <armnn/utility/NumericCast.hpp>
13 #include <armnn/Logging.hpp>
14
15 #include <algorithm>
16 #include <iomanip>
17 #include <string>
18 #include <sstream>
19
20 #include <fmt/format.h>
21
22 using namespace armnnUtils;
23
24 namespace armnn
25 {
26
27 //---------------------------------------------------------------
GetBiasDataType(DataType inputDataType)28 DataType GetBiasDataType(DataType inputDataType)
29 {
30 switch (inputDataType)
31 {
32 case DataType::Float16:
33 return DataType::Float16;
34 case DataType::BFloat16:
35 case DataType::Float32:
36 return DataType::Float32;
37 case DataType::QAsymmS8:
38 case DataType::QAsymmU8:
39 case DataType::QSymmS8:
40 case DataType::QSymmS16:
41 return DataType::Signed32;
42 default:
43 ARMNN_ASSERT_MSG(false, "Invalid input data type");
44 return DataType::Float32;
45 }
46 }
47
48 namespace
49 {
50
51 //---------------------------------------------------------------
52 //android ndk does not support std::to_string function.
53 template <typename T>
to_string(T value)54 std::string to_string(T value)
55 {
56 std::ostringstream os;
57 os << value;
58 return os.str();
59 }
60
61 //---------------------------------------------------------------
ValidatePointer(const void * ptr,std::string const & descName,std::string const & paramName)62 void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
63 {
64 if (!ptr)
65 {
66 throw InvalidArgumentException(descName + ": Invalid null pointer. The " +
67 paramName + " parameter must be set.");
68 }
69 }
70
71 //---------------------------------------------------------------
ValidateTensorShapesMatch(const TensorInfo & first,const TensorInfo & second,std::string const & descName,std::string const & firstName,std::string const & secondName)72 void ValidateTensorShapesMatch(const TensorInfo& first,
73 const TensorInfo& second,
74 std::string const& descName,
75 std::string const& firstName,
76 std::string const& secondName)
77 {
78 if (first.GetShape() != second.GetShape())
79 {
80 throw InvalidArgumentException(descName + ": "
81 + firstName + " & " + secondName + " must have identical shapes");
82 }
83 }
84
85 //---------------------------------------------------------------
ValidateNumInputs(const WorkloadInfo & workloadInfo,std::string const & descName,const unsigned int expectedSize)86 void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
87 {
88 if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
89 {
90 throw InvalidArgumentException(descName +
91 ": Requires exactly " + to_string(expectedSize) + "input(s). " +
92 to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
93 }
94 }
95
96 //---------------------------------------------------------------
ValidateNumOutputs(const WorkloadInfo & workloadInfo,std::string const & descName,const unsigned int expectedSize)97 void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
98 {
99 if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
100 {
101 throw InvalidArgumentException(descName +
102 ": Requires exactly " + to_string(expectedSize) + " output(s). " +
103 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
104 }
105 }
106
107 //---------------------------------------------------------------
108
109 //---------------------------------------------------------------
ValidateTensorNumElements(const TensorInfo & tensor,std::string const & descName,unsigned int numElements,std::string const & tensorName)110 void ValidateTensorNumElements(const TensorInfo& tensor,
111 std::string const& descName,
112 unsigned int numElements,
113 std::string const& tensorName)
114 {
115 if (tensor.GetNumElements() != numElements)
116 {
117 throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
118 to_string(tensor.GetNumElements()) + " elements for " +
119 tensorName + " tensor.");
120 }
121 }
122
123 //---------------------------------------------------------------
ValidateTensorDataType(const TensorInfo & tensor,DataType dataType,const std::string & descName,std::string const & tensorName)124 void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
125 const std::string& descName, std::string const& tensorName)
126 {
127 if (tensor.GetDataType() != dataType)
128 {
129 throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
130 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
131 }
132 }
133
ValidPerAxisQuantizedDataType(const TensorInfo & tensor,const std::string & descName,const std::string & tensorName)134 void ValidPerAxisQuantizedDataType(const TensorInfo& tensor, const std::string& descName, const std::string& tensorName)
135 {
136 if (tensor.GetDataType() != DataType::QSymmS8)
137 {
138 throw InvalidArgumentException(descName +
139 ": Expected data type which supports per-axis quantization scheme but got " +
140 GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
141 }
142 }
143
144 //---------------------------------------------------------------
ValidateTensorQuantizationSpace(const TensorInfo & first,const TensorInfo & second,const std::string & descName,std::string const & firstName,std::string const & secondName)145 void ValidateTensorQuantizationSpace(const TensorInfo& first,
146 const TensorInfo& second,
147 const std::string& descName,
148 std::string const& firstName,
149 std::string const& secondName)
150 {
151 if (!first.IsQuantized() ||
152 !second.IsQuantized())
153 {
154 // Not a quantized type, ignore the validation
155 return;
156 }
157
158 DataType firstDataType = first.GetDataType();
159 DataType secondDataType = second.GetDataType();
160
161 if (firstDataType != secondDataType)
162 {
163 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
164 " must be of the same quantized type, " +
165 firstName + " is " + GetDataTypeName(firstDataType) + ", " +
166 secondName + " is " + GetDataTypeName(secondDataType));
167 }
168
169 if (!first.IsTypeSpaceMatch(second))
170 {
171 throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
172 " must have the same quantization space, " +
173 firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
174 " and scale " + to_string(first.GetQuantizationScale()) + ", " +
175 secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
176 " and scale " + to_string(second.GetQuantizationScale()));
177 }
178 }
179
180 //---------------------------------------------------------------
ValidateBiasTensorQuantization(const TensorInfo & biasTensor,const TensorInfo & inputTensorInfo,const TensorInfo & weightsTensorInfo,const std::string & descName)181 void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
182 const TensorInfo& inputTensorInfo,
183 const TensorInfo& weightsTensorInfo,
184 const std::string& descName)
185 {
186 // Helper lambda function to validate a single bias quantization scale value
187 auto VerifyBiasQuantizationScale = [&descName](float biasScale, float expectedScale) -> void
188 {
189 constexpr float tolerance = 0.0001f;
190 if (std::abs(biasScale - expectedScale) > tolerance)
191 {
192 // Print the float values with extra precision to see very small differences
193 ARMNN_LOG(warning) << std::setprecision(6) << descName << ": Expected " << expectedScale <<
194 " for bias quantization scale (product of input and weight scales), but got " <<
195 biasScale << ". Using scale provided.";
196 }
197 };
198
199 if (biasTensor.GetQuantizationOffset() != 0)
200 {
201 throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
202 to_string(biasTensor.GetQuantizationOffset()));
203 }
204
205 if (biasTensor.HasMultipleQuantizationScales() || weightsTensorInfo.HasMultipleQuantizationScales())
206 {
207 // Validate per-axis quantization scales
208 const std::vector<float>& weightScales = weightsTensorInfo.GetQuantizationScales();
209 const std::vector<float>& biasScales = biasTensor.GetQuantizationScales();
210
211 if (weightScales.size() != biasScales.size())
212 {
213 std::stringstream msg;
214 msg << descName << ": Expected matching number of per-axis quantization scales for weights and bias, "
215 << "but got different values. This is currently unsupported: weights=" << weightScales.size()
216 << ", biases=" << biasScales.size();
217 throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
218 }
219
220 for (size_t i = 0ul; i < biasScales.size(); ++i)
221 {
222 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightScales[i];
223 VerifyBiasQuantizationScale(biasScales[i], expectedScale);
224 }
225 }
226 else
227 {
228 // Validate per-tensor quantization scale
229 const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
230 VerifyBiasQuantizationScale(biasTensor.GetQuantizationScale(), expectedScale);
231 }
232 }
233
234 //---------------------------------------------------------------
ValidateTensors(const std::vector<ITensorHandle * > & vec,unsigned int numExpected,const std::string & descName,const std::string & varName)235 void ValidateTensors(const std::vector<ITensorHandle*>& vec,
236 unsigned int numExpected,
237 const std::string& descName,
238 const std::string& varName)
239 {
240 if (vec.empty() && numExpected > 0)
241 {
242 throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
243 }
244
245 for (unsigned int i = 0; i < numExpected; ++i)
246 {
247 if (!vec[i])
248 {
249 throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
250 }
251 }
252 }
253
254 //---------------------------------------------------------------
ValidateBroadcastTensorShapesMatch(const TensorInfo & first,const TensorInfo & second,const TensorInfo & output,std::string const & descName,std::string const & firstName,std::string const & secondName)255 void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
256 const TensorInfo& second,
257 const TensorInfo& output,
258 std::string const& descName,
259 std::string const& firstName,
260 std::string const& secondName)
261 {
262 // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
263 // broadcasted.
264 if (first.GetNumDimensions() != second.GetNumDimensions())
265 {
266 throw InvalidArgumentException(descName + ": Tensors "
267 + firstName + " & " + secondName
268 + " must have the same number of dimensions in order to be broadcasted");
269 }
270 uint32_t numDims = first.GetNumDimensions();
271 std::vector<uint32_t> outputDims(numDims, 0u);
272 for (uint32_t i = 0; i < numDims; i++)
273 {
274 const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
275 const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
276 if (dimsNotEqual && dimsNotOne)
277 {
278 throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
279 }
280 outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
281 }
282 TensorShape broadcastShape = TensorShape(armnn::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
283 if (broadcastShape != output.GetShape())
284 {
285 throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
286 + firstName + " & " + secondName
287 + " does not match the output shape");
288 }
289 }
290
291 //---------------------------------------------------------------
ValidateDataTypes(const TensorInfo & info,const std::vector<armnn::DataType> & supportedTypes,std::string const & descName)292 void ValidateDataTypes(const TensorInfo& info,
293 const std::vector<armnn::DataType>& supportedTypes,
294 std::string const& descName)
295 {
296 auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
297 if (iterator == supportedTypes.end())
298 {
299 throw InvalidArgumentException(descName + ": " + " Tensor type is not supported.");
300 }
301 }
302
303 //---------------------------------------------------------------
ValidateTensorDataTypesMatch(const TensorInfo & first,const TensorInfo & second,std::string const & descName,std::string const & firstName,std::string const & secondName)304 void ValidateTensorDataTypesMatch(const TensorInfo& first,
305 const TensorInfo& second,
306 std::string const& descName,
307 std::string const& firstName,
308 std::string const& secondName)
309 {
310 if (first.GetDataType() != second.GetDataType())
311 {
312 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
313 " must have identical data types.");
314 }
315 }
316
317 //---------------------------------------------------------------
ValidateTensorNumElementsMatch(const TensorInfo & first,const TensorInfo & second,std::string const & descName,std::string const & firstName,std::string const & secondName)318 void ValidateTensorNumElementsMatch(const TensorInfo& first,
319 const TensorInfo& second,
320 std::string const& descName,
321 std::string const& firstName,
322 std::string const& secondName)
323 {
324 if (first.GetNumElements() != second.GetNumElements())
325 {
326 throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
327 " must have the same number of elements.");
328 }
329 }
330
ValidateWeightDataType(const TensorInfo & inputInfo,const TensorInfo & weightInfo,const std::string & descName)331 void ValidateWeightDataType(const TensorInfo& inputInfo,
332 const TensorInfo& weightInfo,
333 const std::string& descName)
334 {
335 const DataType inputType = inputInfo.GetDataType();
336 if (IsQuantized8BitType(inputType))
337 {
338 const std::vector<DataType> validTypes =
339 {
340 DataType::QAsymmS8,
341 DataType::QAsymmU8,
342 DataType::QSymmS8
343 };
344
345 ValidateDataTypes(weightInfo, validTypes, descName);
346 }
347 else
348 {
349 ValidateTensorDataTypesMatch(inputInfo, weightInfo, descName, "input", "weight");
350 }
351 }
352
ValidatePerAxisQuantizationDimension(const TensorInfo & tensorInfo,const std::string & descName,const std::string & tensorName)353 void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo,
354 const std::string& descName,
355 const std::string& tensorName)
356 {
357 const Optional<unsigned int>& quantizationDim = tensorInfo.GetQuantizationDim();
358 if (!quantizationDim.has_value())
359 {
360 throw InvalidArgumentException(fmt::format("{0}: Quantization dimension for per-axis quantization "
361 "not set on tensor {1}.", descName, tensorName));
362 }
363 }
364
ValidatePerAxisQuantizationOffset(const TensorInfo & tensorInfo,const std::string & descName,const std::string & tensorName)365 void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
366 const std::string& descName,
367 const std::string& tensorName)
368 {
369 int32_t quantizationOffset = tensorInfo.GetQuantizationOffset();
370 if (quantizationOffset != 0)
371 {
372 throw InvalidArgumentException(fmt::format(
373 "{0}: Quantization offset for per-axis quantization expected to be 0 on tensor {1}, but got: {2}",
374 descName, tensorName, quantizationOffset));
375 }
376 }
377
ValidatePerAxisQuantization(const TensorInfo & inputInfo,const TensorInfo & outputInfo,const TensorInfo & weightInfo,const Optional<TensorInfo> & optionalBiasInfo,const std::string & descName)378 void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
379 const TensorInfo& outputInfo,
380 const TensorInfo& weightInfo,
381 const Optional<TensorInfo>& optionalBiasInfo,
382 const std::string& descName)
383 {
384 if (weightInfo.HasPerAxisQuantization())
385 {
386 const DataType inputDataType = inputInfo.GetDataType();
387 const DataType outputDataType = outputInfo.GetDataType();
388
389 const bool canHavePerAxisQuantization = (IsQuantized8BitType(inputDataType)) && inputDataType == outputDataType;
390
391 if (!canHavePerAxisQuantization)
392 {
393 throw InvalidArgumentException(fmt::format(
394 "{0}: Per-axis quantization parameters set on tensor {1}, but data type does not support "
395 "per-axis quantization.", descName, "weight"));
396 }
397
398
399 ValidPerAxisQuantizedDataType(weightInfo, descName, "weight");
400 ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
401 ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
402
403 if (optionalBiasInfo.has_value())
404 {
405 const TensorInfo& biasInfo = optionalBiasInfo.value();
406 if (!biasInfo.HasPerAxisQuantization())
407 {
408 throw InvalidArgumentException(fmt::format(
409 "{}: Per-axis quantization parameters not set on bias tensor, "
410 "despite being set on weight tensor.", descName));
411 }
412
413 ValidateTensorDataType(biasInfo, DataType::Signed32, descName, "bias");
414 ValidatePerAxisQuantizationDimension(biasInfo, descName, "bias");
415 ValidatePerAxisQuantizationOffset(biasInfo, descName, "bias");
416 }
417 }
418 }
419
420 } // anonymous namespace
421
422 //---------------------------------------------------------------
ValidateTensorNumDimensions(const TensorInfo & tensor,std::string const & descName,unsigned int numDimensions,std::string const & tensorName) const423 void QueueDescriptor::ValidateTensorNumDimensions(const TensorInfo& tensor,
424 std::string const& descName,
425 unsigned int numDimensions,
426 std::string const& tensorName) const
427 {
428 // If we're allowing expanded dimensions then numDimensions becomes the minimum number of Dimensions we can allow.
429 // Throw an Exception if the tensors has fewer than numDimensions or if the squeezed dimensions are greater than
430 // numDimensions.
431 if (m_AllowExpandedDims)
432 {
433 unsigned int squeezedDims = 0;
434
435 for (unsigned int i = 0; i < tensor.GetNumDimensions(); ++i)
436 {
437 if (tensor.GetShape()[i] != 1)
438 {
439 ++squeezedDims;
440 }
441 }
442 if (tensor.GetNumDimensions() < numDimensions || squeezedDims > numDimensions)
443 {
444 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " or less but got " +
445 to_string(tensor.GetNumDimensions()) + " dimensions for " +
446 tensorName + " tensor.");
447 }
448 }
449 else
450 {
451 if (tensor.GetNumDimensions() != numDimensions)
452 {
453 throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
454 to_string(tensor.GetNumDimensions()) + " dimensions for " +
455 tensorName + " tensor.");
456 }
457 }
458 }
459
460 //---------------------------------------------------------------
ValidateTensorNumDimNumElem(const TensorInfo & tensorInfo,unsigned int numDimension,unsigned int numElements,std::string const & tensorName) const461 void QueueDescriptor::ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
462 unsigned int numDimension,
463 unsigned int numElements,
464 std::string const& tensorName) const
465 {
466 const std::string functionName{"ValidateTensorNumDimNumElem"};
467 ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
468 ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
469 }
470
471 //---------------------------------------------------------------
ValidateInputsOutputs(const std::string & descName,unsigned int numExpectedIn,unsigned int numExpectedOut) const472 void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
473 unsigned int numExpectedIn, unsigned int numExpectedOut) const
474 {
475 ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
476 ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
477 }
478
479 //---------------------------------------------------------------
Validate(const WorkloadInfo & workloadInfo) const480 void MapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
481 {
482 const std::string descriptorName{"MapQueueDescriptor"};
483
484 ValidateNumInputs(workloadInfo, descriptorName, 1);
485 ValidateNumOutputs(workloadInfo, descriptorName, 0);
486
487 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
488 {
489 if (!m_Inputs[i])
490 {
491 throw InvalidArgumentException(
492 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
493 }
494 }
495 }
496
497 //---------------------------------------------------------------
Validate(const WorkloadInfo & workloadInfo) const498 void UnmapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
499 {
500 const std::string descriptorName{"UnmapQueueDescriptor"};
501
502 ValidateNumInputs(workloadInfo, descriptorName, 1);
503 ValidateNumOutputs(workloadInfo, descriptorName, 0);
504
505 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
506 {
507 if (!m_Inputs[i])
508 {
509 throw InvalidArgumentException(
510 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
511 }
512 }
513 }
514
515 //---------------------------------------------------------------
Validate(const WorkloadInfo & workloadInfo) const516 void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
517 {
518 const std::string descriptorName{"MemCopyQueueDescriptor"};
519
520 ValidateNumInputs(workloadInfo, descriptorName, 1);
521 ValidateNumOutputs(workloadInfo, descriptorName , 1);
522
523 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
524 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
525
526 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
527 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
528
529 if (m_Inputs.size() != m_Outputs.size())
530 {
531 throw InvalidArgumentException(fmt::format(
532 "{0}: Number of inputs ({1}) does not match the number of outputs ({2}).",
533 descriptorName, m_Inputs.size(), m_Outputs.size()));
534 }
535
536 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
537 {
538 if (!m_Inputs[i])
539 {
540 throw InvalidArgumentException(fmt::format(
541 "{0}: Invalid NULL input {1}.", descriptorName, i));
542 }
543
544 if (!m_Outputs[i])
545 {
546 throw InvalidArgumentException(fmt::format("{0}: Invalid NULL output {1}", descriptorName, i));
547 }
548 }
549 }
550
551 //---------------------------------------------------------------
Validate(const WorkloadInfo & workloadInfo) const552 void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
553 {
554 ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
555 ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
556
557 if (workloadInfo.m_InputTensorInfos.size() != 1)
558 {
559 throw InvalidArgumentException(fmt::format("Number of input infos ({}) is not 1.",
560 workloadInfo.m_InputTensorInfos.size()));
561
562 }
563
564 if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
565 {
566 throw InvalidArgumentException(fmt::format(
567 "Number of input infos ({0}) does not match the number of output infos ({1})",
568 workloadInfo.m_InputTensorInfos.size(), workloadInfo.m_OutputTensorInfos.size()));
569 }
570
571 for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
572 {
573 if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
574 workloadInfo.m_OutputTensorInfos[i].GetNumElements())
575 {
576 throw InvalidArgumentException(fmt::format(
577 "Number of elements for tensor input and output {} does not match", i ));
578 }
579 }
580
581 if (m_Inputs.size() != 1)
582 {
583 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
584 }
585
586 if (m_Inputs.size() != m_Outputs.size())
587 {
588 throw InvalidArgumentException(fmt::format(
589 "Number of inputs ({0}) does not match the number of outputs ({1})",
590 m_Inputs.size(), m_Outputs.size()));
591 }
592
593 for (unsigned int i = 0; i < m_Inputs.size(); ++i)
594 {
595 if (!m_Inputs[i])
596 {
597 throw InvalidArgumentException(fmt::format("Invalid null input {}", i));
598 }
599
600 if (!m_Outputs[i])
601 {
602 throw InvalidArgumentException(fmt::format("Invalid null output {}", i));
603 }
604 }
605 }
606
607 //---------------------------------------------------------------
Validate(const WorkloadInfo & workloadInfo) const608 void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
609 {
610 ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
611
612 if (m_Inputs.size() != 1)
613 {
614 throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
615 }
616
617 if (m_Outputs.size() != 0)
618 {
619 throw InvalidArgumentException(fmt::format("Number of outputs ({}) is not 0.", m_Outputs.size()));
620 }
621
622 if (!m_Inputs[0])
623 {
624 throw InvalidArgumentException(fmt::format("Invalid null input 0"));
625 }
626 }
627
628 //---------------------------------------------------------------
Validate(const WorkloadInfo & workloadInfo) const629 void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
630 {
631 const std::string descriptorName{"ActivationQueueDescriptor"};
632
633 ValidateNumInputs(workloadInfo, descriptorName, 1);
634 ValidateNumOutputs(workloadInfo, descriptorName, 1);
635
636 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
637 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
638
639 std::vector<DataType> supportedTypes =
640 {
641 DataType::BFloat16,
642 DataType::Float16,
643 DataType::Float32,
644 DataType::QAsymmS8,
645 DataType::QAsymmU8,
646 DataType::QSymmS16
647 };
648
649 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
650 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
651 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
652 }
653
Validate(const WorkloadInfo & workloadInfo) const654 void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
655 {
656 const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
657
658 ValidateNumInputs(workloadInfo, descriptorName, 1);
659 ValidateNumOutputs(workloadInfo, descriptorName, 1);
660
661 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
662 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
663
664 if (outputTensorInfo.GetDataType() != DataType::Signed32 &&
665 outputTensorInfo.GetDataType() != DataType::Signed64)
666 {
667 throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32 or Int64.");
668 }
669
670 std::vector<DataType> supportedInputTypes =
671 {
672 DataType::BFloat16,
673 DataType::Float16,
674 DataType::Float32,
675 DataType::QAsymmS8,
676 DataType::QAsymmU8,
677 DataType::QSymmS16,
678 DataType::Signed32,
679 DataType::Signed64
680 };
681
682 ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
683
684 auto inputShape = inputTensorInfo.GetShape();
685 auto outputShape = outputTensorInfo.GetShape();
686
687 auto inputNumDimensions = inputShape.GetNumDimensions();
688 auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
689
690 const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
691
692 // 1D input shape results in scalar output shape
693 if (inputShape.GetNumDimensions() == 1)
694 {
695 if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
696 {
697 throw InvalidArgumentException(descriptorName + outputShapeError);
698 }
699 }
700 else
701 {
702 for (unsigned int i = 0; i < unsignedAxis; ++i)
703 {
704 if (outputShape[i] != inputShape[i])
705 {
706 throw InvalidArgumentException(descriptorName + outputShapeError);
707 }
708 }
709
710 for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
711 {
712 if (outputShape[i - 1] != inputShape[i])
713 {
714 throw InvalidArgumentException(descriptorName + outputShapeError);
715 }
716 }
717 }
718 }
719
Validate(const WorkloadInfo & workloadInfo) const720 void CastQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
721 {
722 const std::string descriptorName{"CastQueueDescriptor"};
723
724 ValidateNumInputs(workloadInfo, descriptorName, 1);
725 ValidateNumOutputs(workloadInfo, descriptorName, 1);
726
727 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
728 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
729
730 std::vector<DataType> supportedTypes =
731 {
732 DataType::BFloat16,
733 DataType::Float16,
734 DataType::Float32,
735 DataType::QAsymmS8,
736 DataType::QAsymmU8,
737 DataType::QSymmS8,
738 DataType::QSymmS16,
739 DataType::Signed32,
740 DataType::Signed64
741 };
742
743 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
744 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
745 }
746
Validate(const WorkloadInfo & workloadInfo) const747 void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
748 {
749 const std::string descriptorName{"SoftmaxQueueDescriptor"};
750
751 ValidateNumInputs(workloadInfo, descriptorName, 1);
752 ValidateNumOutputs(workloadInfo, descriptorName, 1);
753
754 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
755 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
756
757 std::vector<DataType> supportedTypes =
758 {
759 DataType::BFloat16,
760 DataType::Float16,
761 DataType::Float32,
762 DataType::QAsymmS8,
763 DataType::QAsymmU8,
764 DataType::QSymmS16
765 };
766
767 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
768 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
769 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
770 }
771
Validate(const WorkloadInfo & workloadInfo) const772 void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
773 {
774 const std::string descriptorName{"SplitterQueueDescriptor"};
775
776 ValidateNumInputs(workloadInfo, descriptorName, 1);
777
778 // Check the supported data types
779 std::vector<DataType> supportedTypes =
780 {
781 DataType::BFloat16,
782 DataType::Float32,
783 DataType::Float16,
784 DataType::Boolean,
785 DataType::Signed32,
786 DataType::QAsymmS8,
787 DataType::QAsymmU8,
788 DataType::QSymmS16
789 };
790
791 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
792 for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
793 {
794 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
795 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
796
797 const std::string outputName = "output_" + std::to_string(i);
798 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
799 }
800
801 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
802 {
803 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
804 }
805
806 if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
807 {
808 throw InvalidArgumentException(
809 descriptorName + ": Number of split windows "
810 "has to match number of workloadInfo.m_OutputTensorInfos. "
811 "Number of windows: " +
812 to_string(m_ViewOrigins.size()) +
813 ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
814 }
815
816 //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
817 std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
818 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
819 {
820 //Checks that the dimensionality of input is same as the split windows.
821 ViewOrigin const& e = m_ViewOrigins[w];
822 if (e.m_Origin.size() != inputDims)
823 {
824 throw InvalidArgumentException(descriptorName + ": Window origin have to "
825 "have the same dimensionality as the input tensor. "
826 "Window origin (index: " +
827 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
828 " dimensions, the input "
829 "tensor has " +
830 to_string(inputDims) + " dimensions.");
831 }
832 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
833 {
834 if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
835 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
836 {
837 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
838 "be smaller or equal than the size of the input in that coord.");
839 }
840 }
841 }
842 }
843
Validate(const WorkloadInfo & workloadInfo) const844 void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
845 {
846 const std::string descriptorName{"ConcatQueueDescriptor"};
847
848 ValidateNumOutputs(workloadInfo, descriptorName, 1);
849
850 if (m_Inputs.size() <= 0)
851 {
852 throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
853 }
854 if (m_Outputs.size() <= 0)
855 {
856 throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
857 }
858
859 if (workloadInfo.m_InputTensorInfos.size() <= 0)
860 {
861 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
862 }
863 if (workloadInfo.m_OutputTensorInfos.size() <= 0)
864 {
865 throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
866 }
867
868 if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
869 {
870 throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
871 }
872
873 if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
874 {
875 return;
876 }
877
878 if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
879 {
880 throw InvalidArgumentException(
881 descriptorName + ": Number of split windows "
882 "has to match number of workloadInfo.m_InputTensorInfos. "
883 "Number of windows: " +
884 to_string(m_ViewOrigins.size()) +
885 ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
886 }
887
888 //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
889 std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
890 for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
891 {
892 //Checks that the dimensionality of output is same as the split windows.
893 ViewOrigin const& e = m_ViewOrigins[w];
894 if (e.m_Origin.size() != outputDims)
895 {
896 throw InvalidArgumentException(descriptorName + ": Window origin have to "
897 "have the same dimensionality as the output tensor. "
898 "Window origin (index: " +
899 to_string(w) + ") has " + to_string(e.m_Origin.size()) +
900 " dimensions, the output "
901 "tensor has " +
902 to_string(outputDims) + " dimensions.");
903 }
904 //Checks that the merge windows are within the output tensor.
905 for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
906 {
907 if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
908 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
909 {
910 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
911 "be smaller or equal than the size of the output in that coord.");
912 }
913 }
914 }
915
916 // Check the supported data types
917 std::vector<DataType> supportedTypes =
918 {
919 DataType::BFloat16,
920 DataType::Float32,
921 DataType::Float16,
922 DataType::Boolean,
923 DataType::Signed32,
924 DataType::QAsymmS8,
925 DataType::QAsymmU8,
926 DataType::QSymmS16
927 };
928
929 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
930 for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
931 {
932 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
933 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
934
935 const std::string inputName = "input_" + std::to_string(i);
936 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
937 }
938 }
939
Validate(const WorkloadInfo & workloadInfo) const940 void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
941 {
942 const std::string descriptorName{"StackQueueDescriptor"};
943
944 ValidateNumOutputs(workloadInfo, descriptorName, 1);
945
946 if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
947 {
948 throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
949 }
950
951 // All inputs must have the same shape, which is defined in parameters
952 const TensorShape& inputShape = m_Parameters.m_InputShape;
953 for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
954 {
955 if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
956 {
957 throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
958 }
959 }
960
961 if (inputShape.GetNumDimensions() > 4)
962 {
963 throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
964 }
965
966 // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
967 // since the output tensor has an additional dimension.
968 if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
969 {
970 throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
971 "than the number of input dimensions.");
972 }
973
974 // Output shape must be as inferred from the input shape
975 const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
976 for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
977 {
978 if (outputShape[i] != inputShape[i])
979 {
980 throw InvalidArgumentException(descriptorName + ": Output tensor must "
981 "match shape inferred from input tensor.");
982 }
983 }
984
985 if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
986 {
987 throw InvalidArgumentException(descriptorName + ": Output tensor must "
988 "match shape inferred from input tensor.");
989 }
990
991 for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
992 {
993 if (outputShape[i] != inputShape[i-1])
994 {
995 throw InvalidArgumentException(descriptorName + ": Output tensor must "
996 "match shape inferred from input tensor.");
997 }
998 }
999
1000 if (outputShape.GetNumDimensions() > 5)
1001 {
1002 throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
1003 }
1004
1005 // Check the supported data types
1006 std::vector<DataType> supportedTypes =
1007 {
1008 DataType::BFloat16,
1009 DataType::Float32,
1010 DataType::Float16,
1011 DataType::Boolean,
1012 DataType::Signed32,
1013 DataType::QAsymmS8,
1014 DataType::QAsymmU8,
1015 DataType::QSymmS16
1016 };
1017
1018 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1019
1020 for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
1021 {
1022 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1023 workloadInfo.m_InputTensorInfos[i],
1024 descriptorName,
1025 "input_0",
1026 "input_" + std::to_string(i));
1027 }
1028
1029 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1030 workloadInfo.m_OutputTensorInfos[0],
1031 descriptorName,
1032 "input_0",
1033 "output");
1034 }
1035
Validate(const WorkloadInfo & workloadInfo) const1036 void FillQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1037 {
1038 const std::string descriptorName{"FillQueueDescriptor"};
1039
1040 ValidateNumInputs(workloadInfo, descriptorName, 1);
1041 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1042
1043 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1044 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1045
1046 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 1, "input");
1047
1048 std::vector<DataType> supportedTypes =
1049 {
1050 DataType::BFloat16,
1051 DataType::Float32,
1052 DataType::Float16,
1053 DataType::Signed32
1054 };
1055
1056 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1057 }
1058
Validate(const WorkloadInfo & workloadInfo) const1059 void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1060 {
1061 const std::string descriptorName{"FullyConnectedQueueDescriptor"};
1062
1063 uint32_t numInputs = 2;
1064 if (m_Parameters.m_BiasEnabled)
1065 {
1066 numInputs = 3;
1067 }
1068
1069 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
1070 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1071
1072 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1073 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1074
1075 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1076
1077 if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
1078 {
1079 throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
1080 }
1081
1082 TensorInfo weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
1083 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
1084
1085 if (m_Parameters.m_BiasEnabled)
1086 {
1087 TensorInfo biasTensorInfo = workloadInfo.m_InputTensorInfos[2];
1088 // Validates type and quantization values.
1089 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1090 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1091 ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
1092 }
1093
1094 // Check the supported data types
1095 std::vector<DataType> supportedTypes =
1096 {
1097 DataType::BFloat16,
1098 DataType::Float32,
1099 DataType::Float16,
1100 DataType::QAsymmS8,
1101 DataType::QAsymmU8,
1102 DataType::QSymmS16
1103 };
1104
1105 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1106
1107 // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
1108 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1109 {
1110 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1111 {
1112 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1113 "for BFloat16 input.");
1114 }
1115 }
1116 else
1117 {
1118 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1119 }
1120 }
1121
Validate(const WorkloadInfo & workloadInfo) const1122 void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1123 {
1124 const std::string descriptorName{"NormalizationQueueDescriptor"};
1125
1126 ValidateNumInputs(workloadInfo, descriptorName, 1);
1127 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1128
1129 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1130 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1131
1132 // Check the supported data types
1133 std::vector<DataType> supportedTypes =
1134 {
1135 DataType::BFloat16,
1136 DataType::Float16,
1137 DataType::Float32,
1138 DataType::QAsymmS8,
1139 DataType::QAsymmU8,
1140 DataType::QSymmS16
1141 };
1142
1143 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1144
1145 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1146
1147 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1148 }
1149
Validate(const WorkloadInfo & workloadInfo) const1150 void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1151 {
1152 const std::string descriptorName{"AdditionQueueDescriptor"};
1153
1154 ValidateNumInputs(workloadInfo, descriptorName, 2);
1155 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1156
1157 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1158 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1159 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1160
1161 std::vector<DataType> supportedTypes =
1162 {
1163 DataType::BFloat16,
1164 DataType::Float32,
1165 DataType::Float16,
1166 DataType::QAsymmS8,
1167 DataType::QAsymmU8,
1168 DataType::QSymmS16,
1169 DataType::Signed32
1170 };
1171
1172 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1173 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1174 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1175
1176 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1177 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
1178
1179 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1180 inputTensorInfo1,
1181 outputTensorInfo,
1182 descriptorName,
1183 "input_0",
1184 "input_1");
1185 }
1186
Validate(const WorkloadInfo & workloadInfo) const1187 void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1188 {
1189 const std::string descriptorName{"MultiplicationQueueDescriptor"};
1190
1191 ValidateNumInputs(workloadInfo, descriptorName, 2);
1192 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1193
1194 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1195 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1196 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1197
1198 std::vector<DataType> supportedTypes =
1199 {
1200 DataType::BFloat16,
1201 DataType::Float16,
1202 DataType::Float32,
1203 DataType::QAsymmS8,
1204 DataType::QAsymmU8,
1205 DataType::QSymmS16,
1206 DataType::Signed32
1207 };
1208
1209 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1210 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1211 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1212
1213 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1214 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
1215
1216 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1217 inputTensorInfo1,
1218 outputTensorInfo,
1219 descriptorName,
1220 "input_0",
1221 "input_1");
1222 }
1223
Validate(const WorkloadInfo & workloadInfo) const1224 void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1225 {
1226 const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
1227
1228 ValidateNumInputs(workloadInfo, descriptorName, 1);
1229 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1230
1231 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1232 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1233
1234 std::vector<DataType> supportedTypes =
1235 {
1236 DataType::BFloat16,
1237 DataType::Float16,
1238 DataType::Float32,
1239 DataType::QAsymmS8,
1240 DataType::QAsymmU8,
1241 DataType::QSymmS16
1242 };
1243
1244 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1245 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1246
1247 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1248 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1249
1250 ValidatePointer(m_Mean, descriptorName, "mean");
1251 ValidatePointer(m_Variance, descriptorName, "variance");
1252 ValidatePointer(m_Beta, descriptorName, "beta");
1253 ValidatePointer(m_Gamma, descriptorName, "gamma");
1254
1255 const TensorInfo& mean = m_Mean->GetTensorInfo();
1256 const TensorInfo& variance = m_Variance->GetTensorInfo();
1257 const TensorInfo& beta = m_Beta->GetTensorInfo();
1258 const TensorInfo& gamma = m_Gamma->GetTensorInfo();
1259
1260 ValidateTensorNumDimensions(mean, descriptorName, 1, "mean");
1261 ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
1262 ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
1263 ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
1264
1265 ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
1266 ValidateTensorShapesMatch(mean, beta, descriptorName, "mean", "beta");
1267 ValidateTensorShapesMatch(mean, gamma, descriptorName, "mean", "gamma");
1268 }
1269
Validate(const WorkloadInfo & workloadInfo) const1270 void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1271 {
1272 const std::string descriptorName{"Convolution2dQueueDescriptor"};
1273
1274 uint32_t numInputs = 2;
1275 if (m_Parameters.m_BiasEnabled)
1276 {
1277 numInputs = 3;
1278 }
1279
1280 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
1281 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1282
1283 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1284 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1285
1286 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1287 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1288
1289 const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
1290
1291 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1292
1293 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
1294
1295 Optional<TensorInfo> optionalBiasTensorInfo;
1296 if (m_Parameters.m_BiasEnabled)
1297 {
1298 optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
1299 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
1300
1301 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1302 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1303 }
1304
1305 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1306 {
1307 throw InvalidArgumentException(
1308 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1309 "cannot be either negative or 0.",
1310 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1311 }
1312
1313 ValidatePerAxisQuantization(inputTensorInfo,
1314 outputTensorInfo,
1315 weightTensorInfo,
1316 optionalBiasTensorInfo,
1317 descriptorName);
1318
1319 std::vector<DataType> supportedTypes =
1320 {
1321 DataType::BFloat16,
1322 DataType::Float16,
1323 DataType::Float32,
1324 DataType::QAsymmS8,
1325 DataType::QAsymmU8,
1326 DataType::QSymmS16,
1327 DataType::QSymmS8
1328 };
1329
1330 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1331
1332 // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
1333 if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1334 {
1335 if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1336 {
1337 throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
1338 "for BFloat16 input.");
1339 }
1340 }
1341 else
1342 {
1343 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1344 }
1345 }
1346
Validate(const WorkloadInfo & workloadInfo) const1347 void Convolution3dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1348 {
1349 const std::string descriptorName{"Convolution3dQueueDescriptor"};
1350
1351 uint32_t numInputs = 2;
1352 if (m_Parameters.m_BiasEnabled)
1353 {
1354 numInputs = 3;
1355 }
1356 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
1357 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1358
1359 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1360 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1361
1362 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 5, "input");
1363 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 5, "output");
1364
1365 const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
1366 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 5, "weight");
1367
1368 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
1369
1370 Optional<TensorInfo> optionalBiasTensorInfo;
1371 if (m_Parameters.m_BiasEnabled)
1372 {
1373 optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
1374 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
1375
1376 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1377 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1378 }
1379
1380 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 || m_Parameters.m_StrideZ <= 0 )
1381 {
1382 throw InvalidArgumentException(
1383 fmt::format("{}: strideX (provided {}), strideY (provided {}) or strideZ (provided {})"
1384 "cannot be either negative or 0.",
1385 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY, m_Parameters.m_StrideZ));
1386 }
1387
1388 ValidatePerAxisQuantization(inputTensorInfo,
1389 outputTensorInfo,
1390 weightTensorInfo,
1391 optionalBiasTensorInfo,
1392 descriptorName);
1393
1394 std::vector<DataType> supportedTypes =
1395 {
1396 DataType::BFloat16,
1397 DataType::Float16,
1398 DataType::Float32,
1399 DataType::QAsymmS8,
1400 DataType::QAsymmU8,
1401 DataType::QSymmS16,
1402 DataType::QSymmS8
1403 };
1404
1405 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1406 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1407 }
1408
Validate(const WorkloadInfo & workloadInfo) const1409 void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1410 {
1411 const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1412
1413 uint32_t numInputs = 2;
1414 if (m_Parameters.m_BiasEnabled)
1415 {
1416 numInputs = 3;
1417 }
1418
1419 ValidateNumInputs(workloadInfo, descriptorName, numInputs);
1420 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1421
1422 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1423 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1424
1425 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1426 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1427
1428 const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
1429 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1430
1431 if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1432 {
1433 throw InvalidArgumentException(
1434 fmt::format("{}: dilationX (provided {}) and dilationY (provided {}) "
1435 "cannot be smaller than 1.",
1436 descriptorName, m_Parameters.m_DilationX, m_Parameters.m_DilationX));
1437 }
1438
1439 if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 )
1440 {
1441 throw InvalidArgumentException(
1442 fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1443 "cannot be either negative or 0.",
1444 descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1445 }
1446
1447 if (weightTensorInfo.GetShape()[0] != 1)
1448 {
1449 throw InvalidArgumentException(fmt::format(
1450 "{0}: The weight format in armnn is expected to be [1, H, W, Cout]."
1451 "But first dimension is not equal to 1. Provided weight shape: [{1}, {2}, {3}, {4}]",
1452 descriptorName,
1453 weightTensorInfo.GetShape()[0],
1454 weightTensorInfo.GetShape()[1],
1455 weightTensorInfo.GetShape()[2],
1456 weightTensorInfo.GetShape()[3]));
1457 }
1458
1459 const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1460 const unsigned int numWeightOutputChannelsRefFormat = weightTensorInfo.GetShape()[3];
1461 const unsigned int numWeightOutputChannelsAclFormat = weightTensorInfo.GetShape()[1];
1462 const unsigned int numOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1463
1464 // Weights format has two valid options: [1, H, W, Cout] (CpuRef) or [1, Cout, H, W] (CpuAcc/GpuAcc).
1465 bool validRefFormat = (numWeightOutputChannelsRefFormat == numOutputChannels);
1466 bool validAclFormat = (numWeightOutputChannelsAclFormat == numOutputChannels);
1467
1468 if (!(validRefFormat || validAclFormat))
1469 {
1470 throw InvalidArgumentException(fmt::format(
1471 "{0}: The weight format in armnn is expected to be [1, H, W, Cout] (CpuRef) or [1, Cout, H, W] "
1472 "(CpuAcc/GpuAcc). But neither the 4th (CpuRef) or 2nd (CpuAcc/GpuAcc) dimension is equal to Cout."
1473 "Cout = {1} Provided weight shape: [{2}, {3}, {4}, {5}]",
1474 descriptorName,
1475 numOutputChannels,
1476 weightTensorInfo.GetShape()[0],
1477 weightTensorInfo.GetShape()[1],
1478 weightTensorInfo.GetShape()[2],
1479 weightTensorInfo.GetShape()[3]));
1480 }
1481
1482 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
1483
1484 Optional<TensorInfo> optionalBiasTensorInfo;
1485 if (m_Parameters.m_BiasEnabled)
1486 {
1487 optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
1488 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
1489
1490 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1491 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1492 }
1493 ValidatePerAxisQuantization(inputTensorInfo,
1494 outputTensorInfo,
1495 weightTensorInfo,
1496 optionalBiasTensorInfo,
1497 descriptorName);
1498
1499 std::vector<DataType> supportedTypes =
1500 {
1501 DataType::BFloat16,
1502 DataType::Float16,
1503 DataType::Float32,
1504 DataType::QAsymmS8,
1505 DataType::QAsymmU8,
1506 DataType::QSymmS16
1507 };
1508
1509 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1510 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1511 }
1512
Validate(const WorkloadInfo & workloadInfo) const1513 void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1514 {
1515 const std::string descriptorName{"PermuteQueueDescriptor"};
1516
1517 ValidateNumInputs(workloadInfo, descriptorName, 1);
1518 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1519
1520 const PermutationVector& mapping = m_Parameters.m_DimMappings;
1521
1522 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1523 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1524
1525 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
1526 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
1527
1528 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
1529 {
1530 if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
1531 {
1532 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1533 " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1534 "must match dst dimension " + to_string(mapping[i]) +
1535 " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
1536 }
1537 }
1538
1539 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1540 }
1541
Validate(const WorkloadInfo & workloadInfo) const1542 void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1543 {
1544 const std::string descriptorName{"Pooling2dQueueDescriptor"};
1545
1546 ValidateNumInputs(workloadInfo, descriptorName, 1);
1547 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1548
1549 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1550 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1551
1552 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1553 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1554
1555 std::vector<DataType> supportedTypes =
1556 {
1557 DataType::BFloat16,
1558 DataType::Float32,
1559 DataType::Float16,
1560 DataType::QAsymmS8,
1561 DataType::QAsymmU8,
1562 DataType::QSymmS16
1563 };
1564
1565 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1566 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1567 }
1568
Validate(const WorkloadInfo & workloadInfo) const1569 void Pooling3dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1570 {
1571 const std::string descriptorName{"Pooling3dQueueDescriptor"};
1572
1573 ValidateNumInputs(workloadInfo, descriptorName, 1);
1574 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1575
1576 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1577 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1578
1579 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 5, "input");
1580 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 5, "output");
1581
1582 std::vector<DataType> supportedTypes =
1583 {
1584 DataType::BFloat16,
1585 DataType::Float32,
1586 DataType::Float16,
1587 DataType::QAsymmS8,
1588 DataType::QAsymmU8,
1589 DataType::QSymmS16
1590 };
1591
1592 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1593 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1594 }
1595
Validate(const WorkloadInfo & workloadInfo) const1596 void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1597 {
1598 const std::string descriptorName{"ResizeQueueDescriptor"};
1599
1600 ValidateNumInputs(workloadInfo, descriptorName, 1);
1601 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1602
1603 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1604 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1605
1606 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1607 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1608
1609 std::vector<DataType> supportedTypes =
1610 {
1611 DataType::BFloat16,
1612 DataType::Float16,
1613 DataType::Float32,
1614 DataType::QAsymmS8,
1615 DataType::QAsymmU8,
1616 DataType::QSymmS16
1617 };
1618
1619 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1620 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1621
1622 // Resize only changes width and height: batch and channel count must match.
1623 const unsigned int inputBatchSize = inputTensorInfo.GetShape()[0];
1624 const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
1625 if (inputBatchSize != outputBatchSize)
1626 {
1627 throw InvalidArgumentException(
1628 fmt::format("{}: Input batch size ({}) does not match output batch size ({})",
1629 descriptorName, inputBatchSize, outputBatchSize));
1630 }
1631
1632 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1633 const unsigned int inputChannelCount = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1634 const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1635 if (inputChannelCount != outputChannelCount)
1636 {
1637 throw InvalidArgumentException(
1638 fmt::format("{}: Input channel count ({}) does not match output channel count ({})",
1639 descriptorName, inputChannelCount, outputChannelCount));
1640 }
1641 }
1642
Validate(const WorkloadInfo & workloadInfo) const1643 void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1644 {
1645 const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
1646
1647 ValidateNumInputs(workloadInfo, descriptorName, 1);
1648 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1649
1650 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1651 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1652
1653 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 2, "input");
1654 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1655
1656 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1657
1658 if (m_Parameters.m_Min > m_Parameters.m_Max)
1659 {
1660 throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
1661 }
1662 }
1663
Validate(const WorkloadInfo & workloadInfo) const1664 void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1665 {
1666 const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
1667
1668 ValidateNumInputs(workloadInfo, descriptorName, 1);
1669 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1670
1671 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1672 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1673
1674 if (inputTensorInfo.GetNumDimensions() > 4)
1675 {
1676 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1677 }
1678
1679 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1680
1681 // Check the supported data types
1682 std::vector<DataType> supportedTypes =
1683 {
1684 DataType::BFloat16,
1685 DataType::Float32,
1686 DataType::Float16
1687 };
1688
1689 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1690 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1691 }
1692
Validate(const WorkloadInfo & workloadInfo) const1693 void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1694 {
1695 const std::string descriptorName{"L2NormalizationQueueDescriptor"};
1696
1697 ValidateNumInputs(workloadInfo, descriptorName, 1);
1698 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1699
1700 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1701 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1702
1703 if (inputTensorInfo.GetNumDimensions() > 4)
1704 {
1705 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1706 }
1707
1708 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1709
1710 // Check the supported data types
1711 std::vector<DataType> supportedTypes =
1712 {
1713 DataType::BFloat16,
1714 DataType::Float32,
1715 DataType::Float16,
1716 DataType::QAsymmS8,
1717 DataType::QAsymmU8,
1718 DataType::QSymmS16
1719 };
1720
1721 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1722 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1723 }
1724
Validate(const WorkloadInfo & workloadInfo) const1725 void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1726 {
1727 const std::string descriptorName{"LogSoftmaxQueueDescriptor"};
1728
1729 ValidateNumInputs(workloadInfo, descriptorName, 1);
1730 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1731
1732 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1733 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1734
1735 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1736
1737 std::vector<DataType> supportedTypes =
1738 {
1739 DataType::BFloat16,
1740 DataType::Float32,
1741 DataType::Float16,
1742 };
1743
1744 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1745 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1746 }
1747
Validate(const WorkloadInfo & workloadInfo) const1748 void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1749 {
1750 const std::string descriptorName{"ConstantQueueDescriptor"};
1751
1752 ValidateNumInputs(workloadInfo, descriptorName, 0);
1753 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1754
1755 if (!m_LayerOutput)
1756 {
1757 throw InvalidArgumentException(descriptorName + ": No const input specified.");
1758 }
1759
1760 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1761 ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
1762
1763 // Check the supported data types
1764 std::vector<DataType> supportedTypes =
1765 {
1766 DataType::BFloat16,
1767 DataType::Float32,
1768 DataType::Float16,
1769 DataType::QAsymmS8,
1770 DataType::QAsymmU8,
1771 DataType::QSymmS8,
1772 DataType::QSymmS16,
1773 DataType::Signed32
1774 };
1775
1776 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1777 }
1778
Validate(const WorkloadInfo & workloadInfo) const1779 void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1780 {
1781 const std::string descriptorName{"ReshapeQueueDescriptor"};
1782
1783 ValidateNumInputs(workloadInfo, descriptorName, 1);
1784 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1785
1786 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1787 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1788
1789 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1790
1791 // Check the supported data types
1792 std::vector<DataType> supportedTypes =
1793 {
1794 DataType::BFloat16,
1795 DataType::Float32,
1796 DataType::Float16,
1797 DataType::QAsymmS8,
1798 DataType::QAsymmU8,
1799 DataType::QSymmS16,
1800 DataType::Signed32,
1801 DataType::Boolean
1802 };
1803
1804 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1805 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1806 }
1807
Validate(const WorkloadInfo & workloadInfo) const1808 void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1809 {
1810 const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
1811
1812 ValidateNumInputs(workloadInfo, descriptorName, 1);
1813 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1814
1815 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1816 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1817
1818 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1819 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1820
1821 if (m_Parameters.m_BlockShape.size() != 2)
1822 {
1823 throw InvalidArgumentException(descriptorName + ": Block Shape must contain 2 spatial dimensions.");
1824 }
1825
1826 if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1827 {
1828 throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1829 "dimensions as Block Shape.");
1830 }
1831
1832 const TensorShape& inputShape = inputTensorInfo.GetShape();
1833
1834 std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
1835 std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
1836
1837 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1838
1839 const unsigned int inputWidth = inputShape[dimensionIndices.GetWidthIndex()] +
1840 widthPad.first + widthPad.second;
1841 const unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] +
1842 heightPad.first + heightPad.second;
1843
1844 const unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth *
1845 inputShape[dimensionIndices.GetChannelsIndex()];
1846 const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
1847
1848 if (numOutputElements != numInputElements)
1849 {
1850 throw InvalidArgumentException(descriptorName + ": Input tensor has " +
1851 to_string(numInputElements) + " after padding but output tensor has " +
1852 to_string(numOutputElements) + " elements.");
1853 }
1854
1855 if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
1856 {
1857 throw InvalidArgumentException(descriptorName + ": Input shape after padding must be "
1858 "divisible by Block Shape in all spatial dimensions");
1859 }
1860
1861 std::vector<DataType> supportedTypes =
1862 {
1863 DataType::BFloat16,
1864 DataType::Float16,
1865 DataType::Float32,
1866 DataType::QAsymmS8,
1867 DataType::QAsymmU8,
1868 DataType::QSymmS16
1869 };
1870
1871 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1872 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1873 }
1874
Validate(const WorkloadInfo & workloadInfo) const1875 void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1876 {
1877 const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
1878
1879 ValidateNumInputs(workloadInfo, descriptorName, 1);
1880 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1881
1882 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1883 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1884
1885 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
1886 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1887
1888 std::vector<DataType> supportedTypes =
1889 {
1890 DataType::BFloat16,
1891 DataType::Float32,
1892 DataType::Float16,
1893 DataType::QAsymmS8,
1894 DataType::QAsymmU8,
1895 DataType::QSymmS16
1896 };
1897
1898 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1899 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1900
1901 ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1902
1903 if (m_Parameters.m_BlockSize == 0)
1904 {
1905 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1906 }
1907
1908 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1909 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1910 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1911 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
1912
1913 const TensorShape& inputShape = inputTensorInfo.GetShape();
1914 if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
1915 {
1916 throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1917 "by block size in all spatial dimensions");
1918 }
1919
1920 const TensorShape& outputShape = outputTensorInfo.GetShape();
1921 if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1922 {
1923 throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1924 "must be divisible by the square of block size." );
1925 }
1926 }
1927
Validate(const WorkloadInfo & workloadInfo) const1928 void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1929 {
1930 const std::string descriptorName{"FloorQueueDescriptor"};
1931
1932 ValidateNumInputs(workloadInfo, descriptorName, 1);
1933 ValidateNumOutputs(workloadInfo, descriptorName, 1);
1934
1935 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1936 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1937
1938 std::vector<DataType> supportedTypes =
1939 {
1940 DataType::BFloat16,
1941 DataType::Float32,
1942 DataType::Float16,
1943 DataType::QSymmS16
1944 };
1945
1946 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1947 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1948 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1949 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1950 }
1951
Validate(const WorkloadInfo & workloadInfo) const1952 void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1953 {
1954 // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1955
1956 const std::string descriptorName{"LstmQueueDescriptor"};
1957
1958 // check dimensions of all inputs and outputs
1959 if (workloadInfo.m_InputTensorInfos.size() != 3)
1960 {
1961 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1962 }
1963 if (workloadInfo.m_OutputTensorInfos.size() != 4)
1964 {
1965 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1966 }
1967
1968 std::vector<DataType> supportedTypes =
1969 {
1970 DataType::BFloat16,
1971 DataType::Float16,
1972 DataType::Float32,
1973 DataType::QSymmS16
1974 };
1975
1976 // check for supported type of one input and match them with all the other input and output
1977 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1978
1979 // type matches all other inputs
1980 for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
1981 {
1982 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1983 workloadInfo.m_InputTensorInfos[i],
1984 descriptorName,
1985 "input_0",
1986 "input_" + std::to_string(i));
1987 }
1988 // type matches all other outputs
1989 for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
1990 {
1991 ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1992 workloadInfo.m_OutputTensorInfos[i],
1993 "LstmQueueDescriptor",
1994 "input_0",
1995 "output_" + std::to_string(i));
1996 }
1997
1998 // Making sure clipping parameters have valid values.
1999 // == 0 means no clipping
2000 // > 0 means clipping
2001 if (m_Parameters.m_ClippingThresCell < 0.0f)
2002 {
2003 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
2004 }
2005 if (m_Parameters.m_ClippingThresProj < 0.0f)
2006 {
2007 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
2008 }
2009
2010 // Inferring batch size, number of outputs and number of cells from the inputs.
2011 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
2012 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
2013 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
2014 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
2015 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
2016 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
2017
2018 // input tensor
2019 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
2020 descriptorName + " input_0");
2021 // outputStateInTensor
2022 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
2023 descriptorName + " input_1");
2024 // outputStateInTensor
2025 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
2026 descriptorName + " input_2");
2027 // scratchBufferTensor
2028 unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
2029 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
2030 descriptorName + " output_0");
2031 // outputStateOutTensor
2032 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
2033 descriptorName + " output_1");
2034 // cellStateOutTensor
2035 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
2036 descriptorName + " output_2");
2037 // outputTensor
2038 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
2039 descriptorName + " output_3");
2040
2041 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
2042 if ( m_InputToInputWeights )
2043 {
2044 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
2045 (n_cell * n_input), "InputLayerNormWeights");
2046 }
2047
2048 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
2049 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
2050 (n_cell * n_input), "InputToForgetWeights");
2051
2052 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
2053 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
2054 (n_cell * n_input), "InputToCellWeights");
2055
2056 if ( m_RecurrentToInputWeights )
2057 {
2058 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
2059 (n_cell * n_output), "RecurrentToInputWeights");
2060 }
2061
2062 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
2063 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
2064 (n_cell * n_output), "RecurrentToForgetWeights");
2065
2066 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
2067 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
2068 (n_cell * n_output), "RecurrentToCellWeights");
2069
2070 // Make sure the input-gate's parameters are either both present (regular
2071 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
2072 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
2073 !m_Parameters.m_CifgEnabled) ||
2074 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
2075 m_Parameters.m_CifgEnabled));
2076 if (!cifg_weights_all_or_none)
2077 {
2078 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
2079 "RecurrentToInputWeights must either both be present (regular LSTM) "
2080 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
2081 "accordingly.");
2082 }
2083
2084 if ( m_CellToInputWeights )
2085 {
2086 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
2087 n_cell, "CellToInputWeights");
2088 }
2089 if ( m_CellToForgetWeights )
2090 {
2091 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
2092 n_cell, "CellToForgetWeights");
2093 }
2094 if ( m_CellToOutputWeights )
2095 {
2096 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
2097 n_cell, "CellToOutputWeights");
2098 }
2099
2100 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
2101 bool peephole_weights_all_or_none =
2102 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
2103 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
2104 || ( !m_CellToInputWeights && !m_CellToForgetWeights
2105 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
2106 if (!peephole_weights_all_or_none)
2107 {
2108 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
2109 }
2110
2111 // Make sure the input gate bias is present only when not a CIFG-LSTM.
2112 if (m_Parameters.m_CifgEnabled)
2113 {
2114 if (m_InputGateBias)
2115 {
2116 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
2117 }
2118 }
2119 else
2120 {
2121 if (!m_InputGateBias)
2122 {
2123 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
2124 "must be present.");
2125 }
2126 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
2127 n_cell, "InputGateBias");
2128 }
2129
2130 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
2131 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
2132
2133 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
2134 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
2135
2136 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
2137 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
2138
2139 if (m_ProjectionWeights)
2140 {
2141 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
2142 (n_cell * n_output), "ProjectionWeights");
2143 }
2144 if (m_ProjectionBias)
2145 {
2146 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
2147 }
2148
2149 // Making sure the projection tensors are consistent:
2150 // 1) If projection weight is not present, then projection bias should not be
2151 // present.
2152 // 2) If projection weight is present, then projection bias is optional.
2153 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
2154 !m_Parameters.m_ProjectionEnabled)
2155 || (m_ProjectionWeights && !m_ProjectionBias &&
2156 m_Parameters.m_ProjectionEnabled)
2157 || (m_ProjectionWeights && m_ProjectionBias &&
2158 m_Parameters.m_ProjectionEnabled));
2159 if (!projecton_tensors_consistent)
2160 {
2161 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
2162 }
2163
2164 // The four layer normalization weights either all have values or none of them have values. Additionally, if
2165 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
2166 // either all have values or none of them have values. Layer normalization is used when the values of all the
2167 // layer normalization weights are present
2168 if (m_InputLayerNormWeights)
2169 {
2170 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
2171 }
2172 if (m_ForgetLayerNormWeights)
2173 {
2174 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2175 }
2176 if (m_CellLayerNormWeights)
2177 {
2178 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2179 }
2180 if (m_OutputLayerNormWeights)
2181 {
2182 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2183 }
2184
2185 if (m_Parameters.m_LayerNormEnabled)
2186 {
2187 if (!m_Parameters.m_CifgEnabled)
2188 {
2189 if (!m_InputLayerNormWeights)
2190 {
2191 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
2192 "disabled but InputLayerNormWeights are not present");
2193 }
2194 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
2195 1, n_cell, "InputLayerNormWeights");
2196 }
2197 else if (m_InputLayerNormWeights)
2198 {
2199 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
2200 "enabled");
2201 }
2202
2203 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
2204 "ForgetLayerNormWeights");
2205 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2206
2207 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
2208 "OutputLayerNormWeights");
2209 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2210
2211 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
2212 "CellLayerNormWeights");
2213 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2214 }
2215 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
2216 {
2217 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
2218 "normalisation weights are present.");
2219 }
2220 }
2221
Validate(const WorkloadInfo & workloadInfo) const2222 void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2223 {
2224 const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
2225
2226 ValidateNumInputs(workloadInfo, descriptorName, 1);
2227 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2228
2229 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2230 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2231
2232 if (inputTensorInfo.GetDataType() != DataType::Float32)
2233 {
2234 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
2235 }
2236
2237 if (outputTensorInfo.GetDataType() != DataType::Float16)
2238 {
2239 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
2240 }
2241
2242 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2243 }
2244
Validate(const WorkloadInfo & workloadInfo) const2245 void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2246 {
2247 const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
2248
2249 ValidateNumInputs(workloadInfo, descriptorName, 1);
2250 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2251
2252 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2253 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2254
2255 if (inputTensorInfo.GetDataType() != DataType::Float16)
2256 {
2257 throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
2258 }
2259
2260 if (outputTensorInfo.GetDataType() != DataType::Float32)
2261 {
2262 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2263 }
2264
2265 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2266 }
2267
Validate(const WorkloadInfo & workloadInfo) const2268 void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2269 {
2270 const std::string descriptorName{"DivisionQueueDescriptor"};
2271
2272 ValidateNumInputs(workloadInfo, descriptorName, 2);
2273 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2274
2275 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2276 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2277 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2278
2279 std::vector<DataType> supportedTypes =
2280 {
2281 DataType::BFloat16,
2282 DataType::Float16,
2283 DataType::Float32,
2284 DataType::QAsymmS8,
2285 DataType::QAsymmU8,
2286 DataType::QSymmS16,
2287 DataType::Signed32
2288 };
2289
2290 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2291 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2292 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
2293
2294 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2295 inputTensorInfo1,
2296 outputTensorInfo,
2297 descriptorName,
2298 "input_0",
2299 "input_1");
2300 }
2301
Validate(const WorkloadInfo & workloadInfo) const2302 void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2303 {
2304 const std::string descriptorName{"SubtractionQueueDescriptor"};
2305
2306 ValidateNumInputs(workloadInfo, descriptorName, 2);
2307 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2308
2309 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2310 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2311 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2312
2313 std::vector<DataType> supportedTypes =
2314 {
2315 DataType::BFloat16,
2316 DataType::Float16,
2317 DataType::Float32,
2318 DataType::QAsymmS8,
2319 DataType::QAsymmU8,
2320 DataType::QSymmS16,
2321 DataType::Signed32,
2322 };
2323
2324 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2325 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2326 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
2327
2328 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2329 inputTensorInfo1,
2330 outputTensorInfo,
2331 descriptorName,
2332 "input_0",
2333 "input_1");
2334 }
2335
Validate(const WorkloadInfo & workloadInfo) const2336 void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2337 {
2338 const std::string descriptorName{"MaximumQueueDescriptor"};
2339
2340 ValidateNumInputs(workloadInfo, descriptorName, 2);
2341 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2342
2343 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2344 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2345 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2346
2347 std::vector<DataType> supportedTypes =
2348 {
2349 DataType::BFloat16,
2350 DataType::Float16,
2351 DataType::Float32,
2352 DataType::QAsymmS8,
2353 DataType::QAsymmU8,
2354 DataType::QSymmS16,
2355 DataType::Signed32
2356 };
2357
2358 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2359 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2360 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
2361
2362 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2363 inputTensorInfo1,
2364 outputTensorInfo,
2365 descriptorName,
2366 "input_0",
2367 "input_1");
2368 }
2369
Validate(const WorkloadInfo & workloadInfo) const2370 void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2371 {
2372 const std::string descriptorName{"MeanQueueDescriptor"};
2373
2374 ValidateNumInputs(workloadInfo, descriptorName, 1);
2375 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2376
2377 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2378 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2379
2380 std::vector<DataType> supportedTypes =
2381 {
2382 DataType::BFloat16,
2383 DataType::Float32,
2384 DataType::Float16,
2385 DataType::QAsymmS8,
2386 DataType::QAsymmU8,
2387 DataType::QSymmS16
2388 };
2389
2390 // First check if input tensor data type is supported, then
2391 // check if this data type matches the output tensor data type
2392 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2393 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2394
2395 if (m_Parameters.m_KeepDims)
2396 {
2397 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2398 }
2399 else if (m_Parameters.m_Axis.empty())
2400 {
2401 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
2402 }
2403 else
2404 {
2405 unsigned int outputDim =
2406 inputTensorInfo.GetNumDimensions() - armnn::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
2407 ValidateTensorNumDimensions(outputTensorInfo,
2408 descriptorName,
2409 outputDim > 0 ? outputDim : 1,
2410 "output");
2411 }
2412 }
2413
Validate(const WorkloadInfo & workloadInfo) const2414 void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2415 {
2416 const std::string descriptorName{"PadQueueDescriptor"};
2417
2418 ValidateNumInputs(workloadInfo, descriptorName, 1);
2419 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2420
2421 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2422 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2423
2424 // input and output should have the same number of dimensions
2425 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2426
2427 // there should be entry in the pad list for each dimension in the input tensor
2428 if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
2429 throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
2430 "as there are dimensions in the input tensor that is " +
2431 std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
2432 " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
2433 }
2434 }
2435
Validate(const WorkloadInfo & workloadInfo) const2436 void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2437 {
2438 const std::string descriptorName{"QuantizeQueueDescriptor"};
2439
2440 ValidateNumInputs(workloadInfo, descriptorName, 1);
2441 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2442
2443 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2444 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2445
2446 std::vector<DataType> supportedTypes =
2447 {
2448 DataType::BFloat16,
2449 DataType::Float32,
2450 DataType::Float16,
2451 DataType::QSymmS8,
2452 DataType::QAsymmS8,
2453 DataType::QAsymmU8,
2454 DataType::QSymmS16
2455 };
2456
2457 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2458
2459 if (!IsQuantizedType(outputTensorInfo.GetDataType()))
2460 {
2461 throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
2462 }
2463 }
2464
Validate(const WorkloadInfo & workloadInfo) const2465 void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2466 {
2467 const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
2468
2469 ValidateNumInputs(workloadInfo, descriptorName, 1);
2470 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2471
2472 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2473 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2474
2475 std::vector<DataType> supportedTypes =
2476 {
2477 DataType::BFloat16,
2478 DataType::Float32,
2479 DataType::Float16,
2480 DataType::QAsymmS8,
2481 DataType::QAsymmU8,
2482 DataType::QSymmS16
2483 };
2484
2485 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2486 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2487 }
2488
Validate(const WorkloadInfo & workloadInfo) const2489 void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2490 {
2491 const std::string descriptorName{"StridedSliceQueueDescriptor"};
2492
2493 ValidateNumInputs(workloadInfo, descriptorName, 1);
2494 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2495
2496 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2497 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2498
2499 std::vector<DataType> supportedTypes =
2500 {
2501 DataType::BFloat16,
2502 DataType::Float16,
2503 DataType::Float32,
2504 DataType::QAsymmS8,
2505 DataType::QAsymmU8,
2506 DataType::QSymmS16
2507 };
2508
2509 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2510 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2511
2512 ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2513
2514 const uint32_t rank = inputTensorInfo.GetNumDimensions();
2515 if (rank > 4)
2516 {
2517 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
2518 }
2519
2520 // Begin, End & Stride length must be of rank(input0)
2521 if (m_Parameters.m_Begin.size() != rank)
2522 {
2523 throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
2524 }
2525
2526 if (m_Parameters.m_End.size() != rank)
2527 {
2528 throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
2529 }
2530
2531 if (m_Parameters.m_Stride.size() != rank)
2532 {
2533 throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
2534 }
2535
2536 // Stride entries must be non-zero
2537 for (auto& stride : m_Parameters.m_Stride)
2538 {
2539 if (stride == 0)
2540 {
2541 throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
2542 }
2543 }
2544 }
2545
Validate(const WorkloadInfo & workloadInfo) const2546 void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2547 {
2548 const std::string descriptorName{"MinimumQueueDescriptor"};
2549
2550 ValidateNumInputs(workloadInfo, descriptorName, 2);
2551 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2552
2553 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2554 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2555 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2556
2557 std::vector<DataType> supportedTypes =
2558 {
2559 DataType::BFloat16,
2560 DataType::Float16,
2561 DataType::Float32,
2562 DataType::QAsymmS8,
2563 DataType::QAsymmU8,
2564 DataType::QSymmS16,
2565 DataType::Signed32
2566 };
2567
2568 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2569 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2570 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
2571
2572 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2573 inputTensorInfo1,
2574 outputTensorInfo,
2575 descriptorName,
2576 "input_0",
2577 "input_1");
2578 }
2579
Validate(const WorkloadInfo & workloadInfo) const2580 void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2581 {
2582 const std::string descriptorName{"DebugQueueDescriptor"};
2583
2584 ValidateNumInputs(workloadInfo, descriptorName, 1);
2585 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2586 }
2587
Validate(const WorkloadInfo & workloadInfo) const2588 void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2589 {
2590 const std::string descriptorName{"EqualQueueDescriptor"};
2591
2592 ValidateNumInputs(workloadInfo, descriptorName, 2);
2593 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2594
2595 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2596 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2597 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2598
2599 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2600 inputTensorInfo1,
2601 outputTensorInfo,
2602 descriptorName,
2603 "input_0",
2604 "input_1");
2605
2606 if (outputTensorInfo.GetDataType() != DataType::Boolean)
2607 {
2608 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
2609 }
2610 }
2611
Validate(const WorkloadInfo & workloadInfo) const2612 void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2613 {
2614 const std::string descriptorName{"GreaterQueueDescriptor"};
2615
2616 ValidateNumInputs(workloadInfo, descriptorName, 2);
2617 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2618
2619 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2620 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2621 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2622
2623 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2624 inputTensorInfo1,
2625 outputTensorInfo,
2626 descriptorName,
2627 "input_0",
2628 "input_1");
2629
2630 if (outputTensorInfo.GetDataType() != DataType::Boolean)
2631 {
2632 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
2633 }
2634 }
2635
Validate(const WorkloadInfo & workloadInfo) const2636 void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2637 {
2638 const std::string descriptorName{"RsqrtQueueDescriptor"};
2639
2640 ValidateNumInputs(workloadInfo, descriptorName, 1);
2641 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2642
2643 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2644 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2645
2646 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2647
2648 std::vector<DataType> supportedTypes =
2649 {
2650 DataType::BFloat16,
2651 DataType::Float16,
2652 DataType::Float32,
2653 DataType::QAsymmS8,
2654 DataType::QAsymmU8,
2655 DataType::QSymmS16
2656 };
2657
2658 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2659 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2660 }
2661
Validate(const WorkloadInfo & workloadInfo) const2662 void GatherNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2663 {
2664 const std::string descriptorName{"GatherNdQueueDescriptor"};
2665
2666 ValidateNumInputs(workloadInfo, descriptorName, 2);
2667 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2668
2669 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2670 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
2671 {
2672 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
2673 }
2674
2675 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2676 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2677
2678 std::vector<DataType> supportedTypes =
2679 {
2680 DataType::BFloat16,
2681 DataType::Float16,
2682 DataType::Float32,
2683 DataType::QAsymmS8,
2684 DataType::QAsymmU8,
2685 DataType::QSymmS16,
2686 DataType::Signed32,
2687 };
2688
2689 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2690
2691 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2692
2693 unsigned int outputDim = outputTensorInfo.GetNumDimensions();
2694 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
2695 }
2696
Validate(const WorkloadInfo & workloadInfo) const2697 void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2698 {
2699 const std::string descriptorName{"GatherQueueDescriptor"};
2700
2701 ValidateNumInputs(workloadInfo, descriptorName, 2);
2702 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2703
2704 const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2705 if (indicesTensorInfo.GetDataType() != DataType::Signed32)
2706 {
2707 throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
2708 }
2709
2710 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2711 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2712
2713 std::vector<DataType> supportedTypes =
2714 {
2715 DataType::BFloat16,
2716 DataType::Float16,
2717 DataType::Float32,
2718 DataType::QAsymmS8,
2719 DataType::QAsymmU8,
2720 DataType::QSymmS16,
2721 DataType::Signed32,
2722 };
2723
2724 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2725
2726 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2727
2728 unsigned int outputDim = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2729 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
2730 }
2731
Validate(const WorkloadInfo & workloadInfo) const2732 void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2733 {
2734 const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2735
2736 ValidateNumInputs(workloadInfo, descriptorName, 2);
2737
2738 if (workloadInfo.m_OutputTensorInfos.size() != 4)
2739 {
2740 throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
2741 to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2742 }
2743
2744 if (m_Anchors == nullptr)
2745 {
2746 throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
2747 }
2748
2749 const TensorInfo& boxEncodingsInfo = workloadInfo.m_InputTensorInfos[0];
2750 const TensorInfo& scoresInfo = workloadInfo.m_InputTensorInfos[1];
2751 const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
2752
2753 const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
2754 const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
2755 const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
2756 const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
2757
2758 ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2759 ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2760 ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
2761
2762 const std::vector<DataType> supportedInputTypes =
2763 {
2764 DataType::BFloat16,
2765 DataType::Float32,
2766 DataType::Float16,
2767 DataType::QAsymmS8,
2768 DataType::QAsymmU8,
2769 DataType::QSymmS16
2770 };
2771
2772 ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2773 ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2774 ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2775
2776 ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2777 ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2778 ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2779 ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2780
2781 // NOTE: Output is always Float32 regardless of input type
2782 ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2783 ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2784 ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2785 ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
2786
2787 if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2788 {
2789 throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
2790 "must be positive and less than or equal to 1.");
2791 }
2792
2793 if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2794 {
2795 throw InvalidArgumentException(descriptorName + ": Number of classes with background "
2796 "should be equal to number of classes + 1.");
2797 }
2798 }
2799
Validate(const WorkloadInfo & workloadInfo) const2800 void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2801 {
2802 const std::string& descriptorName{"DequantizeQueueDescriptor"};
2803
2804 ValidateNumInputs(workloadInfo, descriptorName, 1);
2805 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2806
2807 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2808 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2809
2810 std::vector<DataType> inputSupportedTypes =
2811 {
2812 DataType::QAsymmS8,
2813 DataType::QAsymmU8,
2814 DataType::QSymmS8,
2815 DataType::QSymmS16,
2816 DataType::Float16
2817 };
2818 ValidateDataTypes(inputTensorInfo, inputSupportedTypes, descriptorName);
2819
2820 std::vector<DataType> outputSupportedTypes =
2821 {
2822 DataType::BFloat16,
2823 DataType::Float32,
2824 DataType::Float16
2825 };
2826
2827 ValidateDataTypes(outputTensorInfo, outputSupportedTypes, descriptorName);
2828 }
2829
Validate(const WorkloadInfo & workloadInfo) const2830 void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2831 {
2832 const std::string& descriptorName{"MergeQueueDescriptor"};
2833
2834 ValidateNumInputs(workloadInfo, descriptorName, 2);
2835 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2836
2837 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2838 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2839 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2840
2841 ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2842 ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2843
2844 ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2845 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2846 }
2847
Validate(const WorkloadInfo & workloadInfo) const2848 void ShapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2849 {
2850 const std::string& descriptorName{"ShapeQueueDescriptor"};
2851
2852 ValidateNumInputs(workloadInfo, descriptorName, 1);
2853 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2854
2855 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2856 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2857
2858 std::vector<DataType> supportedTypes =
2859 {
2860 DataType::BFloat16,
2861 DataType::Float16,
2862 DataType::Float32,
2863 DataType::QAsymmS8,
2864 DataType::QAsymmU8,
2865 DataType::QAsymmS8,
2866 DataType::QSymmS8,
2867 DataType::QSymmS16,
2868 DataType::Signed32
2869 };
2870
2871 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2872 ValidateDataTypes(outputTensorInfo, {DataType::Signed32}, descriptorName);
2873 }
2874
Validate(const WorkloadInfo & workloadInfo) const2875 void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2876 {
2877 const std::string& descriptorName{"SwitchQueueDescriptor"};
2878
2879 ValidateNumInputs(workloadInfo, descriptorName, 2);
2880 ValidateNumOutputs(workloadInfo, descriptorName, 2);
2881
2882 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2883 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2884
2885 const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2886 const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2887
2888 std::vector<DataType> supportedTypes =
2889 {
2890 DataType::BFloat16,
2891 DataType::Float32,
2892 DataType::QAsymmS8,
2893 DataType::QAsymmU8,
2894 DataType::QSymmS16
2895 };
2896
2897 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2898 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2899
2900 ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2901 ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
2902
2903 ValidateTensorShapesMatch(inputTensorInfo0,
2904 outputTensorInfo0,
2905 descriptorName,
2906 "input_0",
2907 "output_0");
2908
2909 ValidateTensorShapesMatch(inputTensorInfo0,
2910 outputTensorInfo1,
2911 descriptorName,
2912 "input_0",
2913 "output_1");
2914 }
2915
Validate(const WorkloadInfo &) const2916 void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& /*workloadInfo*/) const
2917 {
2918 // This is internally generated so it should not need validation.
2919 }
2920
Validate(const WorkloadInfo & workloadInfo) const2921 void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2922 {
2923 const std::string& descriptorName{"PreluQueueDescriptor"};
2924
2925 ValidateNumInputs(workloadInfo, descriptorName, 2);
2926 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2927
2928 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2929 const TensorInfo& alphaTensorInfo = workloadInfo.m_InputTensorInfos[1];
2930 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2931
2932 std::vector<DataType> supportedTypes
2933 {
2934 DataType::BFloat16,
2935 DataType::Float16,
2936 DataType::Float32,
2937 DataType::QAsymmS8,
2938 DataType::QAsymmU8,
2939 DataType::QSymmS16
2940 };
2941
2942 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2943 ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
2944
2945 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
2946
2947 ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo, descriptorName, "input", "alpha");
2948 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
2949
2950 ValidateBroadcastTensorShapesMatch(inputTensorInfo,
2951 alphaTensorInfo,
2952 outputTensorInfo,
2953 descriptorName,
2954 "input",
2955 "alpha");
2956 }
2957
Validate(const WorkloadInfo & workloadInfo) const2958 void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2959 {
2960 const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
2961
2962 ValidateNumInputs(workloadInfo, descriptorName, 1);
2963 ValidateNumOutputs(workloadInfo, descriptorName, 1);
2964
2965 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2966 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2967
2968 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
2969 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
2970
2971 ValidatePointer(m_Weight, descriptorName, "weight");
2972
2973 const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
2974 ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
2975
2976 ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
2977
2978 Optional<TensorInfo> optionalBiasTensorInfo;
2979 if (m_Parameters.m_BiasEnabled)
2980 {
2981 ValidatePointer(m_Bias, descriptorName, "bias");
2982
2983 optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
2984 const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
2985
2986 ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
2987 ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
2988 }
2989
2990 ValidatePerAxisQuantization(inputTensorInfo,
2991 outputTensorInfo,
2992 weightTensorInfo,
2993 optionalBiasTensorInfo,
2994 descriptorName);
2995
2996 std::vector<DataType> supportedTypes =
2997 {
2998 DataType::BFloat16,
2999 DataType::Float32,
3000 DataType::Float16,
3001 DataType::QAsymmS8,
3002 DataType::QAsymmU8,
3003 DataType::QSymmS16
3004 };
3005
3006 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3007 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3008 }
3009
Validate(const WorkloadInfo & workloadInfo) const3010 void TransposeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3011 {
3012 const std::string descriptorName{"TransposeQueueDescriptor"};
3013
3014 ValidateNumInputs(workloadInfo, descriptorName, 1);
3015 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3016
3017 const PermutationVector& mapping = m_Parameters.m_DimMappings;
3018
3019 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3020 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3021
3022 ValidateTensorNumDimensions(inputTensorInfo, descriptorName, mapping.GetSize(), "input");
3023 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
3024
3025 for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
3026 {
3027 if (inputTensorInfo.GetShape()[mapping[i]] != outputTensorInfo.GetShape()[i])
3028 {
3029 throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(mapping[i]) +
3030 " (=" + to_string(inputTensorInfo.GetShape()[mapping[i]]) + ") " +
3031 "must match dst dimension " + to_string(i) +
3032 " (=" + to_string(outputTensorInfo.GetShape()[i]) + ")");
3033 }
3034 }
3035
3036 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3037 }
3038
Validate(const WorkloadInfo & workloadInfo) const3039 void ChannelShuffleQueueDescriptor::Validate(const WorkloadInfo &workloadInfo) const
3040 {
3041 const std::string descriptorName{"TransposeQueueDescriptor"};
3042
3043 ValidateNumInputs(workloadInfo, descriptorName, 1);
3044 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3045
3046 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3047 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3048
3049 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3050 }
3051
Validate(const WorkloadInfo & workloadInfo) const3052 void QLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3053 {
3054 const std::string descriptorName{"QLstmQueueDescriptor"};
3055
3056 // Validate number of inputs/outputs
3057 ValidateNumInputs(workloadInfo, descriptorName, 3);
3058 ValidateNumOutputs(workloadInfo, descriptorName, 3);
3059
3060 // Input/output tensor info
3061 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3062 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[1];
3063 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[2];
3064
3065 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3066 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3067 auto outputInfo = workloadInfo.m_OutputTensorInfos[2];
3068
3069 // Supported types for various tensors in QLSTM
3070 std::vector<DataType> inputOutputSupportedTypes =
3071 {
3072 DataType::QAsymmS8
3073 };
3074
3075 std::vector<DataType> cellStateSupportedTypes =
3076 {
3077 DataType::QSymmS16
3078 };
3079
3080 std::vector<DataType> weightsSupportedTypes =
3081 {
3082 DataType::QSymmS8
3083 };
3084
3085 std::vector<DataType> layerNormPeepholeWeightsSupportedTypes =
3086 {
3087 DataType::QSymmS16
3088 };
3089
3090 std::vector<DataType> biasSupportedTypes =
3091 {
3092 DataType::Signed32
3093 };
3094
3095 // Validate types of input/output tensors
3096 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3097 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3098 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3099
3100 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3101 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3102 ValidateDataTypes(outputInfo, inputOutputSupportedTypes, descriptorName);
3103
3104 // Validate matching types of input/output tensors
3105 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3106 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3107 "outputStateIn", "outputStateOut");
3108 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3109
3110 // Infer number of batches, number of units, input size and output size from tensor dimensions
3111 const uint32_t numBatches = inputInfo.GetShape()[0];
3112 const uint32_t inputSize = inputInfo.GetShape()[1];
3113 const uint32_t outputSize = outputStateInInfo.GetShape()[1];
3114 const uint32_t numUnits = cellStateInInfo.GetShape()[1];
3115
3116 // Validate number of dimensions and number of elements for input/output tensors
3117 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3118 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3119 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * numUnits), descriptorName + " cellStateIn");
3120
3121 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3122 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * numUnits), descriptorName + " cellStateOut");
3123 ValidateTensorNumDimNumElem(outputInfo, 2, (numBatches * outputSize), descriptorName + " output");
3124
3125 // Validate number of dimensions and number of elements for MANDATORY weight tensors
3126 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3127 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3128 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (numUnits * inputSize), " InputToForgetWeights");
3129
3130 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3131 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3132 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (numUnits * inputSize), " InputToCellWeights");
3133
3134 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3135 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3136 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (numUnits * inputSize), " InputToOutputWeights");
3137
3138 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3139 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3140 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (numUnits * outputSize),
3141 " RecurrentToForgetWeights");
3142
3143 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3144 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3145 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3146
3147 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3148 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3149 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3150
3151 // Validate data types for MANDATORY weights tensors (all should match each other)
3152 ValidateDataTypes(inputToForgetWeightsInfo, weightsSupportedTypes, descriptorName);
3153
3154 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToCellWeightsInfo, descriptorName,
3155 "inputToForgetWeights", "inputToCellWeights");
3156 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3157 "inputToForgetWeights", "inputToOutputWeights");
3158
3159 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3160 "inputToForgetWeights", "recurrentToForgeteights");
3161 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3162 "inputToForgetWeights", "recurrentToCellWeights");
3163 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3164 "inputToForgetWeights", "recurrentToOutputWeights");
3165
3166 // Validate number of dimensions and number of elements for MANDATORY bias tensors
3167 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3168 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3169 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, numUnits, " ForgetGateBias");
3170
3171 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3172 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3173 ValidateTensorNumDimNumElem(cellBiasInfo, 1, numUnits, " CellBias");
3174
3175 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3176 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3177 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, numUnits, " OutputGateBias");
3178
3179 // Validate data types for MANDATORY bias tensors
3180 ValidateDataTypes(forgetGateBiasInfo, biasSupportedTypes, descriptorName);
3181
3182 ValidateTensorDataTypesMatch(forgetGateBiasInfo, cellBiasInfo, descriptorName,
3183 "forgetGateBias", "cellBias");
3184 ValidateTensorDataTypesMatch(forgetGateBiasInfo, outputGateBiasInfo, descriptorName,
3185 "forgetGateBias", "outputGateBias");
3186
3187 // Validate OPTIONAL params: CIFG (inputToInputWeights, recurrentToInputWeights, inputGateBias)
3188 const bool allCifgParamsPresentOrNot = ((m_InputToInputWeights && m_RecurrentToInputWeights && m_InputGateBias &&
3189 !m_Parameters.m_CifgEnabled) ||
3190 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3191 !m_InputGateBias && m_Parameters.m_CifgEnabled));
3192
3193 if (!allCifgParamsPresentOrNot)
3194 {
3195 throw InvalidArgumentException(descriptorName +
3196 ": InputToInputWeights, RecurrentToInputWeights and InputGateBias must either all be present "
3197 "(CIFG disabled) or not be present at all (CIFG enabled). m_Parameters.m_CifgEnabled should be "
3198 "set appropriately.");
3199 }
3200
3201 if (!m_Parameters.m_CifgEnabled)
3202 {
3203 // Validate number of dimensions and number of elements
3204 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3205 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (numUnits * inputSize), " InputToInputWeights");
3206
3207 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3208 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (numUnits * outputSize),
3209 " RecurrentToInputWeights");
3210
3211 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3212 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, numUnits, " InputGateBias");
3213
3214 // Validate data types
3215 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToInputWeightsInfo, descriptorName,
3216 "inputToForgetWeights", "inputToInputWeights");
3217 ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3218 "inputToForgetWeights", "recurrentToInputWeights");
3219 ValidateTensorDataTypesMatch(forgetGateBiasInfo, inputGateBiasInfo, descriptorName,
3220 "forgetGateBias", "inputGateBias");
3221 }
3222
3223 // Validate OPTIONAL params: Peephole (cellToInputWeights, cellToForgetWeights, cellToOutputWeights)
3224 bool allPeepholeWeightsPresentOrNot =
3225 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
3226 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
3227 || (!m_CellToInputWeights && !m_CellToForgetWeights
3228 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
3229
3230 if (!allPeepholeWeightsPresentOrNot)
3231 {
3232 throw InvalidArgumentException(descriptorName +
3233 ": CellToInputWeights, CellToForgetWeights and CellToOutputWeights should all be present (Peephole "
3234 "enabled) or not be present at all (Peephole disabled). CellToInputWeights should only be present "
3235 "when Peephole is enabled and CIFG is disabled. m_Parameters.m_PeepholeEnabled should be set "
3236 "appropriately.");
3237 }
3238
3239 if (m_Parameters.m_PeepholeEnabled)
3240 {
3241 auto cellToForgetWeightsInfo = m_CellToForgetWeights->GetTensorInfo();
3242 ValidateTensorNumDimNumElem(cellToForgetWeightsInfo, 1, numUnits, " cellToForgetWeights");
3243 ValidateDataTypes(cellToForgetWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3244
3245 auto cellToOutputWeightsInfo = m_CellToOutputWeights->GetTensorInfo();
3246 ValidateTensorNumDimNumElem(cellToOutputWeightsInfo, 1, numUnits, " cellToOutputWeights");
3247 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToOutputWeightsInfo, descriptorName,
3248 "cellToForgetWeight", "cellToOutputWeights");
3249
3250 if (!m_Parameters.m_CifgEnabled)
3251 {
3252 auto cellToInputWeightsInfo = m_CellToInputWeights->GetTensorInfo();
3253 ValidateTensorNumDimNumElem(cellToInputWeightsInfo, 1, numUnits, " cellToInputWeights");
3254 ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToInputWeightsInfo, descriptorName,
3255 "cellToForgetWeights", "cellToInputWeights");
3256 }
3257 }
3258
3259 // Validate OPTIONAL params: Layer Norm Weights
3260 bool allLayerNormWeightsPresentOrNot =
3261 (((m_InputLayerNormWeights || m_Parameters.m_CifgEnabled) && m_ForgetLayerNormWeights
3262 && m_CellLayerNormWeights && m_OutputLayerNormWeights && m_Parameters.m_LayerNormEnabled)
3263 || (!m_InputLayerNormWeights && !m_ForgetLayerNormWeights && !m_CellLayerNormWeights
3264 && !m_OutputLayerNormWeights && !m_Parameters.m_LayerNormEnabled));
3265
3266 if (!allLayerNormWeightsPresentOrNot)
3267 {
3268 throw InvalidArgumentException(descriptorName +
3269 ": InputLayerNormWeights, ForgetLayerNormWeights, m_OutputLayerNormWeights "
3270 "and CellLayerNormWeights should all be present (Layer Norm enabled) or not "
3271 "be present at all (Layer Norm disabled). InputLayerNormWeights should "
3272 "only be present when Layer Norm is enabled and CIFG is disabled. "
3273 "m_Parameters.m_LayerNormEnabled should be set appropriately.");
3274 }
3275
3276 if (m_Parameters.m_LayerNormEnabled)
3277 {
3278 auto forgetLayerNormWeightsInfo = m_ForgetLayerNormWeights->GetTensorInfo();
3279 ValidateTensorNumDimNumElem(forgetLayerNormWeightsInfo, 1, numUnits, " forgetLayerNormWeights");
3280 ValidateDataTypes(forgetLayerNormWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3281
3282 auto cellLayerNormWeightsInfo = m_CellLayerNormWeights->GetTensorInfo();
3283 ValidateTensorNumDimNumElem(cellLayerNormWeightsInfo, 1, numUnits, " cellLayerNormWeights");
3284 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, cellLayerNormWeightsInfo, descriptorName,
3285 "forgetLayerNormWeights", "cellLayerNormWeights");
3286
3287 auto outputLayerNormWeightsInfo = m_OutputLayerNormWeights->GetTensorInfo();
3288 ValidateTensorNumDimNumElem(outputLayerNormWeightsInfo, 1, numUnits, " outputLayerNormWeights");
3289 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, outputLayerNormWeightsInfo, descriptorName,
3290 "forgetLayerNormWeights", "outputLayerNormWeights");
3291
3292 if (!m_Parameters.m_CifgEnabled)
3293 {
3294 auto inputLayerNormWeightsInfo = m_InputLayerNormWeights->GetTensorInfo();
3295 ValidateTensorNumDimNumElem(inputLayerNormWeightsInfo, 1, numUnits, " inputLayerNormWeights");
3296 ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, inputLayerNormWeightsInfo, descriptorName,
3297 "forgetLayerNormWeights", "inputLayerNormWeights");
3298 }
3299 }
3300
3301 // Validate OPTIONAL params: Projection (projectionWeights, projectionBias)
3302 bool correctProjectionTensorsPresent =
3303 ((!m_ProjectionWeights && !m_ProjectionBias && !m_Parameters.m_ProjectionEnabled) ||
3304 (m_ProjectionWeights && !m_ProjectionBias && m_Parameters.m_ProjectionEnabled) ||
3305 (m_ProjectionWeights && m_ProjectionBias && m_Parameters.m_ProjectionEnabled));
3306
3307 if (!correctProjectionTensorsPresent)
3308 {
3309 throw InvalidArgumentException(descriptorName +
3310 ": If projection is enabled, ProjectionWeights should be present and "
3311 "ProjectionBias is optional. If projection is disabled, neither "
3312 "ProjectionWeights nor ProjectionBias should be present.");
3313 }
3314
3315 if (m_Parameters.m_ProjectionEnabled)
3316 {
3317 auto projectionWeightsInfo = m_ProjectionWeights->GetTensorInfo();
3318 ValidateTensorNumDimNumElem(projectionWeightsInfo, 2, (numUnits * outputSize), "ProjectionWeights");
3319 ValidateDataTypes(projectionWeightsInfo, weightsSupportedTypes, descriptorName);
3320
3321 if (m_ProjectionBias)
3322 {
3323 auto projectionBiasInfo = m_ProjectionBias->GetTensorInfo();
3324 ValidateTensorNumDimNumElem(projectionBiasInfo, 1, outputSize, "ProjectionBias");
3325 ValidateDataTypes(projectionBiasInfo, biasSupportedTypes, descriptorName);
3326 }
3327
3328 }
3329 else if ((outputInfo.GetQuantizationScale() != m_Parameters.m_HiddenStateScale) &&
3330 outputInfo.GetQuantizationOffset() != m_Parameters.m_HiddenStateZeroPoint) {
3331 throw InvalidArgumentException(descriptorName +
3332 ": If projection is disabled, output quantization info (scale, offset) "
3333 "should match HiddenStateScale and HiddenStateZeroPoint.");
3334 }
3335
3336 }
3337
Validate(const WorkloadInfo & workloadInfo) const3338 void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3339 {
3340 const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
3341
3342 // Validate number of inputs/outputs
3343 ValidateNumInputs(workloadInfo, descriptorName, 3);
3344 ValidateNumOutputs(workloadInfo, descriptorName, 2);
3345
3346 // Input/output tensor infos
3347 auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3348 auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
3349 auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
3350
3351 auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3352 auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3353
3354 std::vector<DataType> inputOutputSupportedTypes =
3355 {
3356 DataType::QAsymmU8
3357 };
3358
3359 std::vector<DataType> cellStateSupportedTypes =
3360 {
3361 DataType::QSymmS16
3362 };
3363
3364 std::vector<DataType> weightsSupportedTypes =
3365 {
3366 DataType::QAsymmU8
3367 };
3368
3369 std::vector<DataType> biasSupportedTypes =
3370 {
3371 DataType::Signed32
3372 };
3373
3374 // Validate types of input/output tensors
3375 ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3376 ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3377 ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3378
3379 ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3380 ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3381
3382 // Validate matching types of input/output tensors
3383 ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3384 ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3385 "outputStateIn", "outputStateOut");
3386 ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3387
3388 // Validate matching quantization info for input/output tensors
3389 ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3390 ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
3391 ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3392
3393 // Infer number of batches, input size and output size from tensor dimensions
3394 const uint32_t numBatches = inputInfo.GetShape()[0];
3395 const uint32_t inputSize = inputInfo.GetShape()[1];
3396 const uint32_t outputSize = cellStateInInfo.GetShape()[1];
3397
3398 // Validate number of dimensions and number of elements for input/output tensors
3399 ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3400 ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
3401 ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3402 ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
3403 ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3404
3405 // Validate number of dimensions and number of elements for weights tensors
3406 ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
3407 auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3408 ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
3409
3410 ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3411 auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3412 ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
3413
3414 ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3415 auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3416 ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
3417
3418 ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3419 auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3420 ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
3421
3422 ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
3423 auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3424 ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
3425
3426 ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3427 auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3428 ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
3429 " RecurrentToForgetWeights");
3430
3431 ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3432 auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3433 ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3434
3435 ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3436 auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3437 ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3438
3439 // Validate data types for weights tensors (all should match each other)
3440 ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
3441
3442 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
3443 "inputToInputWeights", "inputToForgetWeights");
3444 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
3445 "inputToInputWeights", "inputToCellWeights");
3446 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3447 "inputToInputWeights", "inputToOutputWeights");
3448
3449 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3450 "inputToInputWeights", "recurrentToInputWeights");
3451 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3452 "inputToInputWeights", "recurrentToForgeteights");
3453 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3454 "inputToInputWeights", "recurrentToCellWeights");
3455 ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3456 "inputToInputWeights", "recurrentToOutputWeights");
3457
3458 // Validate matching quantization info for weight tensors (all should match each other)
3459 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
3460 descriptorName, "inputToInputWeights", "inputToForgetWeights");
3461 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
3462 descriptorName, "inputToInputWeights", "inputToCellWeights");
3463 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
3464 descriptorName, "inputToInputWeights", "inputToOutputWeights");
3465
3466 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
3467 descriptorName, "inputToInputWeights", "recurrentToInputWeights");
3468 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
3469 descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
3470 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
3471 descriptorName, "inputToInputWeights", "recurrentToCellWeights");
3472 ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
3473 descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
3474
3475 // Validate number of dimensions and number of elements in bias tensors
3476 ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
3477 auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3478 ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
3479
3480 ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3481 auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3482 ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
3483
3484 ValidatePointer(m_CellBias, descriptorName, "CellBias");
3485 auto cellBiasInfo = m_CellBias->GetTensorInfo();
3486 ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
3487
3488 ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3489 auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3490 ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
3491
3492 // Validate data types for bias tensors (all should match each other)
3493 ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
3494
3495 ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
3496 "inputGateBias", "forgetGateBias");
3497 ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
3498 "inputGateBias", "cellBias");
3499 ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
3500 "inputGateBias", "outputGateBias");
3501
3502 // Validate bias tensor quantization info
3503 ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3504 ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3505 ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3506 ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3507 }
3508
Validate(const WorkloadInfo & workloadInfo) const3509 void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3510 {
3511 const std::string descriptorName{"AbsQueueDescriptor"};
3512
3513 ValidateNumInputs(workloadInfo, descriptorName, 1);
3514 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3515
3516 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3517 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3518
3519 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3520
3521 std::vector<DataType> supportedTypes =
3522 {
3523 DataType::BFloat16,
3524 DataType::Float16,
3525 DataType::Float32,
3526 DataType::QAsymmS8,
3527 DataType::QAsymmU8,
3528 DataType::QSymmS16,
3529 DataType::Signed32
3530 };
3531
3532 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3533 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3534 }
3535
Validate(const WorkloadInfo & workloadInfo) const3536 void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3537 {
3538 const std::string descriptorName{"SliceQueueDescriptor"};
3539
3540 ValidateNumInputs(workloadInfo, descriptorName, 1);
3541 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3542
3543 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3544 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3545
3546 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3547
3548 const unsigned int rank = inputTensorInfo.GetNumDimensions();
3549 if (rank > 4)
3550 {
3551 throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
3552 }
3553
3554 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
3555
3556 // Check if m_Begin and m_Size have the expected length
3557 if (m_Parameters.m_Begin.size() != rank)
3558 {
3559 throw InvalidArgumentException(descriptorName +
3560 ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
3561 }
3562 if (m_Parameters.m_Size.size() != rank)
3563 {
3564 throw InvalidArgumentException(descriptorName +
3565 ": Length of size descriptor must equal rank " + std::to_string(rank));
3566 }
3567
3568 // Check if the shape of the output tensor matches m_Size
3569 const TensorShape& outputShape = outputTensorInfo.GetShape();
3570 for (unsigned int i = 0u; i < rank; ++i)
3571 {
3572 if (m_Parameters.m_Size[i] != outputShape[i])
3573 {
3574 throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
3575 }
3576 }
3577
3578 // Check if the sum of begin offset and size in a given dimension
3579 // does not exceed the size of corresponding input
3580 const TensorShape& inputShape = inputTensorInfo.GetShape();
3581 for(unsigned int i = 0u; i < rank; ++i)
3582 {
3583 if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
3584 {
3585 throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
3586 std::to_string(i) + " exceeds input size.");
3587 }
3588 }
3589 }
3590
Validate(const WorkloadInfo & workloadInfo) const3591 void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3592 {
3593 const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
3594
3595 ValidateNumInputs(workloadInfo, descriptorName, 1);
3596 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3597
3598 const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
3599 const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
3600
3601 ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
3602 ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
3603
3604 std::vector<DataType> supportedTypes =
3605 {
3606 DataType::BFloat16,
3607 DataType::Float32,
3608 DataType::Float16,
3609 DataType::QAsymmS8,
3610 DataType::QAsymmU8,
3611 DataType::QSymmS16
3612 };
3613
3614 ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
3615 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
3616
3617 ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
3618
3619 if (m_Parameters.m_BlockSize == 0)
3620 {
3621 throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
3622 }
3623
3624 DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
3625 const unsigned int wIndex = dimensionIndices.GetWidthIndex();
3626 const unsigned int hIndex = dimensionIndices.GetHeightIndex();
3627 const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
3628
3629 const TensorShape& outputShape = outputInfo.GetShape();
3630 if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
3631 {
3632 throw InvalidArgumentException(descriptorName + ": Output width and height shape"
3633 "must be divisible by block size.");
3634 }
3635
3636 const TensorShape& inputShape = inputInfo.GetShape();
3637 if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
3638 {
3639 throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
3640 "must be divisible by the square of block size." );
3641 }
3642 }
3643
Validate(const WorkloadInfo & workloadInfo) const3644 void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3645 {
3646 const std::string descriptorName{"ComparisonQueueDescriptor"};
3647
3648 ValidateNumInputs(workloadInfo, descriptorName, 2);
3649 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3650
3651 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3652 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3653 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3654
3655 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3656 inputTensorInfo1,
3657 outputTensorInfo,
3658 descriptorName,
3659 "input_0",
3660 "input_1");
3661
3662 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3663 {
3664 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3665 }
3666 }
3667
Validate(const WorkloadInfo & workloadInfo) const3668 void ElementwiseBinaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3669 {
3670 const std::string descriptorName{"ElementwiseBinaryQueueDescriptor"};
3671
3672 ValidateNumInputs(workloadInfo, descriptorName, 2);
3673 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3674
3675 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3676 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3677 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3678
3679 std::vector<DataType> supportedTypes =
3680 {
3681 DataType::BFloat16,
3682 DataType::Float16,
3683 DataType::Float32,
3684 DataType::QAsymmS8,
3685 DataType::QAsymmU8,
3686 DataType::QSymmS16,
3687 DataType::Signed32
3688 };
3689
3690 ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
3691 ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
3692
3693 ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input", "output");
3694 ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input", "output");
3695 }
3696
Validate(const WorkloadInfo & workloadInfo) const3697 void ElementwiseUnaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3698 {
3699 const std::string descriptorName{"ElementwiseUnaryQueueDescriptor"};
3700
3701 ValidateNumInputs(workloadInfo, descriptorName, 1);
3702 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3703
3704 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3705 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3706
3707 ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3708
3709 std::vector<DataType> supportedTypes =
3710 {
3711 DataType::BFloat16,
3712 DataType::Float16,
3713 DataType::Float32,
3714 DataType::QAsymmS8,
3715 DataType::QAsymmU8,
3716 DataType::QSymmS16,
3717 DataType::Signed32
3718 };
3719
3720 std::vector<DataType> logicalSupportedTypes =
3721 {
3722 DataType::Boolean
3723 };
3724
3725 if (m_Parameters.m_Operation == UnaryOperation::LogicalNot)
3726 {
3727 ValidateDataTypes(inputTensorInfo, logicalSupportedTypes, descriptorName);
3728 }
3729 else
3730 {
3731 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3732 }
3733
3734
3735 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3736 }
3737
Validate(const WorkloadInfo & workloadInfo) const3738 void RankQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3739 {
3740 const std::string descriptorName{"RankQueueDescriptor"};
3741
3742 ValidateNumInputs(workloadInfo, descriptorName, 1);
3743 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3744
3745 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3746 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3747
3748 ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
3749 ValidateTensorNumElements(outputTensorInfo, descriptorName, 1, "output");
3750
3751 std::vector<DataType> supportedTypes =
3752 {
3753 DataType::BFloat16,
3754 DataType::Float16,
3755 DataType::Float32,
3756 DataType::QAsymmS8,
3757 DataType::QAsymmU8,
3758 DataType::QSymmS8,
3759 DataType::QSymmS16,
3760 DataType::Signed32
3761 };
3762
3763 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3764 ValidateDataTypes(outputTensorInfo, { DataType::Signed32 }, descriptorName);
3765 }
3766
Validate(const WorkloadInfo & workloadInfo) const3767 void LogicalBinaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3768 {
3769 const std::string descriptorName{"LogicalBinaryQueueDescriptor"};
3770
3771 ValidateNumInputs(workloadInfo, descriptorName, 2);
3772 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3773
3774 const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3775 const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3776 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3777
3778 ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3779 inputTensorInfo1,
3780 outputTensorInfo,
3781 descriptorName,
3782 "input_0",
3783 "input_1");
3784
3785 if (inputTensorInfo0.GetDataType() != DataType::Boolean)
3786 {
3787 throw InvalidArgumentException(descriptorName + ": Input tensor 0 type must be Boolean.");
3788 }
3789
3790 if (inputTensorInfo1.GetDataType() != DataType::Boolean)
3791 {
3792 throw InvalidArgumentException(descriptorName + ": Input tensor 1 type must be Boolean.");
3793 }
3794
3795 if (outputTensorInfo.GetDataType() != DataType::Boolean)
3796 {
3797 throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3798 }
3799 }
3800
Validate(const WorkloadInfo & workloadInfo) const3801 void ReduceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3802 {
3803 const std::string descriptorName{"ReduceQueueDescriptor"};
3804
3805 ValidateNumInputs(workloadInfo, descriptorName, 1);
3806 ValidateNumOutputs(workloadInfo, descriptorName, 1);
3807
3808 const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3809 const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3810
3811 std::vector<DataType> supportedTypes =
3812 {
3813 DataType::BFloat16,
3814 DataType::Float16,
3815 DataType::Float32,
3816 DataType::QAsymmS8,
3817 DataType::QAsymmU8,
3818 DataType::QSymmS16,
3819 DataType::Signed32
3820 };
3821
3822 ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3823 ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3824 }
3825
Validate(const WorkloadInfo & workloadInfo) const3826 void UnidirectionalSequenceLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3827 {
3828 // Modified from LstmQueueDescriptor::Validate to support UnidirectionalSequenceLstm
3829
3830 const std::string descriptorName{"UnidirectionalSequenceLstmQueueDescriptor"};
3831
3832 // check dimensions of all inputs and outputs
3833 if (workloadInfo.m_InputTensorInfos.size() != 3)
3834 {
3835 throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
3836 }
3837 if (workloadInfo.m_OutputTensorInfos.size() != 3)
3838 {
3839 throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
3840 }
3841
3842 std::vector<DataType> supportedTypes =
3843 {
3844 DataType::Float32,
3845 DataType::QAsymmS8
3846 };
3847
3848 // check for supported type of one input and match them with all the other input and output
3849 ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
3850
3851 // Making sure clipping parameters have valid values.
3852 // == 0 means no clipping
3853 // > 0 means clipping
3854 if (m_Parameters.m_ClippingThresCell < 0.0f)
3855 {
3856 throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
3857 }
3858 if (m_Parameters.m_ClippingThresProj < 0.0f)
3859 {
3860 throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
3861 }
3862
3863 unsigned int batchIndx = 0;
3864 unsigned int inputIndx = 1;
3865 uint32_t timeStep = 1;
3866 unsigned int timeIndx = 1;
3867 inputIndx = 2;
3868 if (m_Parameters.m_TimeMajor)
3869 {
3870 batchIndx = 1;
3871 timeIndx = 0;
3872
3873 }
3874 timeStep = workloadInfo.m_InputTensorInfos[0].GetShape()[timeIndx];
3875
3876 // Inferring batch size, number of outputs and number of cells from the inputs.
3877 const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[inputIndx];
3878 const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[batchIndx];
3879 ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
3880 const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
3881 ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
3882 const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
3883
3884 // input tensor
3885 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 3, (timeStep * n_batch * n_input),
3886 descriptorName + " input_0");
3887 // outputStateInTensor
3888 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
3889 descriptorName + " input_1");
3890 // outputStateInTensor
3891 ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
3892 descriptorName + " input_2");
3893
3894 // outputTensor
3895 ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 3, (timeStep * n_batch * n_output),
3896 descriptorName + " output_0");
3897
3898 // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
3899 if ( m_InputToInputWeights )
3900 {
3901 ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
3902 (n_cell * n_input), "InputLayerNormWeights");
3903 }
3904
3905 ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
3906 ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
3907 (n_cell * n_input), "InputToForgetWeights");
3908
3909 ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
3910 ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
3911 (n_cell * n_input), "InputToCellWeights");
3912
3913 if ( m_RecurrentToInputWeights )
3914 {
3915 ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
3916 (n_cell * n_output), "RecurrentToInputWeights");
3917 }
3918
3919 ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
3920 ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
3921 (n_cell * n_output), "RecurrentToForgetWeights");
3922
3923 ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
3924 ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
3925 (n_cell * n_output), "RecurrentToCellWeights");
3926
3927 // Make sure the input-gate's parameters are either both present (regular
3928 // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
3929 bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
3930 !m_Parameters.m_CifgEnabled) ||
3931 (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3932 m_Parameters.m_CifgEnabled));
3933 if (!cifg_weights_all_or_none)
3934 {
3935 throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
3936 "RecurrentToInputWeights must either both be present (regular LSTM) "
3937 "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
3938 "accordingly.");
3939 }
3940
3941 if ( m_CellToInputWeights )
3942 {
3943 ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
3944 n_cell, "CellToInputWeights");
3945 }
3946 if ( m_CellToForgetWeights )
3947 {
3948 ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
3949 n_cell, "CellToForgetWeights");
3950 }
3951 if ( m_CellToOutputWeights )
3952 {
3953 ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
3954 n_cell, "CellToOutputWeights");
3955 }
3956
3957 // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
3958 bool peephole_weights_all_or_none =
3959 (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
3960 && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
3961 || ( !m_CellToInputWeights && !m_CellToForgetWeights
3962 && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
3963 if (!peephole_weights_all_or_none)
3964 {
3965 throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
3966 }
3967
3968 // Make sure the input gate bias is present only when not a CIFG-LSTM.
3969 if (m_Parameters.m_CifgEnabled)
3970 {
3971 if (m_InputGateBias)
3972 {
3973 throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
3974 }
3975 }
3976 else
3977 {
3978 if (!m_InputGateBias)
3979 {
3980 throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
3981 "must be present.");
3982 }
3983 ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
3984 n_cell, "InputGateBias");
3985 }
3986
3987 ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
3988 ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
3989
3990 ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
3991 ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
3992
3993 ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
3994 ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
3995
3996 if (m_ProjectionWeights)
3997 {
3998 ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
3999 (n_cell * n_output), "ProjectionWeights");
4000 }
4001 if (m_ProjectionBias)
4002 {
4003 ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
4004 }
4005
4006 // Making sure the projection tensors are consistent:
4007 // 1) If projection weight is not present, then projection bias should not be
4008 // present.
4009 // 2) If projection weight is present, then projection bias is optional.
4010 bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
4011 !m_Parameters.m_ProjectionEnabled)
4012 || (m_ProjectionWeights && !m_ProjectionBias &&
4013 m_Parameters.m_ProjectionEnabled)
4014 || (m_ProjectionWeights && m_ProjectionBias &&
4015 m_Parameters.m_ProjectionEnabled));
4016 if (!projecton_tensors_consistent)
4017 {
4018 throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
4019 }
4020
4021 // The four layer normalization weights either all have values or none of them have values. Additionally, if
4022 // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
4023 // either all have values or none of them have values. Layer normalization is used when the values of all the
4024 // layer normalization weights are present
4025 if (m_InputLayerNormWeights)
4026 {
4027 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
4028 }
4029 if (m_ForgetLayerNormWeights)
4030 {
4031 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
4032 }
4033 if (m_CellLayerNormWeights)
4034 {
4035 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
4036 }
4037 if (m_OutputLayerNormWeights)
4038 {
4039 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
4040 }
4041
4042 if (m_Parameters.m_LayerNormEnabled)
4043 {
4044 if (!m_Parameters.m_CifgEnabled)
4045 {
4046 if (!m_InputLayerNormWeights)
4047 {
4048 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
4049 "disabled but InputLayerNormWeights are not present");
4050 }
4051 ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
4052 1, n_cell, "InputLayerNormWeights");
4053 }
4054 else if (m_InputLayerNormWeights)
4055 {
4056 throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
4057 "enabled");
4058 }
4059
4060 ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
4061 "ForgetLayerNormWeights");
4062 ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
4063
4064 ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
4065 "OutputLayerNormWeights");
4066 ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
4067
4068 ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
4069 "CellLayerNormWeights");
4070 ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
4071 }
4072 else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
4073 {
4074 throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
4075 "normalisation weights are present.");
4076 }
4077 }
4078
Validate(const WorkloadInfo & workloadInfo) const4079 void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
4080 {
4081 const std::string descriptorName{"BatchMatMulDescriptor"};
4082
4083 ValidateNumInputs(workloadInfo, descriptorName, 2);
4084 ValidateNumOutputs(workloadInfo, descriptorName, 1);
4085
4086 // Inputs must be: both 2D+
4087 // For inputs X and Y whose dimensions to be multiplied are (M,N) and (I,J) respectively,
4088 // axes N and I must be the same size
4089
4090 const auto& inputXInfoBeforeParams = workloadInfo.m_InputTensorInfos[0];
4091 const auto& inputYInfoBeforeParams = workloadInfo.m_InputTensorInfos[1];
4092 const auto& outputInfo = workloadInfo.m_OutputTensorInfos[0];
4093 // Output info has already been inferred
4094
4095 std::vector<DataType> supportedTypes =
4096 {
4097 DataType::BFloat16,
4098 DataType::Float16,
4099 DataType::Float32,
4100 DataType::QAsymmS8,
4101 DataType::QAsymmU8,
4102 DataType::QSymmS16
4103 };
4104
4105 ValidateDataTypes(inputXInfoBeforeParams, supportedTypes, descriptorName);
4106 ValidateDataTypes(inputYInfoBeforeParams, supportedTypes, descriptorName);
4107 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
4108
4109 if ((inputXInfoBeforeParams.GetNumDimensions() < 2) ||
4110 (inputYInfoBeforeParams.GetNumDimensions() < 2))
4111 {
4112 throw InvalidArgumentException(descriptorName + ": Input tensors are not 2D or greater.");
4113 }
4114
4115 TensorInfo inputXInfoAfterParams;
4116 TensorInfo inputYInfoAfterParams;
4117
4118 if((m_Parameters.m_TransposeX && m_Parameters.m_AdjointX) ||
4119 (m_Parameters.m_TransposeY && m_Parameters.m_AdjointY))
4120 {
4121 throw InvalidArgumentException(descriptorName +
4122 ": Invalid descriptor parameters - Transpose and Adjoint "
4123 "cannot both be true for a given input tensor.");
4124 }
4125 if(m_Parameters.m_TransposeX)
4126 {
4127 inputXInfoAfterParams = armnnUtils::Permuted(inputXInfoBeforeParams,
4128 BatchMatMulDescriptor::GetPermuteVec(
4129 m_Parameters.m_DataLayoutX,
4130 inputXInfoBeforeParams.GetShape()));
4131 }
4132 else if(m_Parameters.m_AdjointX)
4133 {
4134 auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX,
4135 inputXInfoBeforeParams.GetShape());
4136 if(inputXInfoBeforeParams.GetShape()[axesToMul.first] !=
4137 inputXInfoBeforeParams.GetShape()[axesToMul.second])
4138 {
4139 throw InvalidArgumentException(descriptorName +
4140 ": Adjoint is set to true for input tensor X, but the axes to be adjointed are not square." );
4141 }
4142 // Shape remains the same as it's square
4143 inputXInfoAfterParams = inputXInfoBeforeParams;
4144 }
4145 else
4146 {
4147 inputXInfoAfterParams = inputXInfoBeforeParams;
4148 }
4149
4150 if(m_Parameters.m_TransposeY)
4151 {
4152 inputYInfoAfterParams = armnnUtils::Permuted(inputYInfoBeforeParams,
4153 BatchMatMulDescriptor::GetPermuteVec(
4154 m_Parameters.m_DataLayoutY,
4155 inputYInfoBeforeParams.GetShape()));
4156 }
4157 else if(m_Parameters.m_AdjointY)
4158 {
4159 auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY,
4160 inputYInfoBeforeParams.GetShape());
4161 if(inputYInfoBeforeParams.GetShape()[axesToMul.first] !=
4162 inputYInfoBeforeParams.GetShape()[axesToMul.second])
4163 {
4164 throw InvalidArgumentException(descriptorName +
4165 ": Adjoint is set to true for input tensor Y, but the axes to be adjointed are not square." );
4166 }
4167 // Shape remains the same as it's square
4168 inputYInfoAfterParams = inputYInfoBeforeParams;
4169 }
4170 else
4171 {
4172 inputYInfoAfterParams = inputYInfoBeforeParams;
4173 }
4174
4175 switch(m_Parameters.m_DataLayoutX)
4176 {
4177 case DataLayout::NCDHW:
4178 case DataLayout::NDHWC:
4179 if(inputXInfoAfterParams.GetNumDimensions() < 3)
4180 {
4181 throw InvalidArgumentException(descriptorName +
4182 ": Input tensor X does not have the correct "
4183 "number of dimensions for the Data Layout that it has been assigned.");
4184 }
4185 break;
4186 case DataLayout::NCHW:
4187 case DataLayout::NHWC:
4188 default:
4189 break;
4190 }
4191
4192 switch(m_Parameters.m_DataLayoutY)
4193 {
4194 case DataLayout::NCDHW:
4195 case DataLayout::NDHWC:
4196 if(inputYInfoAfterParams.GetNumDimensions() < 3)
4197 {
4198 throw InvalidArgumentException(descriptorName +
4199 ": Input tensor Y does not have the correct "
4200 "number of dimensions for the Data Layout that it has been assigned.");
4201 }
4202 break;
4203 case DataLayout::NCHW:
4204 case DataLayout::NHWC:
4205 default:
4206 break;
4207 }
4208
4209 auto axesXToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX,
4210 inputXInfoAfterParams.GetShape());
4211 auto axesYToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY,
4212 inputXInfoBeforeParams.GetShape());
4213
4214 if(inputXInfoAfterParams.GetShape()[axesXToMul.second]
4215 != inputYInfoAfterParams.GetShape()[axesYToMul.first])
4216 {
4217 throw InvalidArgumentException(descriptorName +
4218 ": The final axis of input tensor X must be the same size as "
4219 "the second last axis of input tensor Y.");
4220 }
4221
4222 { // Separate scope so we don't pollute the rest of the scope with our temp variables
4223 // e.g. NHWC isnt compatible with NCHW as of now
4224 DataLayout xLayout = m_Parameters.m_DataLayoutX;
4225 DataLayout yLayout = m_Parameters.m_DataLayoutY;
4226
4227 if(xLayout == DataLayout::NCHW || xLayout == DataLayout::NCDHW)
4228 {
4229 if(yLayout == DataLayout::NHWC || yLayout == DataLayout::NDHWC)
4230 {
4231 throw InvalidArgumentException(descriptorName +
4232 ": Invalid input tensor data layout combination.");
4233 }
4234 }
4235 if(yLayout == DataLayout::NCHW || yLayout == DataLayout::NCDHW)
4236 {
4237 if(xLayout == DataLayout::NHWC || xLayout == DataLayout::NDHWC)
4238 {
4239 throw InvalidArgumentException(descriptorName +
4240 ": Invalid input tensor data layout combination.");
4241 }
4242 }
4243 }
4244
4245 // Simulate aligning the ends of the matrix dims and prepending 1's to the beginning of the shorter one
4246 unsigned int outputTensorDimSize = std::max(inputXInfoAfterParams.GetNumDimensions(),
4247 inputYInfoAfterParams.GetNumDimensions());
4248 if(outputTensorDimSize-2 > 0)
4249 {
4250 TensorInfo tiXNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4251 DataType::Float32);
4252 TensorInfo tiYNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4253 DataType::Float32);
4254 TensorInfo tiOutNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4255 DataType::Float32);
4256
4257 auto doAxisExtension = [&](std::vector<unsigned int> axisIndices, TensorInfo& ti)
4258 {
4259 auto sizeDiff = (outputTensorDimSize-2) - axisIndices.size();
4260
4261 for(unsigned int i = 0; i < sizeDiff; i++)
4262 {
4263 axisIndices.insert(axisIndices.begin(), 1);
4264 }
4265
4266 for(unsigned int i = 0; i < ti.GetNumDimensions(); i++)
4267 {
4268 ti.GetShape()[i] = inputXInfoAfterParams.GetShape()[i];
4269 }
4270 };
4271
4272 auto axesXNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutX,
4273 inputXInfoAfterParams.GetShape());
4274 auto axesYNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutY,
4275 inputYInfoAfterParams.GetShape());
4276
4277 doAxisExtension(axesXNotMul, tiXNotMul);
4278 doAxisExtension(axesYNotMul, tiYNotMul);
4279
4280 for(unsigned int i = 0; i < tiOutNotMul.GetNumDimensions(); i++)
4281 {
4282 tiOutNotMul.GetShape()[i] = std::max(tiXNotMul.GetShape()[i],
4283 tiYNotMul.GetShape()[i]);
4284 }
4285
4286 ValidateBroadcastTensorShapesMatch(tiXNotMul,
4287 tiYNotMul,
4288 tiOutNotMul,
4289 descriptorName,
4290 "input_X",
4291 "input_Y");
4292 }
4293 }
4294
4295
4296 } // namespace armnn