1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_LITE_DELEGATES_SERIALIZATION_H_ 16 #define TENSORFLOW_LITE_DELEGATES_SERIALIZATION_H_ 17 18 #include <cstdint> 19 #include <map> 20 #include <string> 21 #include <utility> 22 #include <vector> 23 24 #include "tensorflow/lite/c/common.h" 25 26 // This file implements a serialization utility that TFLite delegates can use to 27 // read/write initialization data. 28 // 29 // Example code: 30 // 31 // Initialization 32 // ============== 33 // SerializationParams params; 34 // // Acts as a namespace for all data entries for a given model. 35 // // See StrFingerprint(). 36 // params.model_token = options->model_token; 37 // // Location where data is stored, should be private to the app using this. 38 // params.serialization_dir = options->serialization_dir; 39 // Serialization serialization(params); 40 // 41 // Writing data 42 // ============ 43 // TfLiteContext* context = ...; 44 // TfLiteDelegateParams* params = ...; 45 // SerializationEntry kernels_entry = serialization->GetEntryForKernel( 46 // "gpuv2_kernels", context, delegate_params); 47 // 48 // TfLiteStatus kernels_save_status = kernels_entry.SetData( 49 // reinterpret_cast<char*>(data_ptr), 50 // data_size); 51 // if (kernels_save_status == kTfLiteOk) { 52 // //...serialization successful... 53 // } else if (kernels_save_status == kTfLiteDelegateDataWriteError) { 54 // //...error in serializing data to disk... 55 // } else { 56 // //...unexpected error... 57 // } 58 // 59 // Reading data 60 // ============ 61 // std::string kernels_data; 62 // TfLiteStatus kernels_data_status = kernels_entry.GetData(&kernels_data); 63 // if (kernels_data_status == kTfLiteOk) { 64 // //...serialized data found... 65 // } else if (kernels_data_status == kTfLiteDelegateDataNotFound) { 66 // //...serialized data missing... 67 // } else { 68 // //...unexpected error... 69 // } 70 namespace tflite { 71 namespace delegates { 72 73 // Helper to generate a unique string (converted from 64-bit farmhash) given 74 // some data. Intended for use by: 75 // 76 // 1. Delegates, to 'fingerprint' some custom data (like options), 77 // and provide it as custom_key to Serialization::GetEntryForDelegate or 78 // GetEntryForKernel. 79 // 2. TFLite clients, to fingerprint a model flatbuffer & get a unique 80 // model_token. 81 std::string StrFingerprint(const void* data, const size_t num_bytes); 82 83 // Encapsulates a unique blob of data serialized by a delegate. 84 // Needs to be initialized with a Serialization instance. 85 // Any data set with this entry is 'keyed' by a 64-bit fingerprint unique to the 86 // parameters used during initialization via 87 // Serialization::GetEntryForDelegate/GetEntryForKernel. 88 // 89 // NOTE: TFLite cannot guarantee that the read data is always fully valid, 90 // especially if the directory is accessible to other applications/processes. 91 // It is the delegate's responsibility to validate the retrieved data. 92 class SerializationEntry { 93 public: 94 friend class Serialization; 95 96 // Returns a 64-bit fingerprint unique to the parameters provided during the 97 // generation of this SerializationEntry. 98 // Produces same value on every run. GetFingerprint()99 uint64_t GetFingerprint() const { return fingerprint_; } 100 101 // Stores `data` into a file that is unique to this SerializationKey. 102 // Overwrites any existing data if present. 103 // 104 // Returns: 105 // kTfLiteOk if data is successfully stored 106 // kTfLiteDelegateDataWriteError for data writing issues 107 // kTfLiteError for unexpected error. 108 // 109 // NOTE: We use a temp file & rename it as file renaming is an atomic 110 // operation in most systems. 111 TfLiteStatus SetData(TfLiteContext* context, const char* data, 112 const size_t size) const; 113 114 // Get `data` corresponding to this key, if available. 115 // 116 // Returns: 117 // kTfLiteOk if data is successfully stored 118 // kTfLiteDataError for data writing issues 119 // kTfLiteError for unexpected error. 120 TfLiteStatus GetData(TfLiteContext* context, std::string* data) const; 121 122 // Non-copyable. 123 SerializationEntry(const SerializationEntry&) = delete; 124 SerializationEntry& operator=(const SerializationEntry&) = delete; 125 SerializationEntry(SerializationEntry&& src) = default; 126 127 protected: 128 SerializationEntry(const std::string& cache_dir, 129 const std::string& model_token, 130 const uint64_t fingerprint_64); 131 132 // Caching directory. 133 const std::string cache_dir_; 134 // Model Token. 135 const std::string model_token_; 136 // For most applications, 64-bit fingerprints are enough. 137 const uint64_t fingerprint_ = 0; 138 }; 139 140 // Encapsulates all the data that clients can use to parametrize a Serialization 141 // interface. 142 typedef struct SerializationParams { 143 // Acts as a 'namespace' for all SerializationEntry instances. 144 // Clients should ensure that the token is unique to the model graph & data. 145 // StrFingerprint() can be used with the flatbuffer data to generate a unique 146 // 64-bit token. 147 // TODO(b/190055017): Add 64-bit fingerprints to TFLite flatbuffers to ensure 148 // different model constants automatically lead to different fingerprints. 149 // Required. 150 const char* model_token; 151 // Denotes the directory to be used to store data. 152 // It is the client's responsibility to ensure this location is valid and 153 // application-specific to avoid unintended data access issues. 154 // On Android, `getCodeCacheDir()` is recommended. 155 // Required. 156 const char* cache_dir; 157 } SerializationParams; 158 159 // Utility to enable caching abilities for delegates. 160 // See documentation at the top of the file for usage details. 161 // 162 // WARNING: Experimental interface, subject to change. 163 class Serialization { 164 public: 165 // Initialize a Serialization interface for applicable delegates. Serialization(const SerializationParams & params)166 explicit Serialization(const SerializationParams& params) 167 : cache_dir_(params.cache_dir), model_token_(params.model_token) {} 168 169 // Generate a SerializationEntry that incorporates both `custom_key` & 170 // `context` into its unique fingerprint. 171 // Should be used to handle data common to all delegate kernels. 172 // Delegates can incorporate versions & init arguments in custom_key using 173 // StrFingerprint(). GetEntryForDelegate(const std::string & custom_key,TfLiteContext * context)174 SerializationEntry GetEntryForDelegate(const std::string& custom_key, 175 TfLiteContext* context) { 176 return GetEntryImpl(custom_key, context); 177 } 178 179 // Generate a SerializationEntry that incorporates `custom_key`, `context`, 180 // and `delegate_params` into its unique fingerprint. 181 // Should be used to handle data specific to a delegate kernel, since 182 // the context+delegate_params combination is node-specific. 183 // Delegates can incorporate versions & init arguments in custom_key using 184 // StrFingerprint(). GetEntryForKernel(const std::string & custom_key,TfLiteContext * context,const TfLiteDelegateParams * partition_params)185 SerializationEntry GetEntryForKernel( 186 const std::string& custom_key, TfLiteContext* context, 187 const TfLiteDelegateParams* partition_params) { 188 return GetEntryImpl(custom_key, context, partition_params); 189 } 190 191 // Non-copyable. 192 Serialization(const Serialization&) = delete; 193 Serialization& operator=(const Serialization&) = delete; 194 195 protected: 196 SerializationEntry GetEntryImpl( 197 const std::string& custom_key, TfLiteContext* context = nullptr, 198 const TfLiteDelegateParams* delegate_params = nullptr); 199 200 const std::string cache_dir_; 201 const std::string model_token_; 202 }; 203 204 // Helper for delegates to save their delegation decisions (which nodes to 205 // delegate) in TfLiteDelegate::Prepare(). 206 // Internally, this uses a unique SerializationEntry based on the `context` & 207 // `delegate_id` to save the `node_ids`. It is recommended that `delegate_id` be 208 // unique to a backend/version to avoid reading back stale delegation decisions. 209 // 210 // NOTE: This implementation is platform-specific, so this method & the 211 // subsequent call to GetDelegatedNodes should happen on the same device. 212 TfLiteStatus SaveDelegatedNodes(TfLiteContext* context, 213 Serialization* serialization, 214 const std::string& delegate_id, 215 const TfLiteIntArray* node_ids); 216 217 // Retrieves list of delegated nodes that were saved earlier with 218 // SaveDelegatedNodes. 219 // Caller assumes ownership of data pointed by *nodes_ids. 220 // 221 // NOTE: This implementation is platform-specific, so SaveDelegatedNodes & 222 // corresponding GetDelegatedNodes should be called on the same device. 223 TfLiteStatus GetDelegatedNodes(TfLiteContext* context, 224 Serialization* serialization, 225 const std::string& delegate_id, 226 TfLiteIntArray** node_ids); 227 228 } // namespace delegates 229 } // namespace tflite 230 231 #endif // TENSORFLOW_LITE_DELEGATES_SERIALIZATION_H_ 232