1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_DATA_DATASET_UTILS_H_
16 #define TENSORFLOW_CORE_DATA_DATASET_UTILS_H_
17
18 #include <functional>
19 #include <string>
20
21 #include "absl/container/flat_hash_set.h"
22 #include "tensorflow/core/common_runtime/function.h"
23 #include "tensorflow/core/framework/dataset.h"
24 #include "tensorflow/core/framework/function.h"
25 #include "tensorflow/core/framework/resource_handle.h"
26 #include "tensorflow/core/framework/resource_mgr.h"
27 #include "tensorflow/core/framework/tensor.h"
28
29 namespace tensorflow {
30 namespace data {
31
32 // Constant used for indicating that the argument of tf.data.Dataset.shard
33 // should be supplied by the auto-sharding rewrite.
34 constexpr int kShardHint = -1;
35
36 // The initial parallelism value before Autotune has a chance to optimize.
37 constexpr int kAutotuneDefaultParallelism = 16;
38
39 // Creates a resource handle with a unique name for the given resource where
40 // the resource is managed by the Resource Manager.
41 template <typename T>
CreateWeakHandle(OpKernelContext * ctx,T * resource,const string & container_name,ResourceHandle * handle)42 Status CreateWeakHandle(OpKernelContext* ctx, T* resource,
43 const string& container_name, ResourceHandle* handle) {
44 static std::atomic<int64_t> resource_id_counter(0);
45 string unique_name =
46 strings::StrCat(container_name, resource_id_counter.fetch_add(1));
47 ResourceMgr* mgr = ctx->resource_manager();
48 TF_RETURN_IF_ERROR(mgr->Create<T>(container_name, unique_name, resource));
49
50 *handle = MakeResourceHandle(container_name, unique_name, *ctx->device(),
51 TypeIndex::Make<T>());
52 return OkStatus();
53 }
54
55 // Creates a ref-counting resource handle for the given resource, where the
56 // resource is owned by the handle.
57 template <typename T>
CreateHandle(OpKernelContext * ctx,T * resource,ResourceHandle * handle)58 Status CreateHandle(OpKernelContext* ctx, T* resource, ResourceHandle* handle) {
59 ResourceMgr* mgr = ctx->resource_manager();
60 *handle =
61 ResourceHandle::MakeRefCountingHandle(resource, ctx->device()->name());
62 TF_RETURN_IF_ERROR(
63 mgr->CreateUnowned<T>(handle->container(), handle->name(), resource));
64 return OkStatus();
65 }
66
67 // TODO(b/198162355): Merge this class with ResourceOpKernel.
68 template <typename T>
69 class AnonymousResourceOp : public OpKernel {
70 public:
71 // Creates an AnonymousResourceOp.
72 // ref_counting: Determines if the Op returns a ref-counting ResourceHandle.
73 // ResourceHandle. See go/tf-resource-handle-ref-count.
74 // return_deleter: Determines if the Op outputs a deleter tensor in addition
75 // to the resource handle tensor.
76 // If the resource handle is ref-counting, a no-op deleter is returned.
AnonymousResourceOp(OpKernelConstruction * context,bool ref_counting,bool return_deleter)77 explicit AnonymousResourceOp(OpKernelConstruction* context, bool ref_counting,
78 bool return_deleter)
79 : OpKernel(context),
80 ref_counting_(ref_counting),
81 return_deleter_(return_deleter) {}
82
Compute(OpKernelContext * ctx)83 void Compute(OpKernelContext* ctx) override {
84 FunctionLibraryRuntime* lib;
85 std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
86 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
87 OP_REQUIRES_OK(
88 ctx, ctx->function_library()->Clone(&flib_def, &pflr, &lib, true));
89 T* resource;
90 OP_REQUIRES_OK(ctx, CreateResource(ctx, std::move(flib_def),
91 std::move(pflr), lib, &resource));
92
93 ResourceHandle handle;
94 if (ref_counting_) {
95 OP_REQUIRES_OK(ctx, CreateHandle(ctx, resource, &handle));
96 } else {
97 OP_REQUIRES_OK(ctx, CreateWeakHandle(ctx, resource, name(), &handle));
98 }
99 Tensor* handle_t;
100 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle_t));
101 handle_t->scalar<ResourceHandle>()() = handle;
102
103 if (return_deleter_) {
104 Tensor* deleter_t;
105 AllocatorAttributes attr;
106 attr.set_on_host(true);
107 OP_REQUIRES_OK(
108 ctx, ctx->allocate_output(1, TensorShape({}), &deleter_t, attr));
109 // TODO(feyu): Consider returning an OptionalVariant.
110 if (!ref_counting_) {
111 // A deleter output that deletes the resource when destroyed.
112 deleter_t->scalar<Variant>()() =
113 ResourceDeleter(handle, ctx->resource_manager());
114 }
115 }
116 }
117
118 protected:
119 virtual string name() = 0;
120
121 virtual Status CreateResource(
122 OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
123 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
124 FunctionLibraryRuntime* lib, T** resource) = 0;
125
126 private:
127 const bool ref_counting_;
128 const bool return_deleter_;
129 };
130
131 // Returns Status::OK() if `expected` and `received` types match,
132 // errors::InvalidArgument otherwise.
133 Status VerifyTypesMatch(const DataTypeVector& expected,
134 const DataTypeVector& received);
135
136 Status VerifyTypesMatch(const DataTypeVector& expected,
137 const std::vector<Tensor>& received);
138
139 // Returns Status::OK() if `expected` and `received` shapes are compatible,
140 // errors::InvalidArgument otherwise.
141 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
142 const std::vector<PartialTensorShape>& received);
143
144 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
145 const std::vector<Tensor>& received);
146
147 // Dataset op level determinism policy.
148 class DeterminismPolicy {
149 public:
150 enum class Type : int {
151 // The op must produce elements deterministically.
152 kDeterministic,
153 // The op may relax determinism to improve performance.
154 kNondeterministic,
155 // The determinism policy is not specified at the op level. In this case we
156 // use the experimental_deterministic dataset option to determine the
157 // determinism policy.
158 kDefault,
159 };
160 static constexpr const char* const kDeterministic = "true";
161 static constexpr const char* const kNondeterministic = "false";
162 static constexpr const char* const kDefault = "default";
163
DeterminismPolicy()164 DeterminismPolicy() : determinism_(Type::kDefault) {}
DeterminismPolicy(Type determinism)165 explicit DeterminismPolicy(Type determinism) : determinism_(determinism) {}
166 // Creates a DeterminismPolicy with Type kDeterministic or
167 // kNondeterministic, depending on the values of `is_deterministic`.
168 explicit DeterminismPolicy(bool is_deterministic);
169
170 static Status FromString(const std::string& s, DeterminismPolicy* out);
171
172 // Returns the string representing the determinism policy. This will be one of
173 // the string constants defined above.
174 std::string String() const;
175
176 /// Convenience methods for checking the DeterminismPolicy::Type.
IsDeterministic()177 bool IsDeterministic() const { return determinism_ == Type::kDeterministic; }
IsNondeterministic()178 bool IsNondeterministic() const {
179 return determinism_ == Type::kNondeterministic;
180 }
IsDefault()181 bool IsDefault() const { return determinism_ == Type::kDefault; }
182
183 private:
184 Type determinism_;
185 };
186
187 // Resolves non-deterministic seeds if necessary, returning either the original
188 // seeds or the resolved seeds.
189 //
190 // By TensorFlow convention, if both seeds are 0, they should be replaced with
191 // non-deterministically chosen seeds.
192 std::pair<int64_t, int64_t> MaybeOverrideSeeds(
193 std::pair<int64_t, int64_t> seeds);
194
195 // Adds the functions in `to_add` to `base`. If a function with a matching
196 // signature already exists in `base`, replaces it with the function from
197 // `to_add`.
198 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
199 const FunctionLibraryDefinition& to_add);
200 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
201 const FunctionDefLibrary& to_add);
202
203 // Determines whether the given function is stateful.
204 Status IsFunctionStateful(const FunctionLibraryDefinition& library,
205 const FunctionDef& function_def);
206
207 // Determines whether the given node is stateful.
208 Status IsNodeStateful(const FunctionLibraryDefinition& library,
209 const NodeDef& node);
210
211 // Creates a runner that runs functions with limited parallelism.
212 std::function<void(std::function<void()>)> RunnerWithMaxParallelism(
213 std::function<void(std::function<void()>)> runner, int max_parallelism);
214
215 // Op for creating a typed dummy resource.
216 //
217 // This op is used to provide a resource "placeholder" for ops such as
218 // `CacheDatasetV2` or `ShuffleDatasetV2` that expects a resource input.
219 // Originally, the lifetime of the resources passed into these ops was managed
220 // externally. After the implementation changed to manage the lifetime of the
221 // resources (including creation) by the ops themselves, the resource input is
222 // only needed to pass a resource handle through graph rewrites. When they are
223 // invoked from user code, the implementation passes in a dummy resource.
224 template <typename ResourceType>
225 class DummyResourceOp : public OpKernel {
226 public:
DummyResourceOp(OpKernelConstruction * ctx)227 explicit DummyResourceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
228
Compute(OpKernelContext * ctx)229 void Compute(OpKernelContext* ctx) override {
230 Tensor* tensor;
231 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &tensor));
232 tensor->scalar<ResourceHandle>()() = MakeResourceHandle<ResourceType>(
233 ctx, /*container=*/"", /*name=*/"dummy_resource");
234 }
235 };
236
237 // Given an op prefix and an op to match, returns whether the op to match
238 // is a match for any version of the op prefix. For example,
239 // MatchesAnyVersion("BatchDataset", "BatchDataset") == true
240 // MatchesAnyVersion("BatchDataset", "BatchDatasetV2") == true
241 // MatchesAnyVersion("BatchDataset", "BatchDatasetV3") == true
242 // MatchesAnyVersion("PaddedBatchDataset", "BatchDataset") == false
243 bool MatchesAnyVersion(StringPiece op_prefix, StringPiece op_to_match);
244
245 // Returns the index-th slice of a given tensor. If the index-th slice of
246 // the tensor is not aligned, returns a deep copy of the tensor.
247 Tensor MaybeCopySubSlice(const Tensor& tensor, int64 index);
248
249 // Removes device placements from the ops of all functions in `library`.
250 void StripDevicePlacement(FunctionDefLibrary* library);
251
252 // Copies partial of the batch output.
253 Status CopyPartialBatch(int64_t num_elements, const Tensor& value,
254 Tensor* output);
255
256 // Reads a batch when restoring the iterator.
257 Status ReadBatch(IteratorContext* ctx, IteratorStateReader* reader,
258 int64_t batch_size, const string& iterator_prefix,
259 const string& batch_prefix, std::vector<Tensor>* batch);
260
261 // Writes a batch when saving the iterator.
262 Status WriteBatch(int64_t batch_size, int64_t num_elements,
263 const string& iterator_prefix, const string& batch_prefix,
264 IteratorStateWriter* writer, std::vector<Tensor>* batch);
265
266 // Reads a status when restoring the iterator.
267 Status ReadStatus(const string& iterator_prefix, const string& prefix,
268 IteratorStateReader* reader, Status* status);
269
270 // Writes a status when saving the iterator.
271 Status WriteStatus(const string& iterator_prefix, const string& prefix,
272 const Status& status, IteratorStateWriter* writer);
273
274 // Processes a batch to output. In the case a partial batch is encountered, copy
275 // only partial of the batch.
276 Status ProcessBatch(int64_t batch_size, int64_t num_elements,
277 bool drop_remainder, const Status& status,
278 IteratorContext* ctx, std::vector<Tensor>* output,
279 bool* end_of_sequence, std::vector<Tensor>* batch);
280
281 // Constructs and stores the parameters for the CopyBatch function.
282 struct CopyBatchParams {
283 Allocator* allocator;
284 std::function<void(std::function<void()>)>* runner;
285 int64 runner_threadpool_size;
286
CopyBatchParamsCopyBatchParams287 explicit CopyBatchParams(IteratorContext* ctx) {
288 allocator = ctx->allocator({});
289 runner = ctx->runner();
290 runner_threadpool_size = ctx->runner_threadpool_size();
291 }
292
CopyBatchParamsCopyBatchParams293 explicit CopyBatchParams(OpKernelContext* ctx) {
294 allocator = ctx->get_allocator({});
295 runner = ctx->runner();
296 runner_threadpool_size = GetRunnerThreadpoolSizeFromOpKernelContext(ctx);
297 }
298 };
299
300 // Copies the input elements to a batch.
301 //
302 // The `batch_elements` argument contains the individual elements to copy into a
303 // batch. The `parallel_copy` argument indicates whether to parallelize the
304 // copy. The `allocation_callback` argument can be used to pass a callback to
305 // invoke upon successful allocation of the memory for the batch. The
306 // `out_tensors` argument will be used to store the resulting batch (one for
307 // each component of the input).
308 Status CopyBatch(CopyBatchParams params,
309 const std::vector<std::vector<Tensor>>& batch_elements,
310 bool parallel_copy,
311 std::function<Status()> allocation_callback,
312 std::vector<Tensor>* out_tensors);
313
314 // Computes the set of experiments to apply based on the job name, rollout
315 // percentage of registered experiments, and the TF_DATA_EXPERIMENT_OPT_IN and
316 // TF_DATA_EXPERIMENT_OPT_OUT environment variables.
317 absl::flat_hash_set<string> GetExperiments();
318 absl::flat_hash_set<string> GetExperiments(
319 const string& job_name, std::function<uint64(const string&)> hash_func);
320
321 // Logs and records the experiments that will be applied.
322 void LogAndRecordExperiments(const absl::flat_hash_set<string>& experiments);
323
324 // Computes the set of enabled, disabled, and default optimizations based on the
325 // given options. An optimization must be a graph optimizer name that has been
326 // registered with Grappler.
327 void GetOptimizations(const Options& options,
328 absl::flat_hash_set<tstring>* optimizations_enabled,
329 absl::flat_hash_set<tstring>* optimizations_disabled,
330 absl::flat_hash_set<tstring>* optimizations_default);
331
332 // Creates graph rewrite configs based on the given options. The configs will
333 // only be used if their corresponding optimizers registered with Grappler are
334 // enabled.
335 // A config is a string with the following format:
336 // <optimizer name>:<attribute name>:<attribute value>
337 absl::flat_hash_set<tstring> CreateGraphRewriteConfigs(const Options& options);
338
339 // Determines whether max intra-op parallelism should be configured.
340 bool ShouldConfigureMaxIntraOpParallelism(const Options& options);
341
342 // Determines whether private threadpool should be used.
343 bool ShouldUsePrivateThreadPool(const Options& options);
344
345 // Determines whether autotuning should be used.
346 bool ShouldUseAutotuning(const Options& options);
347
348 // Determines whether optimizations should be applied.
349 bool ShouldApplyOptimizations(
350 const Options& options,
351 const absl::flat_hash_set<tstring>& optimizations_enabled,
352 const absl::flat_hash_set<tstring>& optimizations_default);
353
354 // Returns the default CPU budget.
GetCpuBudget()355 inline int GetCpuBudget() {
356 static bool in_experiment = GetExperiments().contains("tune_cpu_budget");
357 return (in_experiment ? 1.2 : 1.0) * port::NumSchedulableCPUs();
358 }
359
360 // Returns the initial value for parallelism parameter before the first Autotune
361 // optimization.
362 int64 GetAutotuneDefaultParallelism(IteratorContext* ctx);
363
364 // Registry of tf.data experiments.
365 class DatasetExperimentRegistry {
366 public:
367 // Registers the experiment.
368 static void Register(const string& experiment, int64_t rollout_pct);
369
370 // Returns all registered experiments.
371 static absl::flat_hash_map<string, int64_t> Experiments();
372 };
373
374 // Helper class to register a dataset experiment.
375 class DatasetExperimentRegistrar {
376 public:
DatasetExperimentRegistrar(const string & experiment,int64_t rollout_pct)377 explicit DatasetExperimentRegistrar(const string& experiment,
378 int64_t rollout_pct) {
379 DatasetExperimentRegistry::Register(experiment, rollout_pct);
380 }
381 };
382
383 // Macro that can be used to register a dataset experiment.
384 #define REGISTER_DATASET_EXPERIMENT(experiment, rollout_pct) \
385 REGISTER_DATASET_OP_NAME_UNIQ_HELPER(__COUNTER__, experiment, rollout_pct)
386
387 #define REGISTER_DATASET_OP_NAME_UNIQ_HELPER(ctr, experiment, rollout_pct) \
388 REGISTER_DATASET_OP_NAME_UNIQ(ctr, experiment, rollout_pct)
389
390 #define REGISTER_DATASET_OP_NAME_UNIQ(ctr, experiment, rollout_pct) \
391 static ::tensorflow::data::DatasetExperimentRegistrar \
392 registrar__body__##ctr##__object(experiment, rollout_pct)
393
394 } // namespace data
395 } // namespace tensorflow
396
397 #endif // TENSORFLOW_CORE_DATA_DATASET_UTILS_H_
398