1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_SHIM_TENSOR_VIEW_H_
16 #define TENSORFLOW_LITE_KERNELS_SHIM_TENSOR_VIEW_H_
17
18 #include "absl/status/statusor.h"
19 #include "absl/strings/string_view.h"
20 #include "absl/types/span.h"
21 #include "absl/types/variant.h"
22 #include "tensorflow/core/platform/logging.h"
23 #include "tensorflow/core/platform/tstring.h"
24
25 namespace tflite {
26 namespace shim {
27
28 // A type deduction template which is specialized for TF and TFLite.
29 // That is it maps
30 // ::tensorflow::Tensor -> tflite::shim::TfTensorView
31 // ::TfLiteTensor -> tflite::shim::TfLiteTensorView
32 template <typename W>
33 struct TensorViewSubType {};
34
35 // Common denominator for ::tflite::TfLiteTensor and ::tensorflow::Tensor.
36 // It is a "view" over the underlying tensor without taking ownership.
37 // Objects of this class can also mutate the underlying tensor depending on
38 // whether the underlying tensor is "const" qualified or not.
39 //
40 // Movable and copyable.
41 // It can be instantiated with the New() factory function. eg.
42 // TfTensorView t = TensorView::New(&tf_tensor);
43 // const TfTensorView t = TensorView::New(&const_tf_tensor);
44 // TfLiteTensorView t = TensorView::New(&tflite_tensor);
45 // const TfLiteTensorView t = TensorView::New(&const_tflite_tensor);
46 class TensorView {
47 protected:
48 // Union over all data types
49 using DataVariantType =
50 absl::variant<absl::Span<bool>, absl::Span<uint8_t>, absl::Span<uint64_t>,
51 absl::Span<int8_t>, absl::Span<int16_t>,
52 absl::Span<int32_t>, absl::Span<int64_t>, absl::Span<float>,
53 absl::Span<double>, absl::Span<::tensorflow::tstring>>;
54
55 // An interface while provides convenient row-major indexing over the
56 // underlying tensor.
57 // Example usage:
58 //
59 // // A scalar view
60 // const TensorView t_float
61 // float val = t_float.AsScalar<float>();
62 //
63 // // A vector view
64 // const TensorView t_int;
65 // auto t_int_vec = t_int.As<int32_t, /*RANK=*/ 1>();
66 // int sum = t_int_vec(0) + t_int_vec(1);
67 //
68 // // A matrix view
69 // TensorView t_str;
70 // auto t_str_mat = t_str.As<tensorflow::tstring, /*RANK=*/ 2>();
71 // t_str_mat(0, 0) = "abc";
72 // t_str_mat(2, 3) = "def";
73 template <typename DType, int RANK>
74 class Tensor {
75 public:
Tensor(TensorView * t)76 explicit Tensor(TensorView *t)
77 : data_(t->Data<DType>()), shape_(t->Shape()) {
78 DCHECK_EQ(RANK, shape_.size());
79 ComputeRowSizes();
80 }
81
Tensor(const TensorView * t)82 explicit Tensor(const TensorView *t)
83 : data_(t->Data<DType>()), shape_(t->Shape()) {
84 DCHECK_EQ(RANK, shape_.size());
85 ComputeRowSizes();
86 }
87
88 // indexing operator
89 template <typename... IndexTypes>
operator()90 inline DType &operator()(IndexTypes... indices) {
91 const auto idx = RowMajorIndex(std::array<int, RANK>{{indices...}});
92 return data_[idx];
93 }
94
95 // const indexing operator
96 template <typename... IndexTypes>
operator()97 inline const DType &operator()(IndexTypes... indices) const {
98 const auto idx = RowMajorIndex(std::array<int, RANK>{{indices...}});
99 return data_.at(idx);
100 }
101
102 // Pointer accessor
Ptr()103 typename absl::Span<DType>::pointer Ptr() { return data_.data(); }
Ptr()104 constexpr typename absl::Span<DType>::const_pointer Ptr() const {
105 return data_.data();
106 }
107
108 // Size of the given dimension
Dim(int dim_i)109 inline int Dim(int dim_i) const {
110 DCHECK(RANK > 0 && dim_i < RANK) << "dim: " << dim_i << " rank:" << RANK;
111 // Handle negative indices
112 if (dim_i < 0) dim_i = ((dim_i % RANK) + RANK) % RANK;
113 return shape_[dim_i];
114 }
115
116 // The tensor's rank: number of dimensions
Rank()117 /*[[nodiscard]]*/ constexpr std::size_t Rank() const { return RANK; }
118
119 private:
120 // Computes the row-major index
RowMajorIndex(const std::array<int,RANK> & indices)121 inline std::size_t RowMajorIndex(
122 const std::array<int, RANK> &indices) const {
123 std::size_t ret = 0;
124 for (int i = 0; i < RANK; ++i) ret += indices[i] * row_sizes_[i];
125 return ret;
126 }
127
128 // Pre computes row sizes to convert multi dim indices into a row major
129 // index
ComputeRowSizes()130 void ComputeRowSizes() {
131 // Precompute row sizes for row major index computation
132 if (RANK > 0) {
133 row_sizes_[RANK - 1] = 1;
134 for (int i = RANK - 2; i >= 0; --i) {
135 row_sizes_[i] = row_sizes_[i + 1] * shape_[i + 1];
136 }
137 }
138 }
139
140 absl::Span<DType> data_;
141 const absl::Span<int> shape_;
142 std::size_t row_sizes_[RANK]{};
143 };
144
145 public:
146 // Factory which gets specialized for different wrapped tensor types.
147 template <typename W>
148 static absl::StatusOr<typename TensorViewSubType<W>::Type> New(
149 W *wrapped_tensor);
150
151 protected:
152 // Move constructor
153 TensorView(TensorView &&o) = default;
154 // Copy constructor
155 TensorView(const TensorView &o) = default;
156 // Move assignment operator
157 TensorView &operator=(TensorView &&o) = default;
158 // Copy assignment operator
159 TensorView &operator=(const TensorView &) = default;
160
161 public:
162 // Dtor
163 virtual ~TensorView() = default;
164
165 // Accessors
166
167 // Shape
Shape()168 absl::Span<int> Shape() { return shape_; }
Shape()169 /*[[nodiscard]]*/ const absl::Span<int> Shape() const { return shape_; }
170
171 // Data
172 template <typename DType>
Data()173 absl::Span<DType> &Data() {
174 return absl::get<absl::Span<DType>>(data_);
175 }
176 template <typename DType>
Data()177 constexpr absl::Span<DType> Data() const {
178 return absl::get<absl::Span<DType>>(data_);
179 }
180
181 // Reads the tensor given the dtype and its rank and provides an indexing
182 // operator.
183 template <typename DType, int RANK>
As()184 Tensor<DType, RANK> As() {
185 return Tensor<DType, RANK>(this);
186 }
187
188 // Const version of As()
189 template <typename DType, int RANK>
As()190 const Tensor<DType, RANK> As() const {
191 return Tensor<DType, RANK>(this);
192 }
193
194 // Read the given tensor as a scalar or return error if it isn't
195 template <typename DType>
196 DType &AsScalar();
197
198 template <typename DType>
199 const DType &AsScalar() const;
200
201 protected:
202 // Templated constructor. Since it's not possible to specify the template
203 // argument directly we place a dummy argument of that type so compiler
204 // can deduce the right template parameter
205 template <typename DType>
TensorView(const absl::Span<int> shape,void * data,const std::size_t data_size,const DType &)206 TensorView(const absl::Span<int> shape, void *data,
207 const std::size_t data_size, const DType &)
208 : shape_(shape),
209 data_(absl::Span<DType>(reinterpret_cast<DType *>(data),
210 data_size / sizeof(DType))) {}
211
212 // Return the total number of elements given the shape.
NumElements(const absl::Span<int> shape)213 static constexpr std::size_t NumElements(const absl::Span<int> shape) {
214 std::size_t ret = 1;
215 for (const auto dim : shape) ret *= dim;
216 return ret;
217 }
218
219 // Tensor shape
220 // Note: using int rather than size_t to avoid conversion to from TfLite shape
221 absl::Span<int> shape_;
222 // Tensor data
223 DataVariantType data_;
224 };
225
226 // Add or remove const qualifier to O based on whether it is in I.
227 // For example
228 // MatchConstNess<const TfLiteTensor, TensorView>::Type == const TensorView
229 // MatchConstNess<TfLiteTensor, TensorView>::Type == TensorView
230 // MatchConstNess<TfLiteTensor, const TensorView>::Type == TensorView
231 template <typename I, typename O>
232 struct MatchConstNess {
233 using Type = std::conditional_t<std::is_const<I>::value, std::add_const_t<O>,
234 std::remove_const_t<O>>;
235 };
236
237 ///////////////////////////// Implementation
238
239 template <typename DType>
AsScalar()240 DType &TensorView::AsScalar() {
241 DCHECK_EQ(shape_.size(), 0) << "Tensor is not a scalar";
242 return Data<DType>()[0];
243 }
244
245 template <typename DType>
AsScalar()246 const DType &TensorView::AsScalar() const {
247 DCHECK_EQ(shape_.size(), 0) << "Tensor is not a scalar";
248 return Data<DType>().at(0);
249 }
250
251 } // namespace shim
252 } // namespace tflite
253
254 #endif // TENSORFLOW_LITE_KERNELS_SHIM_TENSOR_VIEW_H_
255