xref: /aosp_15_r20/external/tensorflow/tensorflow/core/util/saved_tensor_slice_util.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Utilities for saving/restoring tensor slice checkpoints.
17 
18 #ifndef TENSORFLOW_CORE_UTIL_SAVED_TENSOR_SLICE_UTIL_H_
19 #define TENSORFLOW_CORE_UTIL_SAVED_TENSOR_SLICE_UTIL_H_
20 
21 #include <string>  // for string
22 #include "tensorflow/core/framework/tensor.pb.h"
23 #include "tensorflow/core/framework/tensor_slice.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/lib/core/status.h"  // for Status
26 #include "tensorflow/core/platform/protobuf.h"
27 
28 namespace tensorflow {
29 
30 namespace checkpoint {
31 
32 // The key for the metadata in the tensor slice checkpoint files. It is "" so
33 // that the metadata is always at the beginning of a checkpoint file.
34 extern const char kSavedTensorSlicesKey[];
35 
36 // Encode a tensor name + a tensor slice into an ordered code and outputs it as
37 // a string.
38 // The format is
39 //  <0>
40 //  <tensor_name>
41 //  <rank>
42 //  <dim-0-start><dim-0-length>
43 //  <dim-1-start><dim-1-length>
44 //  ...
45 
46 string EncodeTensorNameSlice(const string& name,
47                              const tensorflow::TensorSlice& slice);
48 
49 // Parse out the name and the slice from string encoded as an ordered code.
50 Status DecodeTensorNameSlice(const string& code, string* name,
51                              tensorflow::TensorSlice* slice);
52 
53 // Extracts the full shape, slice spec, and shape of the slice from
54 // "shape_and_slice".  On non-OK return, caller must clear the out-arguments
55 // before reusing.
56 Status ParseShapeAndSlice(const string& shape_and_slice, TensorShape* shape,
57                           TensorSlice* slice, TensorShape* shape_slice);
58 
59 template <typename T>
60 struct SaveTypeTraits;
61 
62 template <typename T>
63 int TensorProtoDataSize(const TensorProto& t);
64 
65 template <typename T>
66 const typename SaveTypeTraits<T>::SavedType* TensorProtoData(
67     const TensorProto& t);
68 
69 template <typename T>
70 typename SaveTypeTraits<T>::RepeatedField* MutableTensorProtoData(
71     TensorProto* t);
72 
73 template <typename T>
74 void Fill(T* data, size_t n, TensorProto* t);
75 
76 #define TENSOR_PROTO_EXTRACT_TYPE_HELPER(TYPE, FIELD, FTYPE, STYPE)      \
77   template <>                                                            \
78   struct SaveTypeTraits<TYPE> {                                          \
79     static constexpr bool supported = true;                              \
80     typedef STYPE SavedType;                                             \
81     typedef protobuf::RepeatedField<FTYPE> RepeatedField;                \
82   };                                                                     \
83   template <>                                                            \
84   inline const STYPE* TensorProtoData<TYPE>(const TensorProto& t) {      \
85     static_assert(SaveTypeTraits<TYPE>::supported,                       \
86                   "Specified type " #TYPE " not supported for Restore"); \
87     return reinterpret_cast<const STYPE*>(t.FIELD##_val().data());       \
88   }                                                                      \
89   template <>                                                            \
90   inline protobuf::RepeatedField<FTYPE>* MutableTensorProtoData<TYPE>(   \
91       TensorProto * t) {                                                 \
92     static_assert(SaveTypeTraits<TYPE>::supported,                       \
93                   "Specified type " #TYPE " not supported for Save");    \
94     return reinterpret_cast<protobuf::RepeatedField<FTYPE>*>(            \
95         t->mutable_##FIELD##_val());                                     \
96   }
97 
98 #define TENSOR_PROTO_EXTRACT_TYPE(TYPE, FIELD, FTYPE)             \
99   TENSOR_PROTO_EXTRACT_TYPE_HELPER(TYPE, FIELD, FTYPE, FTYPE)     \
100   template <>                                                     \
101   inline int TensorProtoDataSize<TYPE>(const TensorProto& t) {    \
102     return t.FIELD##_val_size();                                  \
103   }                                                               \
104   template <>                                                     \
105   inline void Fill(const TYPE* data, size_t n, TensorProto* t) {  \
106     typename protobuf::RepeatedField<FTYPE> copy(data, data + n); \
107     t->mutable_##FIELD##_val()->Swap(&copy);                      \
108   }
109 
110 // Complex needs special treatment since proto doesn't have native complex
111 #define TENSOR_PROTO_EXTRACT_TYPE_COMPLEX(TYPE, FIELD, FTYPE)       \
112   TENSOR_PROTO_EXTRACT_TYPE_HELPER(TYPE, FIELD, FTYPE, TYPE)        \
113   template <>                                                       \
114   inline int TensorProtoDataSize<TYPE>(const TensorProto& t) {      \
115     return t.FIELD##_val_size() / 2;                                \
116   }                                                                 \
117   template <>                                                       \
118   inline void Fill(const TYPE* data, size_t n, TensorProto* t) {    \
119     const FTYPE* sub = reinterpret_cast<const FTYPE*>(data);        \
120     typename protobuf::RepeatedField<FTYPE> copy(sub, sub + 2 * n); \
121     t->mutable_##FIELD##_val()->Swap(&copy);                        \
122   }
123 
124 TENSOR_PROTO_EXTRACT_TYPE(bool, bool, bool);
125 TENSOR_PROTO_EXTRACT_TYPE(float, float, float);
126 TENSOR_PROTO_EXTRACT_TYPE(double, double, double);
127 TENSOR_PROTO_EXTRACT_TYPE_COMPLEX(complex64, scomplex, float);
128 TENSOR_PROTO_EXTRACT_TYPE_COMPLEX(complex128, dcomplex, double);
129 TENSOR_PROTO_EXTRACT_TYPE(int32, int, int32);
130 TENSOR_PROTO_EXTRACT_TYPE(uint32, uint32, uint32);
131 TENSOR_PROTO_EXTRACT_TYPE(int64_t, int64, protobuf_int64);
132 TENSOR_PROTO_EXTRACT_TYPE(uint64, uint64, protobuf_uint64);
133 TENSOR_PROTO_EXTRACT_TYPE(uint16, int, int32);
134 TENSOR_PROTO_EXTRACT_TYPE(uint8, int, int32);
135 TENSOR_PROTO_EXTRACT_TYPE(int8, int, int32);
136 TENSOR_PROTO_EXTRACT_TYPE(int16, int, int32);
137 TENSOR_PROTO_EXTRACT_TYPE(qint8, int, int32);
138 TENSOR_PROTO_EXTRACT_TYPE(quint8, int, int32);
139 TENSOR_PROTO_EXTRACT_TYPE(quint16, int, int32);
140 
141 #undef TENSOR_PROTO_EXTRACT_TYPE_COMPLEX
142 #undef TENSOR_PROTO_EXTRACT_TYPE_HELPER
143 #undef TENSOR_PROTO_EXTRACT_TYPE
144 
145 // Custom implementation for qint32, based on the one for int32.
146 
147 template <>
148 struct SaveTypeTraits<qint32> : SaveTypeTraits<int32> {};
149 
150 template <>
151 inline int TensorProtoDataSize<qint32>(const TensorProto& t) {
152   return t.int_val_size();
153 }
154 
155 template <>
156 inline const int32* TensorProtoData<qint32>(const TensorProto& t) {
157   static_assert(SaveTypeTraits<qint32>::supported,
158                 "Specified type qint32 not supported for Restore");
159   return reinterpret_cast<const int32*>(t.int_val().data());
160 }
161 
162 inline void Fill(const qint32* data, size_t n, TensorProto* t) {
163   const int32* p = reinterpret_cast<const int32*>(data);
164   typename protobuf::RepeatedField<int32> copy(p, p + n);
165   t->mutable_int_val()->Swap(&copy);
166 }
167 
168 // Custom implementation for Eigen::half.
169 
170 template <>
171 struct SaveTypeTraits<Eigen::half> {
172   static constexpr bool supported = true;
173   typedef int SavedType;
174   typedef protobuf::RepeatedField<int32> RepeatedField;
175 };
176 
177 template <>
178 inline int TensorProtoDataSize<Eigen::half>(const TensorProto& t) {
179   return t.half_val_size();
180 }
181 
182 template <>
183 inline const int* TensorProtoData<Eigen::half>(const TensorProto& t) {
184   return t.half_val().data();
185 }
186 
187 template <>
188 inline protobuf::RepeatedField<int32>* MutableTensorProtoData<Eigen::half>(
189     TensorProto* t) {
190   return t->mutable_half_val();
191 }
192 
193 template <>
194 inline void Fill(const Eigen::half* data, size_t n, TensorProto* t) {
195   typename protobuf::RepeatedField<int32>* val = t->mutable_half_val();
196   val->Resize(n, 0);
197   for (size_t i = 0; i < n; ++i) {
198     val->Set(i, Eigen::numext::bit_cast<uint16>(data[i]));
199   }
200 }
201 
202 // Custom implementation for string.
203 
204 template <>
205 struct SaveTypeTraits<tstring> {
206   static constexpr bool supported = true;
207   typedef const string* SavedType;
208   typedef protobuf::RepeatedPtrField<string> RepeatedField;
209 };
210 
211 template <>
212 inline int TensorProtoDataSize<tstring>(const TensorProto& t) {
213   return t.string_val_size();
214 }
215 
216 template <>
217 inline const string* const* TensorProtoData<tstring>(const TensorProto& t) {
218   static_assert(SaveTypeTraits<tstring>::supported,
219                 "Specified type tstring not supported for Restore");
220   return t.string_val().data();
221 }
222 
223 template <>
224 inline protobuf::RepeatedPtrField<string>* MutableTensorProtoData<tstring>(
225     TensorProto* t) {
226   static_assert(SaveTypeTraits<tstring>::supported,
227                 "Specified type tstring not supported for Save");
228   return t->mutable_string_val();
229 }
230 
231 template <>
232 inline void Fill(const tstring* data, size_t n, TensorProto* t) {
233   typename protobuf::RepeatedPtrField<string> copy(data, data + n);
234   t->mutable_string_val()->Swap(&copy);
235 }
236 
237 }  // namespace checkpoint
238 
239 }  // namespace tensorflow
240 
241 #endif  // TENSORFLOW_CORE_UTIL_SAVED_TENSOR_SLICE_UTIL_H_
242