1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/data/service/data_transfer.h"
17
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <unordered_map>
22 #include <vector>
23
24 #include "absl/strings/str_join.h"
25 #include "tensorflow/core/data/dataset.pb.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/types.pb.h"
28 #include "tensorflow/core/framework/variant.h"
29 #include "tensorflow/core/platform/errors.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/status.h"
32
33 namespace tensorflow {
34 namespace data {
35
36 namespace {
get_lock()37 mutex* get_lock() {
38 static mutex lock(LINKER_INITIALIZED);
39 return &lock;
40 }
41
42 using DataTransferServerFactories =
43 std::unordered_map<std::string,
44 std::function<std::shared_ptr<DataTransferServer>(
45 DataTransferServer::GetElementT)>>;
transfer_server_factories()46 DataTransferServerFactories& transfer_server_factories() {
47 static auto& factories = *new DataTransferServerFactories();
48 return factories;
49 }
50
51 using DataTransferClientFactories =
52 std::unordered_map<std::string, DataTransferClient::FactoryT>;
transfer_client_factories()53 DataTransferClientFactories& transfer_client_factories() {
54 static auto& factories = *new DataTransferClientFactories();
55 return factories;
56 }
57 } // namespace
58
Copy() const59 GetElementResult GetElementResult::Copy() const {
60 GetElementResult copy;
61 copy.components = components;
62 copy.element_index = element_index;
63 copy.end_of_sequence = end_of_sequence;
64 copy.skip = skip;
65 return copy;
66 }
67
EstimatedMemoryUsageBytes() const68 size_t GetElementResult::EstimatedMemoryUsageBytes() const {
69 size_t size_bytes = components.size() * sizeof(Tensor) +
70 sizeof(element_index) + sizeof(end_of_sequence) +
71 sizeof(skip);
72 for (const Tensor& tensor : components) {
73 size_bytes += tensor.TotalBytes();
74 if (tensor.dtype() != DT_VARIANT) {
75 continue;
76 }
77
78 // Estimates the memory usage of a compressed element.
79 const Variant& variant = tensor.scalar<Variant>()();
80 const CompressedElement* compressed = variant.get<CompressedElement>();
81 if (compressed) {
82 size_bytes += compressed->SpaceUsedLong();
83 }
84 }
85 return size_bytes;
86 }
87
Register(std::string name,std::function<std::shared_ptr<DataTransferServer> (GetElementT)> factory)88 void DataTransferServer::Register(
89 std::string name,
90 std::function<std::shared_ptr<DataTransferServer>(GetElementT)> factory) {
91 mutex_lock l(*get_lock());
92 if (!transfer_server_factories().insert({name, factory}).second) {
93 LOG(ERROR)
94 << "Two data transfer server factories are being registered with name "
95 << name << ". Which one gets used is undefined.";
96 }
97 }
98
Build(std::string name,GetElementT get_element,std::shared_ptr<DataTransferServer> * out)99 Status DataTransferServer::Build(std::string name, GetElementT get_element,
100 std::shared_ptr<DataTransferServer>* out) {
101 mutex_lock l(*get_lock());
102 auto it = transfer_server_factories().find(name);
103 if (it != transfer_server_factories().end()) {
104 *out = it->second(get_element);
105 return OkStatus();
106 }
107
108 std::vector<std::string> available_names;
109 for (const auto& factory : transfer_server_factories()) {
110 available_names.push_back(factory.first);
111 }
112
113 return errors::NotFound(
114 "No data transfer server factory has been registered for name ", name,
115 ". The available names are: [ ", absl::StrJoin(available_names, ", "),
116 " ]");
117 }
118
Register(std::string name,FactoryT factory)119 void DataTransferClient::Register(std::string name, FactoryT factory) {
120 mutex_lock l(*get_lock());
121 if (!transfer_client_factories().insert({name, factory}).second) {
122 LOG(ERROR)
123 << "Two data transfer client factories are being registered with name "
124 << name << ". Which one gets used is undefined.";
125 }
126 }
127
Build(std::string name,Config config,std::unique_ptr<DataTransferClient> * out)128 Status DataTransferClient::Build(std::string name, Config config,
129 std::unique_ptr<DataTransferClient>* out) {
130 mutex_lock l(*get_lock());
131 auto it = transfer_client_factories().find(name);
132 if (it != transfer_client_factories().end()) {
133 return it->second(config, out);
134 }
135
136 std::vector<string> available_names;
137 for (const auto& factory : transfer_client_factories()) {
138 available_names.push_back(factory.first);
139 }
140
141 return errors::NotFound(
142 "No data transfer client factory has been registered for name ", name,
143 ". The available names are: [ ", absl::StrJoin(available_names, ", "),
144 " ]");
145 }
146
147 } // namespace data
148 } // namespace tensorflow
149