xref: /aosp_15_r20/external/ComputeLibrary/src/dynamic_fusion/sketch/ArgumentPack.h (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef SRC_DYNAMIC_FUSION_SKETCH_ARGUMENTPACK
25 #define SRC_DYNAMIC_FUSION_SKETCH_ARGUMENTPACK
26 
27 #include "arm_compute/core/experimental/Types.h"
28 #include <unordered_map>
29 #include <vector>
30 
31 namespace arm_compute
32 {
33 namespace experimental
34 {
35 namespace dynamic_fusion
36 {
37 /** This is a generic class that packs the arguments of an operator. For now, it is only used for tensor-related types
38  * Examples of "tensor-related types": @ref ITensorInfo, @ref ITensor, @ref ICLTensor
39  *
40  * The argument id is the position of the argument within the pack, and is represented by @ref TensorType
41  *
42  * @tparam T Tensor-related type
43  */
44 template <typename T>
45 class ArgumentPack
46 {
47 public:
48     /** @ref TensorType encodes the position of a tensor argument within the pack */
49     using Id = TensorType;
50     /** A single argument element within the pack
51      * It contains either a const pointer or a non-const pointer to the Tensor-related type T, but never at the same time
52      */
53     struct PackElement
54     {
55         PackElement()                        = default;
56         PackElement(const PackElement &elem) = default;
57         PackElement &operator=(const PackElement &elem) = default;
58         PackElement(PackElement &&elem)                 = default;
59         PackElement &operator=(PackElement &&elem) = default;
PackElementPackElement60         PackElement(Id id, T *tensor)
61             : id(id), tensor(tensor), ctensor(nullptr)
62         {
63         }
PackElementPackElement64         PackElement(Id id, const T *ctensor)
65             : id(id), tensor(nullptr), ctensor(ctensor)
66         {
67         }
68 
69         Id       id{ ACL_UNKNOWN }; /**< Argument id within the pack */
70         T       *tensor{ nullptr }; /**< Non-const pointer to tensor-related object */
71         const T *ctensor
72         {
73             nullptr
74         }; /**< Const pointer to tensor-related object */
75     };
76 
77 public:
78     /** Default constructor */
79     ArgumentPack() = default;
80     /** Destructor */
81     ~ArgumentPack() = default;
82     /** Allow instances of this class to be copy constructed */
83     ArgumentPack<T>(const ArgumentPack<T> &other) = default;
84     /** Allow instances of this class to be copied */
85     ArgumentPack<T> &operator=(const ArgumentPack<T> &other) = default;
86     /** Allow instances of this class to be move constructed */
87     ArgumentPack<T>(ArgumentPack<T> &&other) = default;
88     /** Allow instances of this class to be moved */
89     ArgumentPack<T> &operator=(ArgumentPack<T> &&other) = default;
90     /** Initializer list Constructor */
ArgumentPack(const std::initializer_list<PackElement> & l)91     ArgumentPack(const std::initializer_list<PackElement> &l)
92         : _pack{}
93     {
94         for(const auto &e : l)
95         {
96             _pack[e.id] = e;
97         }
98     }
99     /** Add tensor to the pack
100      *
101      * @param[in] id     ID of the tensor to add
102      * @param[in] tensor Tensor to add
103      */
add_tensor(Id id,T * tensor)104     void add_tensor(Id id, T *tensor)
105     {
106         _pack[id] = PackElement(id, tensor);
107     }
108     /** Add const tensor to the pack
109      *
110      * @param[in] id     ID of the tensor to add
111      * @param[in] tensor Tensor to add
112      */
add_const_tensor(Id id,const T * tensor)113     void add_const_tensor(Id id, const T *tensor)
114     {
115         _pack[id] = PackElement(id, tensor);
116     }
117     /** Get tensor of a given id from the pack
118      *
119      * @param[in] id ID of tensor to extract
120      *
121      * @return The pointer to the tensor if exist and is non-const else nullptr
122      */
get_tensor(Id id)123     T *get_tensor(Id id)
124     {
125         auto it = _pack.find(id);
126         return it != _pack.end() ? it->second.tensor : nullptr;
127     }
128     /** Get constant tensor of a given id
129      *
130      * @param[in] id ID of tensor to extract
131      *
132      * @return The pointer to the tensor (const or not) if exist else nullptr
133      */
get_const_tensor(Id id)134     const T *get_const_tensor(Id id) const
135     {
136         auto it = _pack.find(id);
137         if(it != _pack.end())
138         {
139             return it->second.ctensor != nullptr ? it->second.ctensor : it->second.tensor;
140         }
141         return nullptr;
142     }
143     /** Remove the tensor stored with the given id
144      *
145      * @param[in] id ID of tensor to remove
146      */
remove_tensor(Id id)147     void remove_tensor(Id id)
148     {
149         _pack.erase(id);
150     }
151     /** Pack size accessor
152      *
153      * @return Number of tensors registered to the pack
154      */
size()155     size_t size() const
156     {
157         return _pack.size();
158     }
159     /** Checks if pack is empty
160      *
161      * @return True if empty else false
162      */
empty()163     bool empty() const
164     {
165         return _pack.empty();
166     }
167     /** Get the ACL_SRC_* tensors
168      *
169      * @return std::vector<T *>
170      */
get_src_tensors()171     std::vector<T *> get_src_tensors()
172     {
173         std::vector<T *> src_tensors{};
174         for(int id = static_cast<int>(TensorType::ACL_SRC); id <= static_cast<int>(TensorType::ACL_SRC_END); ++id)
175         {
176             auto tensor = get_tensor(static_cast<TensorType>(id));
177             if(tensor != nullptr)
178             {
179                 src_tensors.push_back(tensor);
180             }
181         }
182         return src_tensors;
183     }
184     /** Get the const ACL_SRC_* tensors
185      *
186      * @return std::vector<const T *>
187      */
get_const_src_tensors()188     std::vector<const T *> get_const_src_tensors() const
189     {
190         std::vector<const T *> src_tensors{};
191         for(int id = static_cast<int>(TensorType::ACL_SRC); id <= static_cast<int>(TensorType::ACL_SRC_END); ++id)
192         {
193             auto tensor = get_const_tensor(static_cast<TensorType>(id));
194             if(tensor != nullptr)
195             {
196                 src_tensors.push_back(tensor);
197             }
198         }
199         return src_tensors;
200     }
201     /** Get the ACL_DST_* tensors
202      *
203      * @return std::vector<T *>
204      */
get_dst_tensors()205     std::vector<T *> get_dst_tensors()
206     {
207         std::vector<T *> dst_tensors{};
208         for(int id = static_cast<int>(TensorType::ACL_DST); id <= static_cast<int>(TensorType::ACL_DST_END); ++id)
209         {
210             auto tensor = get_tensor(static_cast<TensorType>(id));
211             if(tensor != nullptr)
212             {
213                 dst_tensors.push_back(tensor);
214             }
215         }
216         return dst_tensors;
217     }
218     /** Get the const ACL_DST_* tensors
219      *
220      * @return std::vector<const T *>
221      */
get_const_dst_tensors()222     std::vector<const T *> get_const_dst_tensors() const
223     {
224         std::vector<const T *> dst_tensors{};
225         for(int id = static_cast<int>(TensorType::ACL_DST); id <= static_cast<int>(TensorType::ACL_DST_END); ++id)
226         {
227             auto tensor = get_const_tensor(static_cast<TensorType>(id));
228             if(tensor != nullptr)
229             {
230                 dst_tensors.push_back(tensor);
231             }
232         }
233         return dst_tensors;
234     }
235 
236 private:
237     std::unordered_map<int, PackElement> _pack{}; /**< Container with the packed tensors */
238 };
239 } // namespace dynamic_fusion
240 } // namespace experimental
241 } // namespace arm_compute
242 #endif /* SRC_DYNAMIC_FUSION_SKETCH_ARGUMENTPACK */
243