xref: /aosp_15_r20/external/armnn/src/backends/tosaCommon/TosaLayerSupportRules.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 // List of Layer Support Rules common to TOSA backends only, for use with CheckSupportRule()
9 
10 struct TosaOperatorAttributeOfAny : public Rule
11 {
12     template<typename Container>
TosaOperatorAttributeOfAnyTosaOperatorAttributeOfAny13     explicit TosaOperatorAttributeOfAny(TosaSerializationOperator* op, const Container& c)
14     {
15         m_Res = std::any_of(c.begin(), c.end(), [&op](Attribute attribute)
16         {
17             return attribute == op->GetAttributeType();
18         });
19     }
20 };
21 
22 struct TosaTypeAnyOf : public Rule
23 {
24     template<typename Container>
TosaTypeAnyOfTosaTypeAnyOf25     TosaTypeAnyOf(TosaSerializationTensor* tensor, const Container& c)
26     {
27         m_Res = std::any_of(c.begin(), c.end(), [&tensor](DType dt)
28         {
29             return dt == tensor->GetDtype();
30         });
31     }
32 };
33 
34 struct TosaTensorNumDimensionsWithinBounds : public Rule
35 {
TosaTensorNumDimensionsWithinBoundsTosaTensorNumDimensionsWithinBounds36     explicit TosaTensorNumDimensionsWithinBounds(TosaSerializationTensor* tensor)
37     {
38         m_Res = (tensor->GetShape().size() <= MaxNumOfTensorDimensions) || (!tensor->GetShape().empty());
39     }
40 };
41 
42 struct TosaAssertSize : public Rule
43 {
44     template<typename Container>
TosaAssertSizeTosaAssertSize45     explicit TosaAssertSize(const Container& c1, const Container& c2)
46     {
47         m_Res = (c1.size() == c2.size());
48     }
49 };
50 
51 struct TosaContainerContainsTwoTypes : public Rule
52 {
TosaContainerContainsTwoTypesTosaContainerContainsTwoTypes53     explicit TosaContainerContainsTwoTypes(std::tuple<DType, DType>& check,
54                                            const std::vector<std::tuple<DType, DType>>& c)
55     {
56         for (auto item: c)
57         {
58             if (std::get<0>(check) == std::get<0>(item) &&
59                 std::get<1>(check) == std::get<1>(item))
60             {
61                 m_Res = true;
62                 return;
63             }
64         }
65         m_Res = false;
66     }
67 };
68 
69 struct TosaContainerContainsThreeTypes : public Rule
70 {
TosaContainerContainsThreeTypesTosaContainerContainsThreeTypes71     explicit TosaContainerContainsThreeTypes(std::tuple<DType, DType, DType>& check,
72                                              const std::vector<std::tuple<DType, DType, DType>>& c)
73     {
74         for (auto item: c)
75         {
76             if (std::get<0>(check) == std::get<0>(item) &&
77                 std::get<1>(check) == std::get<1>(item) &&
78                 std::get<2>(check) == std::get<2>(item))
79             {
80                 m_Res = true;
81                 return;
82             }
83         }
84         m_Res = false;
85     }
86 };
87