xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/shim/tensor_view.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_SHIM_TENSOR_VIEW_H_
16 #define TENSORFLOW_LITE_KERNELS_SHIM_TENSOR_VIEW_H_
17 
18 #include "absl/status/statusor.h"
19 #include "absl/strings/string_view.h"
20 #include "absl/types/span.h"
21 #include "absl/types/variant.h"
22 #include "tensorflow/core/platform/logging.h"
23 #include "tensorflow/core/platform/tstring.h"
24 
25 namespace tflite {
26 namespace shim {
27 
28 // A type deduction template which is specialized for TF and TFLite.
29 // That is it maps
30 //   ::tensorflow::Tensor -> tflite::shim::TfTensorView
31 //   ::TfLiteTensor -> tflite::shim::TfLiteTensorView
32 template <typename W>
33 struct TensorViewSubType {};
34 
35 // Common denominator for ::tflite::TfLiteTensor and ::tensorflow::Tensor.
36 // It is a "view" over the underlying tensor without taking ownership.
37 // Objects of this class can also mutate the underlying tensor depending on
38 // whether the underlying tensor is "const" qualified or not.
39 //
40 // Movable and copyable.
41 // It can be instantiated with the New() factory function. eg.
42 //   TfTensorView t           = TensorView::New(&tf_tensor);
43 //   const TfTensorView t     = TensorView::New(&const_tf_tensor);
44 //   TfLiteTensorView t       = TensorView::New(&tflite_tensor);
45 //   const TfLiteTensorView t = TensorView::New(&const_tflite_tensor);
46 class TensorView {
47  protected:
48   // Union over all data types
49   using DataVariantType =
50       absl::variant<absl::Span<bool>, absl::Span<uint8_t>, absl::Span<uint64_t>,
51                     absl::Span<int8_t>, absl::Span<int16_t>,
52                     absl::Span<int32_t>, absl::Span<int64_t>, absl::Span<float>,
53                     absl::Span<double>, absl::Span<::tensorflow::tstring>>;
54 
55   // An interface while provides convenient row-major indexing over the
56   // underlying tensor.
57   // Example usage:
58   //
59   //   // A scalar view
60   //   const TensorView t_float
61   //   float val = t_float.AsScalar<float>();
62   //
63   //   // A vector view
64   //   const TensorView t_int;
65   //   auto t_int_vec = t_int.As<int32_t, /*RANK=*/ 1>();
66   //   int sum = t_int_vec(0) + t_int_vec(1);
67   //
68   //   // A matrix view
69   //   TensorView t_str;
70   //   auto t_str_mat = t_str.As<tensorflow::tstring, /*RANK=*/ 2>();
71   //   t_str_mat(0, 0) = "abc";
72   //   t_str_mat(2, 3) = "def";
73   template <typename DType, int RANK>
74   class Tensor {
75    public:
Tensor(TensorView * t)76     explicit Tensor(TensorView *t)
77         : data_(t->Data<DType>()), shape_(t->Shape()) {
78       DCHECK_EQ(RANK, shape_.size());
79       ComputeRowSizes();
80     }
81 
Tensor(const TensorView * t)82     explicit Tensor(const TensorView *t)
83         : data_(t->Data<DType>()), shape_(t->Shape()) {
84       DCHECK_EQ(RANK, shape_.size());
85       ComputeRowSizes();
86     }
87 
88     // indexing operator
89     template <typename... IndexTypes>
operator()90     inline DType &operator()(IndexTypes... indices) {
91       const auto idx = RowMajorIndex(std::array<int, RANK>{{indices...}});
92       return data_[idx];
93     }
94 
95     // const indexing operator
96     template <typename... IndexTypes>
operator()97     inline const DType &operator()(IndexTypes... indices) const {
98       const auto idx = RowMajorIndex(std::array<int, RANK>{{indices...}});
99       return data_.at(idx);
100     }
101 
102     // Pointer accessor
Ptr()103     typename absl::Span<DType>::pointer Ptr() { return data_.data(); }
Ptr()104     constexpr typename absl::Span<DType>::const_pointer Ptr() const {
105       return data_.data();
106     }
107 
108     // Size of the given dimension
Dim(int dim_i)109     inline int Dim(int dim_i) const {
110       DCHECK(RANK > 0 && dim_i < RANK) << "dim: " << dim_i << " rank:" << RANK;
111       // Handle negative indices
112       if (dim_i < 0) dim_i = ((dim_i % RANK) + RANK) % RANK;
113       return shape_[dim_i];
114     }
115 
116     // The tensor's rank: number of dimensions
Rank()117     /*[[nodiscard]]*/ constexpr std::size_t Rank() const { return RANK; }
118 
119    private:
120     // Computes the row-major index
RowMajorIndex(const std::array<int,RANK> & indices)121     inline std::size_t RowMajorIndex(
122         const std::array<int, RANK> &indices) const {
123       std::size_t ret = 0;
124       for (int i = 0; i < RANK; ++i) ret += indices[i] * row_sizes_[i];
125       return ret;
126     }
127 
128     // Pre computes row sizes to convert multi dim indices into a row major
129     // index
ComputeRowSizes()130     void ComputeRowSizes() {
131       // Precompute row sizes for row major index computation
132       if (RANK > 0) {
133         row_sizes_[RANK - 1] = 1;
134         for (int i = RANK - 2; i >= 0; --i) {
135           row_sizes_[i] = row_sizes_[i + 1] * shape_[i + 1];
136         }
137       }
138     }
139 
140     absl::Span<DType> data_;
141     const absl::Span<int> shape_;
142     std::size_t row_sizes_[RANK]{};
143   };
144 
145  public:
146   // Factory which gets specialized for different wrapped tensor types.
147   template <typename W>
148   static absl::StatusOr<typename TensorViewSubType<W>::Type> New(
149       W *wrapped_tensor);
150 
151  protected:
152   // Move constructor
153   TensorView(TensorView &&o) = default;
154   // Copy constructor
155   TensorView(const TensorView &o) = default;
156   // Move assignment operator
157   TensorView &operator=(TensorView &&o) = default;
158   // Copy assignment operator
159   TensorView &operator=(const TensorView &) = default;
160 
161  public:
162   // Dtor
163   virtual ~TensorView() = default;
164 
165   // Accessors
166 
167   // Shape
Shape()168   absl::Span<int> Shape() { return shape_; }
Shape()169   /*[[nodiscard]]*/ const absl::Span<int> Shape() const { return shape_; }
170 
171   // Data
172   template <typename DType>
Data()173   absl::Span<DType> &Data() {
174     return absl::get<absl::Span<DType>>(data_);
175   }
176   template <typename DType>
Data()177   constexpr absl::Span<DType> Data() const {
178     return absl::get<absl::Span<DType>>(data_);
179   }
180 
181   // Reads the tensor given the dtype and its rank and provides an indexing
182   // operator.
183   template <typename DType, int RANK>
As()184   Tensor<DType, RANK> As() {
185     return Tensor<DType, RANK>(this);
186   }
187 
188   // Const version of As()
189   template <typename DType, int RANK>
As()190   const Tensor<DType, RANK> As() const {
191     return Tensor<DType, RANK>(this);
192   }
193 
194   // Read the given tensor as a scalar or return error if it isn't
195   template <typename DType>
196   DType &AsScalar();
197 
198   template <typename DType>
199   const DType &AsScalar() const;
200 
201  protected:
202   // Templated constructor. Since it's not possible to specify the template
203   // argument directly we place a dummy argument of that type so compiler
204   // can deduce the right template parameter
205   template <typename DType>
TensorView(const absl::Span<int> shape,void * data,const std::size_t data_size,const DType &)206   TensorView(const absl::Span<int> shape, void *data,
207              const std::size_t data_size, const DType &)
208       : shape_(shape),
209         data_(absl::Span<DType>(reinterpret_cast<DType *>(data),
210                                 data_size / sizeof(DType))) {}
211 
212   // Return the total number of elements given the shape.
NumElements(const absl::Span<int> shape)213   static constexpr std::size_t NumElements(const absl::Span<int> shape) {
214     std::size_t ret = 1;
215     for (const auto dim : shape) ret *= dim;
216     return ret;
217   }
218 
219   // Tensor shape
220   // Note: using int rather than size_t to avoid conversion to from TfLite shape
221   absl::Span<int> shape_;
222   // Tensor data
223   DataVariantType data_;
224 };
225 
226 // Add or remove const qualifier to O based on whether it is in I.
227 // For example
228 //   MatchConstNess<const TfLiteTensor, TensorView>::Type == const TensorView
229 //   MatchConstNess<TfLiteTensor, TensorView>::Type == TensorView
230 //   MatchConstNess<TfLiteTensor, const TensorView>::Type == TensorView
231 template <typename I, typename O>
232 struct MatchConstNess {
233   using Type = std::conditional_t<std::is_const<I>::value, std::add_const_t<O>,
234                                   std::remove_const_t<O>>;
235 };
236 
237 ///////////////////////////// Implementation
238 
239 template <typename DType>
AsScalar()240 DType &TensorView::AsScalar() {
241   DCHECK_EQ(shape_.size(), 0) << "Tensor is not a scalar";
242   return Data<DType>()[0];
243 }
244 
245 template <typename DType>
AsScalar()246 const DType &TensorView::AsScalar() const {
247   DCHECK_EQ(shape_.size(), 0) << "Tensor is not a scalar";
248   return Data<DType>().at(0);
249 }
250 
251 }  // namespace shim
252 }  // namespace tflite
253 
254 #endif  // TENSORFLOW_LITE_KERNELS_SHIM_TENSOR_VIEW_H_
255