1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
16 
17 #include <map>
18 #include <memory>
19 #include <utility>
20 #include <variant>
21 #include <vector>
22 
23 #include "tensorflow/core/common_runtime/eager/context.h"
24 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
25 #include "tensorflow/core/distributed_runtime/call_options.h"
26 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
27 #include "tensorflow/core/distributed_runtime/worker_session.h"
28 #include "tensorflow/core/framework/cancellation.h"
29 #include "tensorflow/core/framework/function.h"
30 #include "tensorflow/core/framework/graph_def_util.h"
31 
32 namespace tensorflow {
33 namespace eager {
34 namespace {
StripDefaultAttributesInRegisterFunctionOp(RegisterFunctionOp * register_function)35 void StripDefaultAttributesInRegisterFunctionOp(
36     RegisterFunctionOp* register_function) {
37   StripDefaultAttributes(
38       *OpRegistry::Global(),
39       register_function->mutable_function_def()->mutable_node_def());
40   for (auto& function :
41        *register_function->mutable_library()->mutable_function()) {
42     StripDefaultAttributes(*OpRegistry::Global(), function.mutable_node_def());
43   }
44 }
45 }  // namespace
46 
Instantiate(const string & function_name,const FunctionLibraryDefinition & lib_def,AttrSlice attrs,const FunctionLibraryRuntime::InstantiateOptions & options,FunctionLibraryRuntime::LocalHandle * handle,FunctionLibraryRuntime::DoneCallback done)47 void EagerClusterFunctionLibraryRuntime::Instantiate(
48     const string& function_name, const FunctionLibraryDefinition& lib_def,
49     AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options,
50     FunctionLibraryRuntime::LocalHandle* handle,
51     FunctionLibraryRuntime::DoneCallback done) {
52   auto target = options.target;
53   auto released_op = std::make_unique<EagerOperation>(ctx_);
54   Status s =
55       released_op->Reset(function_name.c_str(), target.c_str(), true, nullptr);
56   if (!s.ok()) {
57     done(s);
58     return;
59   }
60   if (!released_op->is_function()) {
61     done(errors::Internal(function_name, " is not a function."));
62     return;
63   }
64 
65   VLOG(1) << "CFLR::Instantiate: " << function_name << " on " << target
66           << " (this: " << this << ")";
67   core::RefCountPtr<eager::EagerClient> eager_client;
68   s = ctx_->GetClient(target, &eager_client);
69   if (!s.ok()) {
70     done(s);
71     return;
72   }
73 
74   if (eager_client == nullptr) {
75     done(errors::InvalidArgument("Could not find eager client for target: ",
76                                  target));
77     return;
78   }
79 
80   const FunctionLibraryDefinition& func_lib_def =
81       options.lib_def ? *options.lib_def : lib_def;
82 
83   auto request = std::make_shared<EnqueueRequest>();
84   auto response = std::make_shared<EnqueueResponse>();
85 
86   request->set_context_id(context_id_);
87 
88   RegisterFunctionOp* register_function =
89       request->add_queue()->mutable_register_function();
90   *register_function->mutable_function_def() =
91       *func_lib_def.Find(function_name);
92   register_function->set_is_component_function(true);
93   *register_function->mutable_library() =
94       func_lib_def.ReachableDefinitions(register_function->function_def())
95           .ToProto();
96   StripDefaultAttributesInRegisterFunctionOp(register_function);
97 
98   const absl::optional<std::vector<int>>& ret_indices = options.ret_indices;
99   eager_client->EnqueueAsync(
100       /*call_opts=*/nullptr, request.get(), response.get(),
101       [this, request, response, handle, released_op = released_op.release(),
102        target, ret_indices, eager_client = eager_client.get(),
103        done](const Status& s) {
104         {
105           mutex_lock l(mu_);
106           *handle = function_data_.size();
107           function_data_.emplace_back(target, ret_indices, eager_client,
108                                       absl::WrapUnique(released_op));
109         }
110         done(s);
111       });
112 }
113 
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::LocalHandle handle,gtl::ArraySlice<Tensor> args,std::vector<Tensor> * rets,FunctionLibraryRuntime::DoneCallback done)114 void EagerClusterFunctionLibraryRuntime::Run(
115     const FunctionLibraryRuntime::Options& opts,
116     FunctionLibraryRuntime::LocalHandle handle, gtl::ArraySlice<Tensor> args,
117     std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) {
118   std::vector<FunctionArg> function_args;
119   for (const auto& tensor : args) {
120     function_args.push_back(tensor);
121   }
122   std::vector<FunctionRet>* function_rets = new std::vector<FunctionRet>;
123   Run(opts, handle, function_args, function_rets,
124       [rets, function_rets, done = std::move(done)](const Status& s) {
125         Status status = s;
126         if (status.ok()) {
127           for (const auto& t : *function_rets) {
128             if (t.index() == 0) {
129               rets->push_back(std::get<Tensor>(t));
130             } else {
131               status.Update(
132                   errors::Internal("Expect a Tensor as a remote function "
133                                    "output but got a TensorShape."));
134               break;
135             }
136           }
137         }
138         delete function_rets;
139         done(status);
140       });
141 }
142 
Run(const FunctionLibraryRuntime::Options & opts,FunctionLibraryRuntime::LocalHandle handle,gtl::ArraySlice<FunctionArg> args,std::vector<FunctionRet> * rets,FunctionLibraryRuntime::DoneCallback done)143 void EagerClusterFunctionLibraryRuntime::Run(
144     const FunctionLibraryRuntime::Options& opts,
145     FunctionLibraryRuntime::LocalHandle handle,
146     gtl::ArraySlice<FunctionArg> args, std::vector<FunctionRet>* rets,
147     FunctionLibraryRuntime::DoneCallback done) {
148   FunctionData* function_data = nullptr;
149   {
150     mutex_lock l(mu_);
151     DCHECK_LE(handle, function_data_.size());
152     function_data = &function_data_[handle];
153   }
154 
155   EagerClient* eager_client = function_data->eager_client.get();
156   if (eager_client == nullptr) {
157     done(errors::Internal("Could not find eager client"));
158     return;
159   }
160 
161   EagerOperation* op = function_data->op.get();
162   if (!op->Inputs().empty()) {
163     done(errors::Internal("Inputs should not be set during instantiation."));
164     return;
165   }
166 
167   auto request = std::make_shared<RunComponentFunctionRequest>();
168   auto response = std::make_shared<RunComponentFunctionResponse>();
169   request->set_context_id(context_id_);
170   eager::Operation* remote_op = request->mutable_operation();
171 
172   if (function_data->ret_indices.has_value()) {
173     for (const int ret_index : function_data->ret_indices.value()) {
174       request->add_output_num(ret_index);
175     }
176   }
177 
178   for (const auto& arg : args) {
179     if (arg.index() == 0) {
180       std::get<Tensor>(arg).AsProtoTensorContent(
181           remote_op->add_op_inputs()->mutable_tensor());
182     } else {
183       remote_op->add_op_inputs()->mutable_remote_handle()->Swap(
184           std::get<RemoteTensorHandle*>(arg));
185     }
186   }
187 
188   // The remote component function should use the same op_id as its parent
189   // multi-device function's in order to get the global unique op_id generated
190   // by the master context.
191   if (opts.op_id.has_value()) {
192     remote_op->set_id(opts.op_id.value());
193   } else {
194     remote_op->set_id(kInvalidOpId);
195   }
196   remote_op->set_is_function(true);
197   remote_op->set_is_component_function(true);
198   remote_op->set_func_step_id(opts.step_id);
199   remote_op->set_name(op->Name());
200   op->Attrs().FillAttrValueMap(remote_op->mutable_attrs());
201   remote_op->set_device(function_data->target);
202 
203   CancellationManager* cm = opts.cancellation_manager;
204   CancellationToken token = 0;
205   auto call_opts = std::make_shared<CallOptions>();
206   call_opts->SetTimeout(
207       ctx_->session_options().config.operation_timeout_in_ms());
208   if (cm != nullptr) {
209     token = cm->get_cancellation_token();
210     const bool already_cancelled = !cm->RegisterCallback(
211         token,
212         [call_opts, request, response, done]() { call_opts->StartCancel(); });
213     if (already_cancelled) {
214       done(errors::Cancelled("EagerClusterFunctionLibraryRuntime::Run"));
215       return;
216     }
217   }
218 
219   // Execute component function on remote worker using RunComponentFunction RPC.
220   // Different from executing remote functions with Enqueue, this method runs
221   // a function on remote worker without tying up a thread (i.e., pure
222   // asynchronously).
223   eager_client->RunComponentFunctionAsync(
224       call_opts.get(), request.get(), response.get(),
225       [request, response, rets, call_opts, cm, token,
226        done = std::move(done)](const Status& s) {
227         if (cm != nullptr) {
228           cm->TryDeregisterCallback(token);
229         }
230         if (!s.ok()) {
231           done(s);
232           return;
233         }
234         if (!response->shape().empty() && !response->tensor().empty()) {
235           done(errors::Internal(
236               "Both shape and tensor are specified in the same response"));
237           return;
238         }
239         for (const auto& shape : response->shape()) {
240           rets->push_back(shape);
241         }
242         for (const auto& tensor_proto : response->tensor()) {
243           Tensor t;
244           if (t.FromProto(tensor_proto)) {
245             rets->push_back(std::move(t));
246           } else {
247             done(errors::Internal("Could not convert tensor proto: ",
248                                   tensor_proto.DebugString()));
249             return;
250           }
251         }
252         done(OkStatus());
253       });
254 }
255 
CleanUp(uint64 step_id,FunctionLibraryRuntime::LocalHandle handle,FunctionLibraryRuntime::DoneCallback done)256 void EagerClusterFunctionLibraryRuntime::CleanUp(
257     uint64 step_id, FunctionLibraryRuntime::LocalHandle handle,
258     FunctionLibraryRuntime::DoneCallback done) {
259   FunctionData* function_data = nullptr;
260   {
261     mutex_lock l(mu_);
262     DCHECK_LE(handle, function_data_.size());
263     function_data = &function_data_[handle];
264   }
265 
266   EagerClient* eager_client = function_data->eager_client.get();
267   if (eager_client == nullptr) {
268     done(errors::Internal("Could not find eager client"));
269     return;
270   }
271 
272   auto request = std::make_shared<EnqueueRequest>();
273   auto response = std::make_shared<EnqueueResponse>();
274   request->set_context_id(context_id_);
275   CleanupFunctionOp* cleanup_function =
276       request->add_queue()->mutable_cleanup_function();
277   cleanup_function->set_step_id(step_id);
278   // StreamingEnqueueAsync could be blocking when streaming RPC is disabled.
279   // CleanUp() needs to be non-blocking since it would be invoked inside the
280   // enqueue done callback of Run(). So we don't use StreamingEnqueueAsync here.
281   eager_client->EnqueueAsync(
282       /*call_opts=*/nullptr, request.get(), response.get(),
283       [request, response, done](const Status& status) { done(status); });
284 }
285 
CreateClusterFLR(const uint64 context_id,EagerContext * ctx,WorkerSession * worker_session)286 DistributedFunctionLibraryRuntime* CreateClusterFLR(
287     const uint64 context_id, EagerContext* ctx, WorkerSession* worker_session) {
288   return new EagerClusterFunctionLibraryRuntime(
289       context_id, ctx, worker_session->remote_device_mgr());
290 }
291 
292 }  // namespace eager
293 }  // namespace tensorflow
294