1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Utilities for saving/restoring tensor slice checkpoints. 17 18 #ifndef TENSORFLOW_CORE_UTIL_SAVED_TENSOR_SLICE_UTIL_H_ 19 #define TENSORFLOW_CORE_UTIL_SAVED_TENSOR_SLICE_UTIL_H_ 20 21 #include <string> // for string 22 #include "tensorflow/core/framework/tensor.pb.h" 23 #include "tensorflow/core/framework/tensor_slice.h" 24 #include "tensorflow/core/framework/types.h" 25 #include "tensorflow/core/lib/core/status.h" // for Status 26 #include "tensorflow/core/platform/protobuf.h" 27 28 namespace tensorflow { 29 30 namespace checkpoint { 31 32 // The key for the metadata in the tensor slice checkpoint files. It is "" so 33 // that the metadata is always at the beginning of a checkpoint file. 34 extern const char kSavedTensorSlicesKey[]; 35 36 // Encode a tensor name + a tensor slice into an ordered code and outputs it as 37 // a string. 38 // The format is 39 // <0> 40 // <tensor_name> 41 // <rank> 42 // <dim-0-start><dim-0-length> 43 // <dim-1-start><dim-1-length> 44 // ... 45 46 string EncodeTensorNameSlice(const string& name, 47 const tensorflow::TensorSlice& slice); 48 49 // Parse out the name and the slice from string encoded as an ordered code. 50 Status DecodeTensorNameSlice(const string& code, string* name, 51 tensorflow::TensorSlice* slice); 52 53 // Extracts the full shape, slice spec, and shape of the slice from 54 // "shape_and_slice". On non-OK return, caller must clear the out-arguments 55 // before reusing. 56 Status ParseShapeAndSlice(const string& shape_and_slice, TensorShape* shape, 57 TensorSlice* slice, TensorShape* shape_slice); 58 59 template <typename T> 60 struct SaveTypeTraits; 61 62 template <typename T> 63 int TensorProtoDataSize(const TensorProto& t); 64 65 template <typename T> 66 const typename SaveTypeTraits<T>::SavedType* TensorProtoData( 67 const TensorProto& t); 68 69 template <typename T> 70 typename SaveTypeTraits<T>::RepeatedField* MutableTensorProtoData( 71 TensorProto* t); 72 73 template <typename T> 74 void Fill(T* data, size_t n, TensorProto* t); 75 76 #define TENSOR_PROTO_EXTRACT_TYPE_HELPER(TYPE, FIELD, FTYPE, STYPE) \ 77 template <> \ 78 struct SaveTypeTraits<TYPE> { \ 79 static constexpr bool supported = true; \ 80 typedef STYPE SavedType; \ 81 typedef protobuf::RepeatedField<FTYPE> RepeatedField; \ 82 }; \ 83 template <> \ 84 inline const STYPE* TensorProtoData<TYPE>(const TensorProto& t) { \ 85 static_assert(SaveTypeTraits<TYPE>::supported, \ 86 "Specified type " #TYPE " not supported for Restore"); \ 87 return reinterpret_cast<const STYPE*>(t.FIELD##_val().data()); \ 88 } \ 89 template <> \ 90 inline protobuf::RepeatedField<FTYPE>* MutableTensorProtoData<TYPE>( \ 91 TensorProto * t) { \ 92 static_assert(SaveTypeTraits<TYPE>::supported, \ 93 "Specified type " #TYPE " not supported for Save"); \ 94 return reinterpret_cast<protobuf::RepeatedField<FTYPE>*>( \ 95 t->mutable_##FIELD##_val()); \ 96 } 97 98 #define TENSOR_PROTO_EXTRACT_TYPE(TYPE, FIELD, FTYPE) \ 99 TENSOR_PROTO_EXTRACT_TYPE_HELPER(TYPE, FIELD, FTYPE, FTYPE) \ 100 template <> \ 101 inline int TensorProtoDataSize<TYPE>(const TensorProto& t) { \ 102 return t.FIELD##_val_size(); \ 103 } \ 104 template <> \ 105 inline void Fill(const TYPE* data, size_t n, TensorProto* t) { \ 106 typename protobuf::RepeatedField<FTYPE> copy(data, data + n); \ 107 t->mutable_##FIELD##_val()->Swap(©); \ 108 } 109 110 // Complex needs special treatment since proto doesn't have native complex 111 #define TENSOR_PROTO_EXTRACT_TYPE_COMPLEX(TYPE, FIELD, FTYPE) \ 112 TENSOR_PROTO_EXTRACT_TYPE_HELPER(TYPE, FIELD, FTYPE, TYPE) \ 113 template <> \ 114 inline int TensorProtoDataSize<TYPE>(const TensorProto& t) { \ 115 return t.FIELD##_val_size() / 2; \ 116 } \ 117 template <> \ 118 inline void Fill(const TYPE* data, size_t n, TensorProto* t) { \ 119 const FTYPE* sub = reinterpret_cast<const FTYPE*>(data); \ 120 typename protobuf::RepeatedField<FTYPE> copy(sub, sub + 2 * n); \ 121 t->mutable_##FIELD##_val()->Swap(©); \ 122 } 123 124 TENSOR_PROTO_EXTRACT_TYPE(bool, bool, bool); 125 TENSOR_PROTO_EXTRACT_TYPE(float, float, float); 126 TENSOR_PROTO_EXTRACT_TYPE(double, double, double); 127 TENSOR_PROTO_EXTRACT_TYPE_COMPLEX(complex64, scomplex, float); 128 TENSOR_PROTO_EXTRACT_TYPE_COMPLEX(complex128, dcomplex, double); 129 TENSOR_PROTO_EXTRACT_TYPE(int32, int, int32); 130 TENSOR_PROTO_EXTRACT_TYPE(uint32, uint32, uint32); 131 TENSOR_PROTO_EXTRACT_TYPE(int64_t, int64, protobuf_int64); 132 TENSOR_PROTO_EXTRACT_TYPE(uint64, uint64, protobuf_uint64); 133 TENSOR_PROTO_EXTRACT_TYPE(uint16, int, int32); 134 TENSOR_PROTO_EXTRACT_TYPE(uint8, int, int32); 135 TENSOR_PROTO_EXTRACT_TYPE(int8, int, int32); 136 TENSOR_PROTO_EXTRACT_TYPE(int16, int, int32); 137 TENSOR_PROTO_EXTRACT_TYPE(qint8, int, int32); 138 TENSOR_PROTO_EXTRACT_TYPE(quint8, int, int32); 139 TENSOR_PROTO_EXTRACT_TYPE(quint16, int, int32); 140 141 #undef TENSOR_PROTO_EXTRACT_TYPE_COMPLEX 142 #undef TENSOR_PROTO_EXTRACT_TYPE_HELPER 143 #undef TENSOR_PROTO_EXTRACT_TYPE 144 145 // Custom implementation for qint32, based on the one for int32. 146 147 template <> 148 struct SaveTypeTraits<qint32> : SaveTypeTraits<int32> {}; 149 150 template <> 151 inline int TensorProtoDataSize<qint32>(const TensorProto& t) { 152 return t.int_val_size(); 153 } 154 155 template <> 156 inline const int32* TensorProtoData<qint32>(const TensorProto& t) { 157 static_assert(SaveTypeTraits<qint32>::supported, 158 "Specified type qint32 not supported for Restore"); 159 return reinterpret_cast<const int32*>(t.int_val().data()); 160 } 161 162 inline void Fill(const qint32* data, size_t n, TensorProto* t) { 163 const int32* p = reinterpret_cast<const int32*>(data); 164 typename protobuf::RepeatedField<int32> copy(p, p + n); 165 t->mutable_int_val()->Swap(©); 166 } 167 168 // Custom implementation for Eigen::half. 169 170 template <> 171 struct SaveTypeTraits<Eigen::half> { 172 static constexpr bool supported = true; 173 typedef int SavedType; 174 typedef protobuf::RepeatedField<int32> RepeatedField; 175 }; 176 177 template <> 178 inline int TensorProtoDataSize<Eigen::half>(const TensorProto& t) { 179 return t.half_val_size(); 180 } 181 182 template <> 183 inline const int* TensorProtoData<Eigen::half>(const TensorProto& t) { 184 return t.half_val().data(); 185 } 186 187 template <> 188 inline protobuf::RepeatedField<int32>* MutableTensorProtoData<Eigen::half>( 189 TensorProto* t) { 190 return t->mutable_half_val(); 191 } 192 193 template <> 194 inline void Fill(const Eigen::half* data, size_t n, TensorProto* t) { 195 typename protobuf::RepeatedField<int32>* val = t->mutable_half_val(); 196 val->Resize(n, 0); 197 for (size_t i = 0; i < n; ++i) { 198 val->Set(i, Eigen::numext::bit_cast<uint16>(data[i])); 199 } 200 } 201 202 // Custom implementation for string. 203 204 template <> 205 struct SaveTypeTraits<tstring> { 206 static constexpr bool supported = true; 207 typedef const string* SavedType; 208 typedef protobuf::RepeatedPtrField<string> RepeatedField; 209 }; 210 211 template <> 212 inline int TensorProtoDataSize<tstring>(const TensorProto& t) { 213 return t.string_val_size(); 214 } 215 216 template <> 217 inline const string* const* TensorProtoData<tstring>(const TensorProto& t) { 218 static_assert(SaveTypeTraits<tstring>::supported, 219 "Specified type tstring not supported for Restore"); 220 return t.string_val().data(); 221 } 222 223 template <> 224 inline protobuf::RepeatedPtrField<string>* MutableTensorProtoData<tstring>( 225 TensorProto* t) { 226 static_assert(SaveTypeTraits<tstring>::supported, 227 "Specified type tstring not supported for Save"); 228 return t->mutable_string_val(); 229 } 230 231 template <> 232 inline void Fill(const tstring* data, size_t n, TensorProto* t) { 233 typename protobuf::RepeatedPtrField<string> copy(data, data + n); 234 t->mutable_string_val()->Swap(©); 235 } 236 237 } // namespace checkpoint 238 239 } // namespace tensorflow 240 241 #endif // TENSORFLOW_CORE_UTIL_SAVED_TENSOR_SLICE_UTIL_H_ 242