xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/service/data_transfer.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/data/service/data_transfer.h"
17 
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <unordered_map>
22 #include <vector>
23 
24 #include "absl/strings/str_join.h"
25 #include "tensorflow/core/data/dataset.pb.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/types.pb.h"
28 #include "tensorflow/core/framework/variant.h"
29 #include "tensorflow/core/platform/errors.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/status.h"
32 
33 namespace tensorflow {
34 namespace data {
35 
36 namespace {
get_lock()37 mutex* get_lock() {
38   static mutex lock(LINKER_INITIALIZED);
39   return &lock;
40 }
41 
42 using DataTransferServerFactories =
43     std::unordered_map<std::string,
44                        std::function<std::shared_ptr<DataTransferServer>(
45                            DataTransferServer::GetElementT)>>;
transfer_server_factories()46 DataTransferServerFactories& transfer_server_factories() {
47   static auto& factories = *new DataTransferServerFactories();
48   return factories;
49 }
50 
51 using DataTransferClientFactories =
52     std::unordered_map<std::string, DataTransferClient::FactoryT>;
transfer_client_factories()53 DataTransferClientFactories& transfer_client_factories() {
54   static auto& factories = *new DataTransferClientFactories();
55   return factories;
56 }
57 }  // namespace
58 
Copy() const59 GetElementResult GetElementResult::Copy() const {
60   GetElementResult copy;
61   copy.components = components;
62   copy.element_index = element_index;
63   copy.end_of_sequence = end_of_sequence;
64   copy.skip = skip;
65   return copy;
66 }
67 
EstimatedMemoryUsageBytes() const68 size_t GetElementResult::EstimatedMemoryUsageBytes() const {
69   size_t size_bytes = components.size() * sizeof(Tensor) +
70                       sizeof(element_index) + sizeof(end_of_sequence) +
71                       sizeof(skip);
72   for (const Tensor& tensor : components) {
73     size_bytes += tensor.TotalBytes();
74     if (tensor.dtype() != DT_VARIANT) {
75       continue;
76     }
77 
78     // Estimates the memory usage of a compressed element.
79     const Variant& variant = tensor.scalar<Variant>()();
80     const CompressedElement* compressed = variant.get<CompressedElement>();
81     if (compressed) {
82       size_bytes += compressed->SpaceUsedLong();
83     }
84   }
85   return size_bytes;
86 }
87 
Register(std::string name,std::function<std::shared_ptr<DataTransferServer> (GetElementT)> factory)88 void DataTransferServer::Register(
89     std::string name,
90     std::function<std::shared_ptr<DataTransferServer>(GetElementT)> factory) {
91   mutex_lock l(*get_lock());
92   if (!transfer_server_factories().insert({name, factory}).second) {
93     LOG(ERROR)
94         << "Two data transfer server factories are being registered with name "
95         << name << ". Which one gets used is undefined.";
96   }
97 }
98 
Build(std::string name,GetElementT get_element,std::shared_ptr<DataTransferServer> * out)99 Status DataTransferServer::Build(std::string name, GetElementT get_element,
100                                  std::shared_ptr<DataTransferServer>* out) {
101   mutex_lock l(*get_lock());
102   auto it = transfer_server_factories().find(name);
103   if (it != transfer_server_factories().end()) {
104     *out = it->second(get_element);
105     return OkStatus();
106   }
107 
108   std::vector<std::string> available_names;
109   for (const auto& factory : transfer_server_factories()) {
110     available_names.push_back(factory.first);
111   }
112 
113   return errors::NotFound(
114       "No data transfer server factory has been registered for name ", name,
115       ". The available names are: [ ", absl::StrJoin(available_names, ", "),
116       " ]");
117 }
118 
Register(std::string name,FactoryT factory)119 void DataTransferClient::Register(std::string name, FactoryT factory) {
120   mutex_lock l(*get_lock());
121   if (!transfer_client_factories().insert({name, factory}).second) {
122     LOG(ERROR)
123         << "Two data transfer client factories are being registered with name "
124         << name << ". Which one gets used is undefined.";
125   }
126 }
127 
Build(std::string name,Config config,std::unique_ptr<DataTransferClient> * out)128 Status DataTransferClient::Build(std::string name, Config config,
129                                  std::unique_ptr<DataTransferClient>* out) {
130   mutex_lock l(*get_lock());
131   auto it = transfer_client_factories().find(name);
132   if (it != transfer_client_factories().end()) {
133     return it->second(config, out);
134   }
135 
136   std::vector<string> available_names;
137   for (const auto& factory : transfer_client_factories()) {
138     available_names.push_back(factory.first);
139   }
140 
141   return errors::NotFound(
142       "No data transfer client factory has been registered for name ", name,
143       ". The available names are: [ ", absl::StrJoin(available_names, ", "),
144       " ]");
145 }
146 
147 }  // namespace data
148 }  // namespace tensorflow
149