1 // 2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 // List of Layer Support Rules common to TOSA backends only, for use with CheckSupportRule() 9 10 struct TosaOperatorAttributeOfAny : public Rule 11 { 12 template<typename Container> TosaOperatorAttributeOfAnyTosaOperatorAttributeOfAny13 explicit TosaOperatorAttributeOfAny(TosaSerializationOperator* op, const Container& c) 14 { 15 m_Res = std::any_of(c.begin(), c.end(), [&op](Attribute attribute) 16 { 17 return attribute == op->GetAttributeType(); 18 }); 19 } 20 }; 21 22 struct TosaTypeAnyOf : public Rule 23 { 24 template<typename Container> TosaTypeAnyOfTosaTypeAnyOf25 TosaTypeAnyOf(TosaSerializationTensor* tensor, const Container& c) 26 { 27 m_Res = std::any_of(c.begin(), c.end(), [&tensor](DType dt) 28 { 29 return dt == tensor->GetDtype(); 30 }); 31 } 32 }; 33 34 struct TosaTensorNumDimensionsWithinBounds : public Rule 35 { TosaTensorNumDimensionsWithinBoundsTosaTensorNumDimensionsWithinBounds36 explicit TosaTensorNumDimensionsWithinBounds(TosaSerializationTensor* tensor) 37 { 38 m_Res = (tensor->GetShape().size() <= MaxNumOfTensorDimensions) || (!tensor->GetShape().empty()); 39 } 40 }; 41 42 struct TosaAssertSize : public Rule 43 { 44 template<typename Container> TosaAssertSizeTosaAssertSize45 explicit TosaAssertSize(const Container& c1, const Container& c2) 46 { 47 m_Res = (c1.size() == c2.size()); 48 } 49 }; 50 51 struct TosaContainerContainsTwoTypes : public Rule 52 { TosaContainerContainsTwoTypesTosaContainerContainsTwoTypes53 explicit TosaContainerContainsTwoTypes(std::tuple<DType, DType>& check, 54 const std::vector<std::tuple<DType, DType>>& c) 55 { 56 for (auto item: c) 57 { 58 if (std::get<0>(check) == std::get<0>(item) && 59 std::get<1>(check) == std::get<1>(item)) 60 { 61 m_Res = true; 62 return; 63 } 64 } 65 m_Res = false; 66 } 67 }; 68 69 struct TosaContainerContainsThreeTypes : public Rule 70 { TosaContainerContainsThreeTypesTosaContainerContainsThreeTypes71 explicit TosaContainerContainsThreeTypes(std::tuple<DType, DType, DType>& check, 72 const std::vector<std::tuple<DType, DType, DType>>& c) 73 { 74 for (auto item: c) 75 { 76 if (std::get<0>(check) == std::get<0>(item) && 77 std::get<1>(check) == std::get<1>(item) && 78 std::get<2>(check) == std::get<2>(item)) 79 { 80 m_Res = true; 81 return; 82 } 83 } 84 m_Res = false; 85 } 86 }; 87