xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/dataset_utils.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_DATA_DATASET_UTILS_H_
16 #define TENSORFLOW_CORE_DATA_DATASET_UTILS_H_
17 
18 #include <functional>
19 #include <string>
20 
21 #include "absl/container/flat_hash_set.h"
22 #include "tensorflow/core/common_runtime/function.h"
23 #include "tensorflow/core/framework/dataset.h"
24 #include "tensorflow/core/framework/function.h"
25 #include "tensorflow/core/framework/resource_handle.h"
26 #include "tensorflow/core/framework/resource_mgr.h"
27 #include "tensorflow/core/framework/tensor.h"
28 
29 namespace tensorflow {
30 namespace data {
31 
32 // Constant used for indicating that the argument of tf.data.Dataset.shard
33 // should be supplied by the auto-sharding rewrite.
34 constexpr int kShardHint = -1;
35 
36 // The initial parallelism value before Autotune has a chance to optimize.
37 constexpr int kAutotuneDefaultParallelism = 16;
38 
39 // Creates a resource handle with a unique name for the given resource where
40 // the resource is managed by the Resource Manager.
41 template <typename T>
CreateWeakHandle(OpKernelContext * ctx,T * resource,const string & container_name,ResourceHandle * handle)42 Status CreateWeakHandle(OpKernelContext* ctx, T* resource,
43                         const string& container_name, ResourceHandle* handle) {
44   static std::atomic<int64_t> resource_id_counter(0);
45   string unique_name =
46       strings::StrCat(container_name, resource_id_counter.fetch_add(1));
47   ResourceMgr* mgr = ctx->resource_manager();
48   TF_RETURN_IF_ERROR(mgr->Create<T>(container_name, unique_name, resource));
49 
50   *handle = MakeResourceHandle(container_name, unique_name, *ctx->device(),
51                                TypeIndex::Make<T>());
52   return OkStatus();
53 }
54 
55 // Creates a ref-counting resource handle for the given resource, where the
56 // resource is owned by the handle.
57 template <typename T>
CreateHandle(OpKernelContext * ctx,T * resource,ResourceHandle * handle)58 Status CreateHandle(OpKernelContext* ctx, T* resource, ResourceHandle* handle) {
59   ResourceMgr* mgr = ctx->resource_manager();
60   *handle =
61       ResourceHandle::MakeRefCountingHandle(resource, ctx->device()->name());
62   TF_RETURN_IF_ERROR(
63       mgr->CreateUnowned<T>(handle->container(), handle->name(), resource));
64   return OkStatus();
65 }
66 
67 // TODO(b/198162355): Merge this class with ResourceOpKernel.
68 template <typename T>
69 class AnonymousResourceOp : public OpKernel {
70  public:
71   // Creates an AnonymousResourceOp.
72   // ref_counting: Determines if the Op returns a ref-counting ResourceHandle.
73   // ResourceHandle. See go/tf-resource-handle-ref-count.
74   // return_deleter: Determines if the Op outputs a deleter tensor in addition
75   // to the resource handle tensor.
76   // If the resource handle is ref-counting, a no-op deleter is returned.
AnonymousResourceOp(OpKernelConstruction * context,bool ref_counting,bool return_deleter)77   explicit AnonymousResourceOp(OpKernelConstruction* context, bool ref_counting,
78                                bool return_deleter)
79       : OpKernel(context),
80         ref_counting_(ref_counting),
81         return_deleter_(return_deleter) {}
82 
Compute(OpKernelContext * ctx)83   void Compute(OpKernelContext* ctx) override {
84     FunctionLibraryRuntime* lib;
85     std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
86     std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
87     OP_REQUIRES_OK(
88         ctx, ctx->function_library()->Clone(&flib_def, &pflr, &lib, true));
89     T* resource;
90     OP_REQUIRES_OK(ctx, CreateResource(ctx, std::move(flib_def),
91                                        std::move(pflr), lib, &resource));
92 
93     ResourceHandle handle;
94     if (ref_counting_) {
95       OP_REQUIRES_OK(ctx, CreateHandle(ctx, resource, &handle));
96     } else {
97       OP_REQUIRES_OK(ctx, CreateWeakHandle(ctx, resource, name(), &handle));
98     }
99     Tensor* handle_t;
100     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle_t));
101     handle_t->scalar<ResourceHandle>()() = handle;
102 
103     if (return_deleter_) {
104       Tensor* deleter_t;
105       AllocatorAttributes attr;
106       attr.set_on_host(true);
107       OP_REQUIRES_OK(
108           ctx, ctx->allocate_output(1, TensorShape({}), &deleter_t, attr));
109       // TODO(feyu): Consider returning an OptionalVariant.
110       if (!ref_counting_) {
111         // A deleter output that deletes the resource when destroyed.
112         deleter_t->scalar<Variant>()() =
113             ResourceDeleter(handle, ctx->resource_manager());
114       }
115     }
116   }
117 
118  protected:
119   virtual string name() = 0;
120 
121   virtual Status CreateResource(
122       OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
123       std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
124       FunctionLibraryRuntime* lib, T** resource) = 0;
125 
126  private:
127   const bool ref_counting_;
128   const bool return_deleter_;
129 };
130 
131 // Returns Status::OK() if `expected` and `received` types match,
132 // errors::InvalidArgument otherwise.
133 Status VerifyTypesMatch(const DataTypeVector& expected,
134                         const DataTypeVector& received);
135 
136 Status VerifyTypesMatch(const DataTypeVector& expected,
137                         const std::vector<Tensor>& received);
138 
139 // Returns Status::OK() if `expected` and `received` shapes are compatible,
140 // errors::InvalidArgument otherwise.
141 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
142                               const std::vector<PartialTensorShape>& received);
143 
144 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
145                               const std::vector<Tensor>& received);
146 
147 // Dataset op level determinism policy.
148 class DeterminismPolicy {
149  public:
150   enum class Type : int {
151     // The op must produce elements deterministically.
152     kDeterministic,
153     // The op may relax determinism to improve performance.
154     kNondeterministic,
155     // The determinism policy is not specified at the op level. In this case we
156     // use the experimental_deterministic dataset option to determine the
157     // determinism policy.
158     kDefault,
159   };
160   static constexpr const char* const kDeterministic = "true";
161   static constexpr const char* const kNondeterministic = "false";
162   static constexpr const char* const kDefault = "default";
163 
DeterminismPolicy()164   DeterminismPolicy() : determinism_(Type::kDefault) {}
DeterminismPolicy(Type determinism)165   explicit DeterminismPolicy(Type determinism) : determinism_(determinism) {}
166   // Creates a DeterminismPolicy with Type kDeterministic or
167   // kNondeterministic, depending on the values of `is_deterministic`.
168   explicit DeterminismPolicy(bool is_deterministic);
169 
170   static Status FromString(const std::string& s, DeterminismPolicy* out);
171 
172   // Returns the string representing the determinism policy. This will be one of
173   // the string constants defined above.
174   std::string String() const;
175 
176   /// Convenience methods for checking the DeterminismPolicy::Type.
IsDeterministic()177   bool IsDeterministic() const { return determinism_ == Type::kDeterministic; }
IsNondeterministic()178   bool IsNondeterministic() const {
179     return determinism_ == Type::kNondeterministic;
180   }
IsDefault()181   bool IsDefault() const { return determinism_ == Type::kDefault; }
182 
183  private:
184   Type determinism_;
185 };
186 
187 // Resolves non-deterministic seeds if necessary, returning either the original
188 // seeds or the resolved seeds.
189 //
190 // By TensorFlow convention, if both seeds are 0, they should be replaced with
191 // non-deterministically chosen seeds.
192 std::pair<int64_t, int64_t> MaybeOverrideSeeds(
193     std::pair<int64_t, int64_t> seeds);
194 
195 // Adds the functions in `to_add` to `base`. If a function with a matching
196 // signature already exists in `base`, replaces it with the function from
197 // `to_add`.
198 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
199                             const FunctionLibraryDefinition& to_add);
200 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
201                             const FunctionDefLibrary& to_add);
202 
203 // Determines whether the given function is stateful.
204 Status IsFunctionStateful(const FunctionLibraryDefinition& library,
205                           const FunctionDef& function_def);
206 
207 // Determines whether the given node is stateful.
208 Status IsNodeStateful(const FunctionLibraryDefinition& library,
209                       const NodeDef& node);
210 
211 // Creates a runner that runs functions with limited parallelism.
212 std::function<void(std::function<void()>)> RunnerWithMaxParallelism(
213     std::function<void(std::function<void()>)> runner, int max_parallelism);
214 
215 // Op for creating a typed dummy resource.
216 //
217 // This op is used to provide a resource "placeholder" for ops such as
218 // `CacheDatasetV2` or `ShuffleDatasetV2` that expects a resource input.
219 // Originally, the lifetime of the resources passed into these ops was managed
220 // externally. After the implementation changed to manage the lifetime of the
221 // resources (including creation) by the ops themselves, the resource input is
222 // only needed to pass a resource handle through graph rewrites. When they are
223 // invoked from user code, the implementation passes in a dummy resource.
224 template <typename ResourceType>
225 class DummyResourceOp : public OpKernel {
226  public:
DummyResourceOp(OpKernelConstruction * ctx)227   explicit DummyResourceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
228 
Compute(OpKernelContext * ctx)229   void Compute(OpKernelContext* ctx) override {
230     Tensor* tensor;
231     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &tensor));
232     tensor->scalar<ResourceHandle>()() = MakeResourceHandle<ResourceType>(
233         ctx, /*container=*/"", /*name=*/"dummy_resource");
234   }
235 };
236 
237 // Given an op prefix and an op to match, returns whether the op to match
238 // is a match for any version of the op prefix. For example,
239 // MatchesAnyVersion("BatchDataset", "BatchDataset") == true
240 // MatchesAnyVersion("BatchDataset", "BatchDatasetV2") == true
241 // MatchesAnyVersion("BatchDataset", "BatchDatasetV3") == true
242 // MatchesAnyVersion("PaddedBatchDataset", "BatchDataset") == false
243 bool MatchesAnyVersion(StringPiece op_prefix, StringPiece op_to_match);
244 
245 // Returns the index-th slice of a given tensor. If the index-th slice of
246 // the tensor is not aligned, returns a deep copy of the tensor.
247 Tensor MaybeCopySubSlice(const Tensor& tensor, int64 index);
248 
249 // Removes device placements from the ops of all functions in `library`.
250 void StripDevicePlacement(FunctionDefLibrary* library);
251 
252 // Copies partial of the batch output.
253 Status CopyPartialBatch(int64_t num_elements, const Tensor& value,
254                         Tensor* output);
255 
256 // Reads a batch when restoring the iterator.
257 Status ReadBatch(IteratorContext* ctx, IteratorStateReader* reader,
258                  int64_t batch_size, const string& iterator_prefix,
259                  const string& batch_prefix, std::vector<Tensor>* batch);
260 
261 // Writes a batch when saving the iterator.
262 Status WriteBatch(int64_t batch_size, int64_t num_elements,
263                   const string& iterator_prefix, const string& batch_prefix,
264                   IteratorStateWriter* writer, std::vector<Tensor>* batch);
265 
266 // Reads a status when restoring the iterator.
267 Status ReadStatus(const string& iterator_prefix, const string& prefix,
268                   IteratorStateReader* reader, Status* status);
269 
270 // Writes a status when saving the iterator.
271 Status WriteStatus(const string& iterator_prefix, const string& prefix,
272                    const Status& status, IteratorStateWriter* writer);
273 
274 // Processes a batch to output. In the case a partial batch is encountered, copy
275 // only partial of the batch.
276 Status ProcessBatch(int64_t batch_size, int64_t num_elements,
277                     bool drop_remainder, const Status& status,
278                     IteratorContext* ctx, std::vector<Tensor>* output,
279                     bool* end_of_sequence, std::vector<Tensor>* batch);
280 
281 // Constructs and stores the parameters for the CopyBatch function.
282 struct CopyBatchParams {
283   Allocator* allocator;
284   std::function<void(std::function<void()>)>* runner;
285   int64 runner_threadpool_size;
286 
CopyBatchParamsCopyBatchParams287   explicit CopyBatchParams(IteratorContext* ctx) {
288     allocator = ctx->allocator({});
289     runner = ctx->runner();
290     runner_threadpool_size = ctx->runner_threadpool_size();
291   }
292 
CopyBatchParamsCopyBatchParams293   explicit CopyBatchParams(OpKernelContext* ctx) {
294     allocator = ctx->get_allocator({});
295     runner = ctx->runner();
296     runner_threadpool_size = GetRunnerThreadpoolSizeFromOpKernelContext(ctx);
297   }
298 };
299 
300 // Copies the input elements to a batch.
301 //
302 // The `batch_elements` argument contains the individual elements to copy into a
303 // batch. The `parallel_copy` argument indicates whether to parallelize the
304 // copy. The `allocation_callback` argument can be used to pass a callback to
305 // invoke upon successful allocation of the memory for the batch. The
306 // `out_tensors` argument will be used to store the resulting batch (one for
307 // each component of the input).
308 Status CopyBatch(CopyBatchParams params,
309                  const std::vector<std::vector<Tensor>>& batch_elements,
310                  bool parallel_copy,
311                  std::function<Status()> allocation_callback,
312                  std::vector<Tensor>* out_tensors);
313 
314 // Computes the set of experiments to apply based on the job name, rollout
315 // percentage of registered experiments, and the TF_DATA_EXPERIMENT_OPT_IN and
316 // TF_DATA_EXPERIMENT_OPT_OUT environment variables.
317 absl::flat_hash_set<string> GetExperiments();
318 absl::flat_hash_set<string> GetExperiments(
319     const string& job_name, std::function<uint64(const string&)> hash_func);
320 
321 // Logs and records the experiments that will be applied.
322 void LogAndRecordExperiments(const absl::flat_hash_set<string>& experiments);
323 
324 // Computes the set of enabled, disabled, and default optimizations based on the
325 // given options. An optimization must be a graph optimizer name that has been
326 // registered with Grappler.
327 void GetOptimizations(const Options& options,
328                       absl::flat_hash_set<tstring>* optimizations_enabled,
329                       absl::flat_hash_set<tstring>* optimizations_disabled,
330                       absl::flat_hash_set<tstring>* optimizations_default);
331 
332 // Creates graph rewrite configs based on the given options. The configs will
333 // only be used if their corresponding optimizers registered with Grappler are
334 // enabled.
335 // A config is a string with the following format:
336 //   <optimizer name>:<attribute name>:<attribute value>
337 absl::flat_hash_set<tstring> CreateGraphRewriteConfigs(const Options& options);
338 
339 // Determines whether max intra-op parallelism should be configured.
340 bool ShouldConfigureMaxIntraOpParallelism(const Options& options);
341 
342 // Determines whether private threadpool should be used.
343 bool ShouldUsePrivateThreadPool(const Options& options);
344 
345 // Determines whether autotuning should be used.
346 bool ShouldUseAutotuning(const Options& options);
347 
348 // Determines whether optimizations should be applied.
349 bool ShouldApplyOptimizations(
350     const Options& options,
351     const absl::flat_hash_set<tstring>& optimizations_enabled,
352     const absl::flat_hash_set<tstring>& optimizations_default);
353 
354 // Returns the default CPU budget.
GetCpuBudget()355 inline int GetCpuBudget() {
356   static bool in_experiment = GetExperiments().contains("tune_cpu_budget");
357   return (in_experiment ? 1.2 : 1.0) * port::NumSchedulableCPUs();
358 }
359 
360 // Returns the initial value for parallelism parameter before the first Autotune
361 // optimization.
362 int64 GetAutotuneDefaultParallelism(IteratorContext* ctx);
363 
364 // Registry of tf.data experiments.
365 class DatasetExperimentRegistry {
366  public:
367   // Registers the experiment.
368   static void Register(const string& experiment, int64_t rollout_pct);
369 
370   // Returns all registered experiments.
371   static absl::flat_hash_map<string, int64_t> Experiments();
372 };
373 
374 // Helper class to register a dataset experiment.
375 class DatasetExperimentRegistrar {
376  public:
DatasetExperimentRegistrar(const string & experiment,int64_t rollout_pct)377   explicit DatasetExperimentRegistrar(const string& experiment,
378                                       int64_t rollout_pct) {
379     DatasetExperimentRegistry::Register(experiment, rollout_pct);
380   }
381 };
382 
383 // Macro that can be used to register a dataset experiment.
384 #define REGISTER_DATASET_EXPERIMENT(experiment, rollout_pct) \
385   REGISTER_DATASET_OP_NAME_UNIQ_HELPER(__COUNTER__, experiment, rollout_pct)
386 
387 #define REGISTER_DATASET_OP_NAME_UNIQ_HELPER(ctr, experiment, rollout_pct) \
388   REGISTER_DATASET_OP_NAME_UNIQ(ctr, experiment, rollout_pct)
389 
390 #define REGISTER_DATASET_OP_NAME_UNIQ(ctr, experiment, rollout_pct) \
391   static ::tensorflow::data::DatasetExperimentRegistrar             \
392       registrar__body__##ctr##__object(experiment, rollout_pct)
393 
394 }  // namespace data
395 }  // namespace tensorflow
396 
397 #endif  // TENSORFLOW_CORE_DATA_DATASET_UTILS_H_
398