xref: /aosp_15_r20/external/federated-compute/fcp/aggregation/core/input_tensor_list.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2023 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef FCP_AGGREGATION_CORE_INPUT_TENSOR_LIST_H_
17 #define FCP_AGGREGATION_CORE_INPUT_TENSOR_LIST_H_
18 
19 #include <cstddef>
20 #include <cstdint>
21 #include <initializer_list>
22 #include <vector>
23 
24 #include "fcp/aggregation/core/tensor.h"
25 
26 namespace fcp {
27 namespace aggregation {
28 
29 // Maximum size of InputTensorList for which inlined storage will be used.
30 // Any InputTensorList with more elements than kInlinedSize will use allocated
31 // storage.
32 // TODO(team): Determine optimal size for this constant based on
33 // microbenchmarks.
34 constexpr int32_t kInlinedSize = 5;
35 
36 // InputTensorList holds pointers to some number of unowned tensors to be used
37 // as input to a function.
38 //
39 // For efficiency, if there are fewer than kInlinedSize tensors, the memory to
40 // hold the pointers is inlined rather than allocated.
41 class InputTensorList final {
42  public:
43   typedef const Tensor* const* const_iterator;
44 
45   // Creates an InputTensorList with the provided elements.
46   InputTensorList(std::initializer_list<const Tensor*> list);
47 
48   // Creates an InputTensorList with a single input tensor.
InputTensorList(const Tensor & tensor)49   InputTensorList(const Tensor& tensor) : InputTensorList({&tensor}) {}
50 
51   // Creates an InputTensorList of a specific size. All elements will initially
52   // be set to nullptr.
53   explicit InputTensorList(size_t size);
54 
55   // InputTensorList class isn't copyable.
56   InputTensorList(const InputTensorList&) = delete;
57 
58   // Move constructor.
59   InputTensorList(InputTensorList&& other);
60 
61   // Move assignment.
62   InputTensorList& operator=(InputTensorList&& other);
63 
64   ~InputTensorList();
65 
begin()66   inline const_iterator begin() const { return data_ptr_; }
67 
end()68   inline const_iterator end() const { return data_ptr_ + size_; }
69 
size()70   inline size_t size() const { return size_; }
71 
72   inline const Tensor* const& operator[](size_t i) const {
73     return data_ptr_[i];
74   }
75 
76   inline const Tensor*& operator[](size_t i) { return data_ptr_[i]; }
77 
78  private:
79   union DataStorage {
DataStorage()80     constexpr DataStorage() : inlined{} {};
~DataStorage()81     ~DataStorage() {}
82     const Tensor* inlined[kInlinedSize];
83     std::vector<const Tensor*> allocated;
84   };
85 
86   void MoveData(InputTensorList&& other);
87 
88   size_t size_;
89   bool is_allocated_;
90   DataStorage data_storage_;
91   const Tensor** data_ptr_;
92 };
93 
94 }  // namespace aggregation
95 }  // namespace fcp
96 
97 #endif  // FCP_AGGREGATION_CORE_INPUT_TENSOR_LIST_H_
98