xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/serialization.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_DELEGATES_SERIALIZATION_H_
16 #define TENSORFLOW_LITE_DELEGATES_SERIALIZATION_H_
17 
18 #include <cstdint>
19 #include <map>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "tensorflow/lite/c/common.h"
25 
26 // This file implements a serialization utility that TFLite delegates can use to
27 // read/write initialization data.
28 //
29 // Example code:
30 //
31 // Initialization
32 // ==============
33 // SerializationParams params;
34 // // Acts as a namespace for all data entries for a given model.
35 // // See StrFingerprint().
36 // params.model_token = options->model_token;
37 // // Location where data is stored, should be private to the app using this.
38 // params.serialization_dir = options->serialization_dir;
39 // Serialization serialization(params);
40 //
41 // Writing data
42 // ============
43 // TfLiteContext* context = ...;
44 // TfLiteDelegateParams* params = ...;
45 // SerializationEntry kernels_entry = serialization->GetEntryForKernel(
46 //     "gpuv2_kernels", context, delegate_params);
47 //
48 // TfLiteStatus kernels_save_status = kernels_entry.SetData(
49 //     reinterpret_cast<char*>(data_ptr),
50 //     data_size);
51 // if (kernels_save_status == kTfLiteOk) {
52 //   //...serialization successful...
53 // } else if (kernels_save_status == kTfLiteDelegateDataWriteError) {
54 //   //...error in serializing data to disk...
55 // } else {
56 //   //...unexpected error...
57 // }
58 //
59 // Reading data
60 // ============
61 // std::string kernels_data;
62 // TfLiteStatus kernels_data_status = kernels_entry.GetData(&kernels_data);
63 // if (kernels_data_status == kTfLiteOk) {
64 //   //...serialized data found...
65 // } else if (kernels_data_status == kTfLiteDelegateDataNotFound) {
66 //   //...serialized data missing...
67 // } else {
68 //   //...unexpected error...
69 // }
70 namespace tflite {
71 namespace delegates {
72 
73 // Helper to generate a unique string (converted from 64-bit farmhash) given
74 // some data. Intended for use by:
75 //
76 // 1. Delegates, to 'fingerprint' some custom data (like options),
77 //    and provide it as custom_key to Serialization::GetEntryForDelegate or
78 //    GetEntryForKernel.
79 // 2. TFLite clients, to fingerprint a model flatbuffer & get a unique
80 //    model_token.
81 std::string StrFingerprint(const void* data, const size_t num_bytes);
82 
83 // Encapsulates a unique blob of data serialized by a delegate.
84 // Needs to be initialized with a Serialization instance.
85 // Any data set with this entry is 'keyed' by a 64-bit fingerprint unique to the
86 // parameters used during initialization via
87 // Serialization::GetEntryForDelegate/GetEntryForKernel.
88 //
89 // NOTE: TFLite cannot guarantee that the read data is always fully valid,
90 // especially if the directory is accessible to other applications/processes.
91 // It is the delegate's responsibility to validate the retrieved data.
92 class SerializationEntry {
93  public:
94   friend class Serialization;
95 
96   // Returns a 64-bit fingerprint unique to the parameters provided during the
97   // generation of this SerializationEntry.
98   // Produces same value on every run.
GetFingerprint()99   uint64_t GetFingerprint() const { return fingerprint_; }
100 
101   // Stores `data` into a file that is unique to this SerializationKey.
102   // Overwrites any existing data if present.
103   //
104   // Returns:
105   //   kTfLiteOk if data is successfully stored
106   //   kTfLiteDelegateDataWriteError for data writing issues
107   //   kTfLiteError for unexpected error.
108   //
109   // NOTE: We use a temp file & rename it as file renaming is an atomic
110   // operation in most systems.
111   TfLiteStatus SetData(TfLiteContext* context, const char* data,
112                        const size_t size) const;
113 
114   // Get `data` corresponding to this key, if available.
115   //
116   // Returns:
117   //   kTfLiteOk if data is successfully stored
118   //   kTfLiteDataError for data writing issues
119   //   kTfLiteError for unexpected error.
120   TfLiteStatus GetData(TfLiteContext* context, std::string* data) const;
121 
122   // Non-copyable.
123   SerializationEntry(const SerializationEntry&) = delete;
124   SerializationEntry& operator=(const SerializationEntry&) = delete;
125   SerializationEntry(SerializationEntry&& src) = default;
126 
127  protected:
128   SerializationEntry(const std::string& cache_dir,
129                      const std::string& model_token,
130                      const uint64_t fingerprint_64);
131 
132   // Caching directory.
133   const std::string cache_dir_;
134   // Model Token.
135   const std::string model_token_;
136   // For most applications, 64-bit fingerprints are enough.
137   const uint64_t fingerprint_ = 0;
138 };
139 
140 // Encapsulates all the data that clients can use to parametrize a Serialization
141 // interface.
142 typedef struct SerializationParams {
143   // Acts as a 'namespace' for all SerializationEntry instances.
144   // Clients should ensure that the token is unique to the model graph & data.
145   // StrFingerprint() can be used with the flatbuffer data to generate a unique
146   // 64-bit token.
147   // TODO(b/190055017): Add 64-bit fingerprints to TFLite flatbuffers to ensure
148   // different model constants automatically lead to different fingerprints.
149   // Required.
150   const char* model_token;
151   // Denotes the directory to be used to store data.
152   // It is the client's responsibility to ensure this location is valid and
153   // application-specific to avoid unintended data access issues.
154   // On Android, `getCodeCacheDir()` is recommended.
155   // Required.
156   const char* cache_dir;
157 } SerializationParams;
158 
159 // Utility to enable caching abilities for delegates.
160 // See documentation at the top of the file for usage details.
161 //
162 // WARNING: Experimental interface, subject to change.
163 class Serialization {
164  public:
165   // Initialize a Serialization interface for applicable delegates.
Serialization(const SerializationParams & params)166   explicit Serialization(const SerializationParams& params)
167       : cache_dir_(params.cache_dir), model_token_(params.model_token) {}
168 
169   // Generate a SerializationEntry that incorporates both `custom_key` &
170   // `context` into its unique fingerprint.
171   //  Should be used to handle data common to all delegate kernels.
172   // Delegates can incorporate versions & init arguments in custom_key using
173   // StrFingerprint().
GetEntryForDelegate(const std::string & custom_key,TfLiteContext * context)174   SerializationEntry GetEntryForDelegate(const std::string& custom_key,
175                                          TfLiteContext* context) {
176     return GetEntryImpl(custom_key, context);
177   }
178 
179   // Generate a SerializationEntry that incorporates `custom_key`, `context`,
180   // and `delegate_params` into its unique fingerprint.
181   // Should be used to handle data specific to a delegate kernel, since
182   // the context+delegate_params combination is node-specific.
183   // Delegates can incorporate versions & init arguments in custom_key using
184   // StrFingerprint().
GetEntryForKernel(const std::string & custom_key,TfLiteContext * context,const TfLiteDelegateParams * partition_params)185   SerializationEntry GetEntryForKernel(
186       const std::string& custom_key, TfLiteContext* context,
187       const TfLiteDelegateParams* partition_params) {
188     return GetEntryImpl(custom_key, context, partition_params);
189   }
190 
191   // Non-copyable.
192   Serialization(const Serialization&) = delete;
193   Serialization& operator=(const Serialization&) = delete;
194 
195  protected:
196   SerializationEntry GetEntryImpl(
197       const std::string& custom_key, TfLiteContext* context = nullptr,
198       const TfLiteDelegateParams* delegate_params = nullptr);
199 
200   const std::string cache_dir_;
201   const std::string model_token_;
202 };
203 
204 // Helper for delegates to save their delegation decisions (which nodes to
205 // delegate) in TfLiteDelegate::Prepare().
206 // Internally, this uses a unique SerializationEntry based on the `context` &
207 // `delegate_id` to save the `node_ids`. It is recommended that `delegate_id` be
208 // unique to a backend/version to avoid reading back stale delegation decisions.
209 //
210 // NOTE: This implementation is platform-specific, so this method & the
211 // subsequent call to GetDelegatedNodes should happen on the same device.
212 TfLiteStatus SaveDelegatedNodes(TfLiteContext* context,
213                                 Serialization* serialization,
214                                 const std::string& delegate_id,
215                                 const TfLiteIntArray* node_ids);
216 
217 // Retrieves list of delegated nodes that were saved earlier with
218 // SaveDelegatedNodes.
219 // Caller assumes ownership of data pointed by *nodes_ids.
220 //
221 // NOTE: This implementation is platform-specific, so SaveDelegatedNodes &
222 // corresponding GetDelegatedNodes should be called on the same device.
223 TfLiteStatus GetDelegatedNodes(TfLiteContext* context,
224                                Serialization* serialization,
225                                const std::string& delegate_id,
226                                TfLiteIntArray** node_ids);
227 
228 }  // namespace delegates
229 }  // namespace tflite
230 
231 #endif  // TENSORFLOW_LITE_DELEGATES_SERIALIZATION_H_
232