1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DATA_STANDALONE_H_ 17 #define TENSORFLOW_CORE_DATA_STANDALONE_H_ 18 19 #include <functional> 20 #include <memory> 21 22 #include "tensorflow/core/common_runtime/device_mgr.h" 23 #include "tensorflow/core/data/unbounded_thread_pool.h" 24 #include "tensorflow/core/framework/dataset.h" 25 #include "tensorflow/core/framework/function_handle_cache.h" 26 #include "tensorflow/core/lib/core/threadpool.h" 27 #include "tensorflow/core/public/session_options.h" 28 29 namespace tensorflow { 30 namespace data { 31 namespace standalone { 32 33 // The purpose of the API in this file is to facilitate standalone execution of 34 // a tf.data input pipeline graph. 35 // 36 // The API exposes two abstractions -- a `Dataset` and an `Iterator` -- which 37 // encapsulate TensorFlow runtime. 38 // 39 // The `Dataset` abstraction represents an input pipeline as a collection 40 // of data sources and a logical plan of transformations that operate over the 41 // data. 42 // 43 // The `Iterator` abstraction represents an execution of an input pipeline that 44 // can be used to enumerate its elements. 45 // 46 // Example usage: 47 // 48 // // Create a `Dataset` by running the `graph_def` graph. 49 // tensorflow::data:standalone::Dataset::Params params; 50 // std::unique_ptr<tensorflow::data::standalone::Dataset> dataset; 51 // Status s = tensorflow::data::standalone::Dataset::FromGraph( 52 // params, graph_def, &dataset); 53 // if (!s.ok()) { /* error handling */ } 54 // 55 // std::unique_ptr<tensorflow::data::standalone::Iterator> iterator; 56 // s = dataset->MakeIterator(&iterator); 57 // if (!s.ok()) { /* error handling */ } 58 // 59 // bool end_of_input = false; 60 // while (!end_of_input) { 61 // std::vector<tensorflow::Tensor> outputs; 62 // s = iterator->GetNext(&outputs, &end_of_input); 63 // if (!s.ok()) { /* error handling */ } 64 // if (!end_of_input) { /* output handling */ } 65 // } 66 67 class Dataset; 68 69 // Represents an execution of an input pipeline that can be used to enumerate 70 // its elements. 71 class Iterator { 72 public: 73 // Returns the next element of the input pipeline (if there is one) and an 74 // indication of whether the end of the input pipeline has been reached. 75 Status GetNext(std::vector<Tensor>* outputs, bool* end_of_input); 76 77 private: 78 friend class Dataset; 79 80 Iterator(IteratorBase* iterator, IteratorContext* ctx); 81 82 std::unique_ptr<IteratorBase> iterator_; 83 std::unique_ptr<IteratorContext> ctx_; 84 }; 85 86 // Represents an input pipeline as a collection of data sources and a logical 87 // plan of transformations that operate over the data. 88 class Dataset { 89 public: 90 // Parameters for `Dataset` creation (e.g. TensorFlow runtime configuration). 91 struct Params { 92 SessionOptions session_options; 93 }; 94 95 // Creates a new `Dataset` instance by running the given dataset graph. 96 static Status FromGraph(Params params, const GraphDef& graph_def, 97 std::unique_ptr<Dataset>* result); 98 99 ~Dataset(); 100 101 // Creates an iterator for this dataset. 102 Status MakeIterator(std::unique_ptr<Iterator>* result); 103 // Creates an iterator, optionally with a split provider. 104 Status MakeIterator( 105 std::vector<std::unique_ptr<SplitProvider>> split_providers, 106 std::unique_ptr<Iterator>* result); 107 108 // Creates split providers for this dataset. 109 Status MakeSplitProviders( 110 std::vector<std::unique_ptr<SplitProvider>>* result); 111 // Returns a pointer to the underlying dataset. 112 const DatasetBase* Get() const; 113 114 private: 115 Dataset(DatasetBase* finalized_dataset, DatasetBase* original_dataset, 116 DeviceMgr* device_mgr, ProcessFunctionLibraryRuntime* pflr, 117 FunctionLibraryDefinition* flib_def, thread::ThreadPool* pool, 118 std::function<void(std::function<void()>)> runner); 119 120 DatasetBase* finalized_dataset_; // owned 121 DatasetBase* original_dataset_; // owned 122 std::unique_ptr<DeviceMgr> device_mgr_; 123 std::unique_ptr<FunctionLibraryDefinition> flib_def_; 124 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; 125 std::unique_ptr<thread::ThreadPool> interop_threadpool_; 126 std::unique_ptr<FunctionHandleCache> function_handle_cache_; 127 std::function<void(std::function<void()>)> runner_; 128 ResourceMgr resource_mgr_; 129 CancellationManager cancellation_manager_; 130 UnboundedThreadPool unbounded_thread_pool_; 131 }; 132 133 } // namespace standalone 134 } // namespace data 135 } // namespace tensorflow 136 137 #endif // TENSORFLOW_CORE_DATA_STANDALONE_H_ 138