1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
16
17 #include <map>
18 #include <memory>
19 #include <utility>
20 #include <variant>
21 #include <vector>
22
23 #include "tensorflow/core/common_runtime/eager/context.h"
24 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
25 #include "tensorflow/core/distributed_runtime/call_options.h"
26 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
27 #include "tensorflow/core/distributed_runtime/worker_session.h"
28 #include "tensorflow/core/framework/cancellation.h"
29 #include "tensorflow/core/framework/function.h"
30 #include "tensorflow/core/framework/graph_def_util.h"
31
32 namespace tensorflow {
33 namespace eager {
34 namespace {
StripDefaultAttributesInRegisterFunctionOp(RegisterFunctionOp * register_function)35 void StripDefaultAttributesInRegisterFunctionOp(
36 RegisterFunctionOp* register_function) {
37 StripDefaultAttributes(
38 *OpRegistry::Global(),
39 register_function->mutable_function_def()->mutable_node_def());
40 for (auto& function :
41 *register_function->mutable_library()->mutable_function()) {
42 StripDefaultAttributes(*OpRegistry::Global(), function.mutable_node_def());
43 }
44 }
45 } // namespace
46
Instantiate(const string & function_name,const FunctionLibraryDefinition & lib_def,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::LocalHandle * handle,FunctionLibraryRuntime::DoneCallback done)47 void EagerClusterFunctionLibraryRuntime::Instantiate(
48 const string& function_name, const FunctionLibraryDefinition& lib_def,
49 AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options,
50 FunctionLibraryRuntime::LocalHandle* handle,
51 FunctionLibraryRuntime::DoneCallback done) {
52 auto target = options.target;
53 auto released_op = std::make_unique<EagerOperation>(ctx_);
54 Status s =
55 released_op->Reset(function_name.c_str(), target.c_str(), true, nullptr);
56 if (!s.ok()) {
57 done(s);
58 return;
59 }
60 if (!released_op->is_function()) {
61 done(errors::Internal(function_name, " is not a function."));
62 return;
63 }
64
65 VLOG(1) << "CFLR::Instantiate: " << function_name << " on " << target
66 << " (this: " << this << ")";
67 core::RefCountPtr<eager::EagerClient> eager_client;
68 s = ctx_->GetClient(target, &eager_client);
69 if (!s.ok()) {
70 done(s);
71 return;
72 }
73
74 if (eager_client == nullptr) {
75 done(errors::InvalidArgument("Could not find eager client for target: ",
76 target));
77 return;
78 }
79
80 const FunctionLibraryDefinition& func_lib_def =
81 options.lib_def ? *options.lib_def : lib_def;
82
83 auto request = std::make_shared<EnqueueRequest>();
84 auto response = std::make_shared<EnqueueResponse>();
85
86 request->set_context_id(context_id_);
87
88 RegisterFunctionOp* register_function =
89 request->add_queue()->mutable_register_function();
90 *register_function->mutable_function_def() =
91 *func_lib_def.Find(function_name);
92 register_function->set_is_component_function(true);
93 *register_function->mutable_library() =
94 func_lib_def.ReachableDefinitions(register_function->function_def())
95 .ToProto();
96 StripDefaultAttributesInRegisterFunctionOp(register_function);
97
98 const absl::optional<std::vector<int>>& ret_indices = options.ret_indices;
99 eager_client->EnqueueAsync(
100 /*call_opts=*/nullptr, request.get(), response.get(),
101 [this, request, response, handle, released_op = released_op.release(),
102 target, ret_indices, eager_client = eager_client.get(),
103 done](const Status& s) {
104 {
105 mutex_lock l(mu_);
106 *handle = function_data_.size();
107 function_data_.emplace_back(target, ret_indices, eager_client,
108 absl::WrapUnique(released_op));
109 }
110 done(s);
111 });
112 }
113
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::LocalHandle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,FunctionLibraryRuntime::DoneCallback done)114 void EagerClusterFunctionLibraryRuntime::Run(
115 const FunctionLibraryRuntime::Options& opts,
116 FunctionLibraryRuntime::LocalHandle handle, gtl::ArraySlice<Tensor> args,
117 std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) {
118 std::vector<FunctionArg> function_args;
119 for (const auto& tensor : args) {
120 function_args.push_back(tensor);
121 }
122 std::vector<FunctionRet>* function_rets = new std::vector<FunctionRet>;
123 Run(opts, handle, function_args, function_rets,
124 [rets, function_rets, done = std::move(done)](const Status& s) {
125 Status status = s;
126 if (status.ok()) {
127 for (const auto& t : *function_rets) {
128 if (t.index() == 0) {
129 rets->push_back(std::get<Tensor>(t));
130 } else {
131 status.Update(
132 errors::Internal("Expect a Tensor as a remote function "
133 "output but got a TensorShape."));
134 break;
135 }
136 }
137 }
138 delete function_rets;
139 done(status);
140 });
141 }
142
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::LocalHandle handle,gtl::ArraySlice<FunctionArg> args,std::vector<FunctionRet> * rets,FunctionLibraryRuntime::DoneCallback done)143 void EagerClusterFunctionLibraryRuntime::Run(
144 const FunctionLibraryRuntime::Options& opts,
145 FunctionLibraryRuntime::LocalHandle handle,
146 gtl::ArraySlice<FunctionArg> args, std::vector<FunctionRet>* rets,
147 FunctionLibraryRuntime::DoneCallback done) {
148 FunctionData* function_data = nullptr;
149 {
150 mutex_lock l(mu_);
151 DCHECK_LE(handle, function_data_.size());
152 function_data = &function_data_[handle];
153 }
154
155 EagerClient* eager_client = function_data->eager_client.get();
156 if (eager_client == nullptr) {
157 done(errors::Internal("Could not find eager client"));
158 return;
159 }
160
161 EagerOperation* op = function_data->op.get();
162 if (!op->Inputs().empty()) {
163 done(errors::Internal("Inputs should not be set during instantiation."));
164 return;
165 }
166
167 auto request = std::make_shared<RunComponentFunctionRequest>();
168 auto response = std::make_shared<RunComponentFunctionResponse>();
169 request->set_context_id(context_id_);
170 eager::Operation* remote_op = request->mutable_operation();
171
172 if (function_data->ret_indices.has_value()) {
173 for (const int ret_index : function_data->ret_indices.value()) {
174 request->add_output_num(ret_index);
175 }
176 }
177
178 for (const auto& arg : args) {
179 if (arg.index() == 0) {
180 std::get<Tensor>(arg).AsProtoTensorContent(
181 remote_op->add_op_inputs()->mutable_tensor());
182 } else {
183 remote_op->add_op_inputs()->mutable_remote_handle()->Swap(
184 std::get<RemoteTensorHandle*>(arg));
185 }
186 }
187
188 // The remote component function should use the same op_id as its parent
189 // multi-device function's in order to get the global unique op_id generated
190 // by the master context.
191 if (opts.op_id.has_value()) {
192 remote_op->set_id(opts.op_id.value());
193 } else {
194 remote_op->set_id(kInvalidOpId);
195 }
196 remote_op->set_is_function(true);
197 remote_op->set_is_component_function(true);
198 remote_op->set_func_step_id(opts.step_id);
199 remote_op->set_name(op->Name());
200 op->Attrs().FillAttrValueMap(remote_op->mutable_attrs());
201 remote_op->set_device(function_data->target);
202
203 CancellationManager* cm = opts.cancellation_manager;
204 CancellationToken token = 0;
205 auto call_opts = std::make_shared<CallOptions>();
206 call_opts->SetTimeout(
207 ctx_->session_options().config.operation_timeout_in_ms());
208 if (cm != nullptr) {
209 token = cm->get_cancellation_token();
210 const bool already_cancelled = !cm->RegisterCallback(
211 token,
212 [call_opts, request, response, done]() { call_opts->StartCancel(); });
213 if (already_cancelled) {
214 done(errors::Cancelled("EagerClusterFunctionLibraryRuntime::Run"));
215 return;
216 }
217 }
218
219 // Execute component function on remote worker using RunComponentFunction RPC.
220 // Different from executing remote functions with Enqueue, this method runs
221 // a function on remote worker without tying up a thread (i.e., pure
222 // asynchronously).
223 eager_client->RunComponentFunctionAsync(
224 call_opts.get(), request.get(), response.get(),
225 [request, response, rets, call_opts, cm, token,
226 done = std::move(done)](const Status& s) {
227 if (cm != nullptr) {
228 cm->TryDeregisterCallback(token);
229 }
230 if (!s.ok()) {
231 done(s);
232 return;
233 }
234 if (!response->shape().empty() && !response->tensor().empty()) {
235 done(errors::Internal(
236 "Both shape and tensor are specified in the same response"));
237 return;
238 }
239 for (const auto& shape : response->shape()) {
240 rets->push_back(shape);
241 }
242 for (const auto& tensor_proto : response->tensor()) {
243 Tensor t;
244 if (t.FromProto(tensor_proto)) {
245 rets->push_back(std::move(t));
246 } else {
247 done(errors::Internal("Could not convert tensor proto: ",
248 tensor_proto.DebugString()));
249 return;
250 }
251 }
252 done(OkStatus());
253 });
254 }
255
CleanUp(uint64 step_id,FunctionLibraryRuntime::LocalHandle handle,FunctionLibraryRuntime::DoneCallback done)256 void EagerClusterFunctionLibraryRuntime::CleanUp(
257 uint64 step_id, FunctionLibraryRuntime::LocalHandle handle,
258 FunctionLibraryRuntime::DoneCallback done) {
259 FunctionData* function_data = nullptr;
260 {
261 mutex_lock l(mu_);
262 DCHECK_LE(handle, function_data_.size());
263 function_data = &function_data_[handle];
264 }
265
266 EagerClient* eager_client = function_data->eager_client.get();
267 if (eager_client == nullptr) {
268 done(errors::Internal("Could not find eager client"));
269 return;
270 }
271
272 auto request = std::make_shared<EnqueueRequest>();
273 auto response = std::make_shared<EnqueueResponse>();
274 request->set_context_id(context_id_);
275 CleanupFunctionOp* cleanup_function =
276 request->add_queue()->mutable_cleanup_function();
277 cleanup_function->set_step_id(step_id);
278 // StreamingEnqueueAsync could be blocking when streaming RPC is disabled.
279 // CleanUp() needs to be non-blocking since it would be invoked inside the
280 // enqueue done callback of Run(). So we don't use StreamingEnqueueAsync here.
281 eager_client->EnqueueAsync(
282 /*call_opts=*/nullptr, request.get(), response.get(),
283 [request, response, done](const Status& status) { done(status); });
284 }
285
CreateClusterFLR(const uint64 context_id,EagerContext * ctx,WorkerSession * worker_session)286 DistributedFunctionLibraryRuntime* CreateClusterFLR(
287 const uint64 context_id, EagerContext* ctx, WorkerSession* worker_session) {
288 return new EagerClusterFunctionLibraryRuntime(
289 context_id, ctx, worker_session->remote_device_mgr());
290 }
291
292 } // namespace eager
293 } // namespace tensorflow
294