xref: /aosp_15_r20/external/tensorflow/tensorflow/core/distributed_runtime/master_session.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/master_session.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <memory>
21 #include <string>
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <utility>
25 #include <vector>
26 
27 #include "tensorflow/core/common_runtime/process_util.h"
28 #include "tensorflow/core/common_runtime/profile_handler.h"
29 #include "tensorflow/core/common_runtime/stats_publisher_interface.h"
30 #include "tensorflow/core/debug/debug_graph_utils.h"
31 #include "tensorflow/core/distributed_runtime/request_id.h"
32 #include "tensorflow/core/distributed_runtime/scheduler.h"
33 #include "tensorflow/core/distributed_runtime/worker_cache.h"
34 #include "tensorflow/core/distributed_runtime/worker_interface.h"
35 #include "tensorflow/core/framework/allocation_description.pb.h"
36 #include "tensorflow/core/framework/collective.h"
37 #include "tensorflow/core/framework/cost_graph.pb.h"
38 #include "tensorflow/core/framework/graph_def_util.h"
39 #include "tensorflow/core/framework/node_def.pb.h"
40 #include "tensorflow/core/framework/node_def_util.h"
41 #include "tensorflow/core/framework/tensor.h"
42 #include "tensorflow/core/framework/tensor.pb.h"
43 #include "tensorflow/core/framework/tensor_description.pb.h"
44 #include "tensorflow/core/graph/graph_partition.h"
45 #include "tensorflow/core/graph/tensor_id.h"
46 #include "tensorflow/core/lib/core/blocking_counter.h"
47 #include "tensorflow/core/lib/core/notification.h"
48 #include "tensorflow/core/lib/core/refcount.h"
49 #include "tensorflow/core/lib/core/status.h"
50 #include "tensorflow/core/lib/gtl/cleanup.h"
51 #include "tensorflow/core/lib/gtl/inlined_vector.h"
52 #include "tensorflow/core/lib/gtl/map_util.h"
53 #include "tensorflow/core/lib/random/random.h"
54 #include "tensorflow/core/lib/strings/numbers.h"
55 #include "tensorflow/core/lib/strings/str_util.h"
56 #include "tensorflow/core/lib/strings/strcat.h"
57 #include "tensorflow/core/lib/strings/stringprintf.h"
58 #include "tensorflow/core/platform/env.h"
59 #include "tensorflow/core/platform/logging.h"
60 #include "tensorflow/core/platform/macros.h"
61 #include "tensorflow/core/platform/mutex.h"
62 #include "tensorflow/core/platform/tracing.h"
63 #include "tensorflow/core/protobuf/config.pb.h"
64 #include "tensorflow/core/protobuf/coordination_config.pb.h"
65 #include "tensorflow/core/public/session_options.h"
66 #include "tensorflow/core/util/device_name_utils.h"
67 
68 namespace tensorflow {
69 
70 // MasterSession wraps ClientGraph in a reference counted object.
71 // This way, MasterSession can clear up the cache mapping Run requests to
72 // compiled graphs while the compiled graph is still being used.
73 //
74 // TODO(zhifengc): Cleanup this class. It's becoming messy.
75 class MasterSession::ReffedClientGraph : public core::RefCounted {
76  public:
ReffedClientGraph(const string & handle,const BuildGraphOptions & bopts,std::unique_ptr<ClientGraph> client_graph,const SessionOptions & session_opts,const StatsPublisherFactory & stats_publisher_factory,bool is_partial,WorkerCacheInterface * worker_cache,bool should_deregister)77   ReffedClientGraph(const string& handle, const BuildGraphOptions& bopts,
78                     std::unique_ptr<ClientGraph> client_graph,
79                     const SessionOptions& session_opts,
80                     const StatsPublisherFactory& stats_publisher_factory,
81                     bool is_partial, WorkerCacheInterface* worker_cache,
82                     bool should_deregister)
83       : session_handle_(handle),
84         bg_opts_(bopts),
85         client_graph_before_register_(std::move(client_graph)),
86         session_opts_(session_opts),
87         is_partial_(is_partial),
88         callable_opts_(bopts.callable_options),
89         worker_cache_(worker_cache),
90         should_deregister_(should_deregister),
91         collective_graph_key_(
92             client_graph_before_register_->collective_graph_key) {
93     VLOG(1) << "Created ReffedClientGraph for node with "
94             << client_graph_before_register_->graph.num_node_ids();
95 
96     stats_publisher_ = stats_publisher_factory(handle, bopts, session_opts);
97 
98     // Initialize a name to node map for processing device stats.
99     for (Node* n : client_graph_before_register_->graph.nodes()) {
100       name_to_node_details_.emplace(
101           n->name(),
102           NodeDetails(n->type_string(),
103                       strings::StrCat(
104                           "(", absl::StrJoin(n->requested_inputs(), ", "))));
105     }
106   }
107 
~ReffedClientGraph()108   ~ReffedClientGraph() override {
109     if (should_deregister_) {
110       DeregisterPartitions();
111     } else {
112       for (Part& part : partitions_) {
113         worker_cache_->ReleaseWorker(part.name, part.worker);
114       }
115     }
116   }
117 
callable_options()118   const CallableOptions& callable_options() { return callable_opts_; }
119 
build_graph_options()120   const BuildGraphOptions& build_graph_options() { return bg_opts_; }
121 
collective_graph_key()122   int64_t collective_graph_key() { return collective_graph_key_; }
123 
GetProfileHandler(uint64 step,int64_t execution_count,const RunOptions & ropts)124   std::unique_ptr<ProfileHandler> GetProfileHandler(uint64 step,
125                                                     int64_t execution_count,
126                                                     const RunOptions& ropts) {
127     return stats_publisher_->GetProfileHandler(step, execution_count, ropts);
128   }
129 
get_and_increment_execution_count()130   int64_t get_and_increment_execution_count() {
131     return execution_count_.fetch_add(1);
132   }
133 
134   // Turn RPC logging on or off, both at the WorkerCache used by this
135   // master process, and at each remote worker in use for the current
136   // partitions.
SetRPCLogging(bool active)137   void SetRPCLogging(bool active) {
138     worker_cache_->SetLogging(active);
139     // Logging is a best-effort activity, so we make async calls to turn
140     // it on/off and don't make use of the responses.
141     for (auto& p : partitions_) {
142       LoggingRequest* req = new LoggingRequest;
143       if (active) {
144         req->set_enable_rpc_logging(true);
145       } else {
146         req->set_disable_rpc_logging(true);
147       }
148       LoggingResponse* resp = new LoggingResponse;
149       Ref();
150       p.worker->LoggingAsync(req, resp, [this, req, resp](const Status& s) {
151         delete req;
152         delete resp;
153         // ReffedClientGraph owns p.worker so we need to hold a ref to
154         // ensure that the method doesn't attempt to access p.worker after
155         // ReffedClient graph has deleted it.
156         // TODO(suharshs): Simplify this ownership model.
157         Unref();
158       });
159     }
160   }
161 
162   // Retrieve all RPC logs data accumulated for the current step, both
163   // from the local WorkerCache in use by this master process and from
164   // all the remote workers executing the remote partitions.
RetrieveLogs(int64_t step_id,StepStats * ss)165   void RetrieveLogs(int64_t step_id, StepStats* ss) {
166     // Get the local data first, because it sets *ss without merging.
167     worker_cache_->RetrieveLogs(step_id, ss);
168 
169     // Then merge in data from all the remote workers.
170     LoggingRequest req;
171     req.add_fetch_step_id(step_id);
172     int waiting_for = partitions_.size();
173     if (waiting_for > 0) {
174       mutex scoped_mu;
175       BlockingCounter all_done(waiting_for);
176       for (auto& p : partitions_) {
177         LoggingResponse* resp = new LoggingResponse;
178         p.worker->LoggingAsync(
179             &req, resp,
180             [step_id, ss, resp, &scoped_mu, &all_done](const Status& s) {
181               {
182                 mutex_lock l(scoped_mu);
183                 if (s.ok()) {
184                   for (auto& lss : resp->step()) {
185                     if (step_id != lss.step_id()) {
186                       LOG(ERROR) << "Wrong step_id in LoggingResponse";
187                       continue;
188                     }
189                     ss->MergeFrom(lss.step_stats());
190                   }
191                 }
192                 delete resp;
193               }
194               // Must not decrement all_done until out of critical section where
195               // *ss is updated.
196               all_done.DecrementCount();
197             });
198       }
199       all_done.Wait();
200     }
201   }
202 
203   // Local execution methods.
204 
205   // Partitions the graph into subgraphs and registers them on
206   // workers.
207   Status RegisterPartitions(PartitionOptions popts);
208 
209   // Runs one step of all partitions.
210   Status RunPartitions(const MasterEnv* env, int64_t step_id,
211                        int64_t execution_count, PerStepState* pss,
212                        CallOptions* opts, const RunStepRequestWrapper& req,
213                        MutableRunStepResponseWrapper* resp,
214                        CancellationManager* cm, const bool is_last_partial_run);
215   Status RunPartitions(const MasterEnv* env, int64_t step_id,
216                        int64_t execution_count, PerStepState* pss,
217                        CallOptions* call_opts, const RunCallableRequest& req,
218                        RunCallableResponse* resp, CancellationManager* cm);
219 
220   // Calls workers to cleanup states for the step "step_id".  Calls
221   // `done` when all cleanup RPCs have completed.
222   void CleanupPartitionsAsync(int64_t step_id, StatusCallback done);
223 
224   // Post-processing of any runtime statistics gathered during execution.
225   void ProcessStats(int64_t step_id, PerStepState* pss, ProfileHandler* ph,
226                     const RunOptions& options, RunMetadata* resp);
227   void ProcessDeviceStats(ProfileHandler* ph, const DeviceStepStats& ds,
228                           bool is_rpc);
229   // Checks that the requested fetches can be computed from the provided feeds.
230   Status CheckFetches(const RunStepRequestWrapper& req,
231                       const RunState* run_state,
232                       GraphExecutionState* execution_state);
233 
234  private:
235   const string session_handle_;
236   const BuildGraphOptions bg_opts_;
237 
238   // NOTE(mrry): This pointer will be null after `RegisterPartitions()` returns.
239   std::unique_ptr<ClientGraph> client_graph_before_register_ TF_GUARDED_BY(mu_);
240   const SessionOptions session_opts_;
241   const bool is_partial_;
242   const CallableOptions callable_opts_;
243   WorkerCacheInterface* const worker_cache_;  // Not owned.
244 
245   struct NodeDetails {
NodeDetailstensorflow::MasterSession::ReffedClientGraph::NodeDetails246     explicit NodeDetails(string type_string, string detail_text)
247         : type_string(std::move(type_string)),
248           detail_text(std::move(detail_text)) {}
249     const string type_string;
250     const string detail_text;
251   };
252   std::unordered_map<string, NodeDetails> name_to_node_details_;
253 
254   const bool should_deregister_;
255   const int64_t collective_graph_key_;
256   std::atomic<int64_t> execution_count_ = {0};
257 
258   // Graph partitioned into per-location subgraphs.
259   struct Part {
260     // Worker name.
261     string name;
262 
263     // Maps feed names to rendezvous keys. Empty most of the time.
264     std::unordered_map<string, string> feed_key;
265 
266     // Maps rendezvous keys to fetch names. Empty most of the time.
267     std::unordered_map<string, string> key_fetch;
268 
269     // The interface to the worker. Owned.
270     WorkerInterface* worker = nullptr;
271 
272     // After registration with the worker, graph_handle identifies
273     // this partition on the worker.
274     string graph_handle;
275 
Parttensorflow::MasterSession::ReffedClientGraph::Part276     Part() : feed_key(3), key_fetch(3) {}
277   };
278 
279   // partitions_ is immutable after RegisterPartitions() call
280   // finishes.  RunPartitions() can access partitions_ safely without
281   // acquiring locks.
282   std::vector<Part> partitions_;
283 
284   mutable mutex mu_;
285 
286   // Partition initialization and registration only needs to happen
287   // once. `!client_graph_before_register_ && !init_done_.HasBeenNotified()`
288   // indicates the initialization is ongoing.
289   Notification init_done_;
290 
291   // init_result_ remembers the initialization error if any.
292   Status init_result_ TF_GUARDED_BY(mu_);
293 
294   std::unique_ptr<StatsPublisherInterface> stats_publisher_;
295 
DetailText(const NodeDetails & details,const NodeExecStats & stats)296   string DetailText(const NodeDetails& details, const NodeExecStats& stats) {
297     int64_t tot = 0;
298     for (auto& no : stats.output()) {
299       tot += no.tensor_description().allocation_description().requested_bytes();
300     }
301     string bytes;
302     if (tot >= 0.1 * 1048576.0) {
303       bytes = strings::Printf("[%.1fMB] ", tot / 1048576.0);
304     }
305     return strings::StrCat(bytes, stats.node_name(), " = ", details.type_string,
306                            details.detail_text);
307   }
308 
309   // Send/Recv nodes that are the result of client-added
310   // feeds and fetches must be tracked so that the tensors
311   // can be added to the local rendezvous.
312   static void TrackFeedsAndFetches(Part* part, const GraphDef& graph_def,
313                                    const PartitionOptions& popts);
314 
315   // The actual graph partitioning and registration implementation.
316   Status DoBuildPartitions(
317       PartitionOptions popts, ClientGraph* client_graph,
318       std::unordered_map<string, GraphDef>* out_partitions);
319   Status DoRegisterPartitions(
320       const PartitionOptions& popts,
321       std::unordered_map<string, GraphDef> graph_partitions);
322 
323   // Prepares a number of calls to workers. One call per partition.
324   // This is a generic method that handles Run, PartialRun, and RunCallable.
325   template <class FetchListType, class ClientRequestType,
326             class ClientResponseType>
327   Status RunPartitionsHelper(
328       const std::unordered_map<StringPiece, size_t, StringPieceHasher>& feeds,
329       const FetchListType& fetches, const MasterEnv* env, int64_t step_id,
330       int64_t execution_count, PerStepState* pss, CallOptions* call_opts,
331       const ClientRequestType& req, ClientResponseType* resp,
332       CancellationManager* cm, bool is_last_partial_run);
333 
334   // Deregisters the partitions on the workers.  Called in the
335   // destructor and does not wait for the rpc completion.
336   void DeregisterPartitions();
337 
338   TF_DISALLOW_COPY_AND_ASSIGN(ReffedClientGraph);
339 };
340 
RegisterPartitions(PartitionOptions popts)341 Status MasterSession::ReffedClientGraph::RegisterPartitions(
342     PartitionOptions popts) {
343   {  // Ensure register once.
344     mu_.lock();
345     if (client_graph_before_register_) {
346       // The `ClientGraph` is no longer needed after partitions are registered.
347       // Since it can account for a large amount of memory, we consume it here,
348       // and it will be freed after concluding with registration.
349 
350       std::unique_ptr<ClientGraph> client_graph;
351       std::swap(client_graph_before_register_, client_graph);
352       mu_.unlock();
353       std::unordered_map<string, GraphDef> graph_defs;
354       popts.flib_def = client_graph->flib_def.get();
355       Status s = DoBuildPartitions(popts, client_graph.get(), &graph_defs);
356       if (s.ok()) {
357         // NOTE(mrry): The pointers in `graph_defs_for_publishing` do not remain
358         // valid after the call to DoRegisterPartitions begins, so
359         // `stats_publisher_` must make a copy if it wants to retain the
360         // GraphDef objects.
361         std::vector<const GraphDef*> graph_defs_for_publishing;
362         graph_defs_for_publishing.reserve(partitions_.size());
363         for (const auto& name_def : graph_defs) {
364           graph_defs_for_publishing.push_back(&name_def.second);
365         }
366         stats_publisher_->PublishGraphProto(graph_defs_for_publishing);
367         s = DoRegisterPartitions(popts, std::move(graph_defs));
368       }
369       mu_.lock();
370       init_result_ = s;
371       init_done_.Notify();
372     } else {
373       mu_.unlock();
374       init_done_.WaitForNotification();
375       mu_.lock();
376     }
377     const Status result = init_result_;
378     mu_.unlock();
379     return result;
380   }
381 }
382 
SplitByWorker(const Node * node)383 static string SplitByWorker(const Node* node) {
384   string task;
385   string device;
386   CHECK(DeviceNameUtils::SplitDeviceName(node->assigned_device_name(), &task,
387                                          &device))
388       << "node: " << node->name() << " dev: " << node->assigned_device_name();
389   return task;
390 }
391 
TrackFeedsAndFetches(Part * part,const GraphDef & graph_def,const PartitionOptions & popts)392 void MasterSession::ReffedClientGraph::TrackFeedsAndFetches(
393     Part* part, const GraphDef& graph_def, const PartitionOptions& popts) {
394   for (int i = 0; i < graph_def.node_size(); ++i) {
395     const NodeDef& ndef = graph_def.node(i);
396     const bool is_recv = ndef.op() == "_Recv";
397     const bool is_send = ndef.op() == "_Send";
398 
399     if (is_recv || is_send) {
400       // Only send/recv nodes that were added as feeds and fetches
401       // (client-terminated) should be tracked.  Other send/recv nodes
402       // are for transferring data between partitions / memory spaces.
403       bool client_terminated;
404       TF_CHECK_OK(GetNodeAttr(ndef, "client_terminated", &client_terminated));
405       if (client_terminated) {
406         string name;
407         TF_CHECK_OK(GetNodeAttr(ndef, "tensor_name", &name));
408         string send_device;
409         TF_CHECK_OK(GetNodeAttr(ndef, "send_device", &send_device));
410         string recv_device;
411         TF_CHECK_OK(GetNodeAttr(ndef, "recv_device", &recv_device));
412         uint64 send_device_incarnation;
413         TF_CHECK_OK(
414             GetNodeAttr(ndef, "send_device_incarnation",
415                         reinterpret_cast<int64_t*>(&send_device_incarnation)));
416         const string& key =
417             Rendezvous::CreateKey(send_device, send_device_incarnation,
418                                   recv_device, name, FrameAndIter(0, 0));
419 
420         if (is_recv) {
421           part->feed_key.insert({name, key});
422         } else {
423           part->key_fetch.insert({key, name});
424         }
425       }
426     }
427   }
428 }
429 
DoBuildPartitions(PartitionOptions popts,ClientGraph * client_graph,std::unordered_map<string,GraphDef> * out_partitions)430 Status MasterSession::ReffedClientGraph::DoBuildPartitions(
431     PartitionOptions popts, ClientGraph* client_graph,
432     std::unordered_map<string, GraphDef>* out_partitions) {
433   if (popts.need_to_record_start_times) {
434     CostModel cost_model(true);
435     cost_model.InitFromGraph(client_graph->graph);
436     // TODO(yuanbyu): Use the real cost model.
437     // execution_state_->MergeFromGlobal(&cost_model);
438     SlackAnalysis sa(&client_graph->graph, &cost_model);
439     sa.ComputeAsap(&popts.start_times);
440   }
441 
442   // Partition the graph.
443   return Partition(popts, &client_graph->graph, out_partitions);
444 }
445 
DoRegisterPartitions(const PartitionOptions & popts,std::unordered_map<string,GraphDef> graph_partitions)446 Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
447     const PartitionOptions& popts,
448     std::unordered_map<string, GraphDef> graph_partitions) {
449   partitions_.reserve(graph_partitions.size());
450   Status s;
451   for (auto& name_def : graph_partitions) {
452     partitions_.emplace_back();
453     Part* part = &partitions_.back();
454     part->name = name_def.first;
455     TrackFeedsAndFetches(part, name_def.second, popts);
456     part->worker = worker_cache_->GetOrCreateWorker(part->name);
457     if (part->worker == nullptr) {
458       s = errors::NotFound("worker ", part->name);
459       break;
460     }
461   }
462   if (!s.ok()) {
463     for (Part& part : partitions_) {
464       worker_cache_->ReleaseWorker(part.name, part.worker);
465       part.worker = nullptr;
466     }
467     return s;
468   }
469   struct Call {
470     RegisterGraphRequest req;
471     RegisterGraphResponse resp;
472     Status status;
473   };
474   const int num = partitions_.size();
475   gtl::InlinedVector<Call, 4> calls(num);
476   BlockingCounter done(num);
477   for (int i = 0; i < num; ++i) {
478     const Part& part = partitions_[i];
479     Call* c = &calls[i];
480     c->req.set_session_handle(session_handle_);
481     c->req.set_create_worker_session_called(!should_deregister_);
482     c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]);
483     StripDefaultAttributes(*OpRegistry::Global(),
484                            c->req.mutable_graph_def()->mutable_node());
485     *c->req.mutable_config_proto() = session_opts_.config;
486     *c->req.mutable_graph_options() = session_opts_.config.graph_options();
487     *c->req.mutable_debug_options() =
488         callable_opts_.run_options().debug_options();
489     c->req.set_collective_graph_key(collective_graph_key_);
490     VLOG(2) << "Register " << c->req.graph_def().DebugString();
491     auto cb = [c, &done](const Status& s) {
492       c->status = s;
493       done.DecrementCount();
494     };
495     part.worker->RegisterGraphAsync(&c->req, &c->resp, cb);
496   }
497   done.Wait();
498   for (int i = 0; i < num; ++i) {
499     Call* c = &calls[i];
500     s.Update(c->status);
501     partitions_[i].graph_handle = c->resp.graph_handle();
502   }
503   return s;
504 }
505 
506 namespace {
507 // Helper class to manage "num" parallel RunGraph calls.
508 class RunManyGraphs {
509  public:
RunManyGraphs(int num)510   explicit RunManyGraphs(int num) : calls_(num), pending_(num) {}
511 
~RunManyGraphs()512   ~RunManyGraphs() {}
513 
514   // Returns the index-th call.
515   struct Call {
516     CallOptions opts;
517     const string* worker_name;
518     std::atomic<bool> done{false};
519     std::unique_ptr<MutableRunGraphRequestWrapper> req;
520     std::unique_ptr<MutableRunGraphResponseWrapper> resp;
521   };
get(int index)522   Call* get(int index) { return &calls_[index]; }
523 
524   // When the index-th call is done, updates the overall status.
WhenDone(int index,const Status & s)525   void WhenDone(int index, const Status& s) {
526     TRACEPRINTF("Partition %d %s", index, s.ToString().c_str());
527     Call* call = get(index);
528     call->done = true;
529     auto resp = call->resp.get();
530     if (resp->status_code() != error::Code::OK) {
531       // resp->status_code will only be non-OK if s.ok().
532       mutex_lock l(mu_);
533       Status resp_status = call->resp->status();
534       ReportBadStatus(errors::CreateWithUpdatedMessage(
535           resp_status, strings::StrCat("From ", *call->worker_name, ":\n",
536                                        resp_status.error_message())));
537     } else if (!s.ok()) {
538       mutex_lock l(mu_);
539       ReportBadStatus(errors::CreateWithUpdatedMessage(
540           s, strings::StrCat("From ", *call->worker_name, ":\n",
541                              s.error_message())));
542     }
543     pending_.DecrementCount();
544   }
545 
StartCancel()546   void StartCancel() {
547     mutex_lock l(mu_);
548     ReportBadStatus(errors::Cancelled("RunManyGraphs"));
549   }
550 
Wait()551   void Wait() {
552     // Check the error status every 60 seconds in other to print a log message
553     // in the event of a hang.
554     const std::chrono::milliseconds kCheckErrorPeriod(1000 * 60);
555     while (true) {
556       if (pending_.WaitFor(kCheckErrorPeriod)) {
557         return;
558       }
559       if (!status().ok()) {
560         break;
561       }
562     }
563 
564     // The step has failed. Wait for another 60 seconds before diagnosing a
565     // hang.
566     DCHECK(!status().ok());
567     if (pending_.WaitFor(kCheckErrorPeriod)) {
568       return;
569     }
570     LOG(ERROR)
571         << "RunStep still blocked after 60 seconds. Failed with error status: "
572         << status();
573     for (const Call& call : calls_) {
574       if (!call.done) {
575         LOG(ERROR) << "- No response from RunGraph call to worker: "
576                    << *call.worker_name;
577       }
578     }
579     pending_.Wait();
580   }
581 
status() const582   Status status() const {
583     mutex_lock l(mu_);
584     // Concat status objects in this StatusGroup to get the aggregated status,
585     // as each status in status_group_ is already summarized status.
586     return status_group_.as_concatenated_status();
587   }
588 
589  private:
590   gtl::InlinedVector<Call, 4> calls_;
591 
592   BlockingCounter pending_;
593   mutable mutex mu_;
594   StatusGroup status_group_ TF_GUARDED_BY(mu_);
595   bool cancel_issued_ TF_GUARDED_BY(mu_) = false;
596 
ReportBadStatus(const Status & s)597   void ReportBadStatus(const Status& s) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
598     VLOG(1) << "Master received error status " << s;
599     if (!cancel_issued_ && !StatusGroup::IsDerived(s)) {
600       // Only start cancelling other workers upon receiving a non-derived
601       // error
602       cancel_issued_ = true;
603 
604       VLOG(1) << "Master received error report. Cancelling remaining workers.";
605       for (Call& call : calls_) {
606         call.opts.StartCancel();
607       }
608     }
609 
610     status_group_.Update(s);
611   }
612 
613   TF_DISALLOW_COPY_AND_ASSIGN(RunManyGraphs);
614 };
615 
AddSendFromClientRequest(const RunStepRequestWrapper & client_req,MutableRunGraphRequestWrapper * worker_req,size_t index,const string & send_key)616 Status AddSendFromClientRequest(const RunStepRequestWrapper& client_req,
617                                 MutableRunGraphRequestWrapper* worker_req,
618                                 size_t index, const string& send_key) {
619   return worker_req->AddSendFromRunStepRequest(client_req, index, send_key);
620 }
621 
AddSendFromClientRequest(const RunCallableRequest & client_req,MutableRunGraphRequestWrapper * worker_req,size_t index,const string & send_key)622 Status AddSendFromClientRequest(const RunCallableRequest& client_req,
623                                 MutableRunGraphRequestWrapper* worker_req,
624                                 size_t index, const string& send_key) {
625   return worker_req->AddSendFromRunCallableRequest(client_req, index, send_key);
626 }
627 
628 // TODO(mrry): Add a full-fledged wrapper that avoids TensorProto copies for
629 // in-process messages.
630 struct RunCallableResponseWrapper {
631   RunCallableResponse* resp;  // Not owned.
632   std::unordered_map<string, TensorProto> fetch_key_to_protos;
633 
mutable_metadatatensorflow::__anon4350675b0411::RunCallableResponseWrapper634   RunMetadata* mutable_metadata() { return resp->mutable_metadata(); }
635 
AddTensorFromRunGraphResponsetensorflow::__anon4350675b0411::RunCallableResponseWrapper636   Status AddTensorFromRunGraphResponse(
637       const string& tensor_name, MutableRunGraphResponseWrapper* worker_resp,
638       size_t index) {
639     return worker_resp->RecvValue(index, &fetch_key_to_protos[tensor_name]);
640   }
641 };
642 }  // namespace
643 
644 template <class FetchListType, class ClientRequestType,
645           class ClientResponseType>
RunPartitionsHelper(const std::unordered_map<StringPiece,size_t,StringPieceHasher> & feeds,const FetchListType & fetches,const MasterEnv * env,int64_t step_id,int64_t execution_count,PerStepState * pss,CallOptions * call_opts,const ClientRequestType & req,ClientResponseType * resp,CancellationManager * cm,bool is_last_partial_run)646 Status MasterSession::ReffedClientGraph::RunPartitionsHelper(
647     const std::unordered_map<StringPiece, size_t, StringPieceHasher>& feeds,
648     const FetchListType& fetches, const MasterEnv* env, int64_t step_id,
649     int64_t execution_count, PerStepState* pss, CallOptions* call_opts,
650     const ClientRequestType& req, ClientResponseType* resp,
651     CancellationManager* cm, bool is_last_partial_run) {
652   // Collect execution cost stats on a smoothly decreasing frequency.
653   ExecutorOpts exec_opts;
654   if (pss->report_tensor_allocations_upon_oom) {
655     exec_opts.set_report_tensor_allocations_upon_oom(true);
656   }
657   if (pss->collect_costs) {
658     exec_opts.set_record_costs(true);
659   }
660   if (pss->collect_timeline) {
661     exec_opts.set_record_timeline(true);
662   }
663   if (pss->collect_rpcs) {
664     SetRPCLogging(true);
665   }
666   if (pss->collect_partition_graphs) {
667     exec_opts.set_record_partition_graphs(true);
668   }
669   if (pss->collect_costs || pss->collect_timeline) {
670     pss->step_stats.resize(partitions_.size());
671   }
672 
673   const int num = partitions_.size();
674   RunManyGraphs calls(num);
675 
676   for (int i = 0; i < num; ++i) {
677     const Part& part = partitions_[i];
678     RunManyGraphs::Call* c = calls.get(i);
679     c->worker_name = &part.name;
680     c->req.reset(part.worker->CreateRunGraphRequest());
681     c->resp.reset(part.worker->CreateRunGraphResponse());
682     if (is_partial_) {
683       c->req->set_is_partial(is_partial_);
684       c->req->set_is_last_partial_run(is_last_partial_run);
685     }
686     c->req->set_session_handle(session_handle_);
687     c->req->set_create_worker_session_called(!should_deregister_);
688     c->req->set_graph_handle(part.graph_handle);
689     c->req->set_step_id(step_id);
690     *c->req->mutable_exec_opts() = exec_opts;
691     c->req->set_store_errors_in_response_body(true);
692     c->req->set_request_id(GetUniqueRequestId());
693     // If any feeds are provided, send the feed values together
694     // in the RunGraph request.
695     // In the partial case, we only want to include feeds provided in the req.
696     // In the non-partial case, all feeds in the request are in the part.
697     // We keep these as separate paths for now, to ensure we aren't
698     // inadvertently slowing down the normal run path.
699     if (is_partial_) {
700       for (const auto& name_index : feeds) {
701         const auto iter = part.feed_key.find(string(name_index.first));
702         if (iter == part.feed_key.end()) {
703           // The provided feed must be for a different partition.
704           continue;
705         }
706         const string& key = iter->second;
707         TF_RETURN_IF_ERROR(AddSendFromClientRequest(req, c->req.get(),
708                                                     name_index.second, key));
709       }
710       // TODO(suharshs): Make a map from feed to fetch_key to make this faster.
711       // For now, we just iterate through partitions to find the matching key.
712       for (const string& req_fetch : fetches) {
713         for (const auto& key_fetch : part.key_fetch) {
714           if (key_fetch.second == req_fetch) {
715             c->req->add_recv_key(key_fetch.first);
716             break;
717           }
718         }
719       }
720     } else {
721       for (const auto& feed_key : part.feed_key) {
722         const string& feed = feed_key.first;
723         const string& key = feed_key.second;
724         auto iter = feeds.find(feed);
725         if (iter == feeds.end()) {
726           return errors::Internal("No feed index found for feed: ", feed);
727         }
728         const int64_t feed_index = iter->second;
729         TF_RETURN_IF_ERROR(
730             AddSendFromClientRequest(req, c->req.get(), feed_index, key));
731       }
732       for (const auto& key_fetch : part.key_fetch) {
733         const string& key = key_fetch.first;
734         c->req->add_recv_key(key);
735       }
736     }
737   }
738 
739   // Issues RunGraph calls.
740   for (int i = 0; i < num; ++i) {
741     const Part& part = partitions_[i];
742     RunManyGraphs::Call* call = calls.get(i);
743     TRACEPRINTF("Partition %d %s", i, part.name.c_str());
744     part.worker->RunGraphAsync(
745         &call->opts, call->req.get(), call->resp.get(),
746         std::bind(&RunManyGraphs::WhenDone, &calls, i, std::placeholders::_1));
747   }
748 
749   // Waits for the RunGraph calls.
750   call_opts->SetCancelCallback([&calls]() {
751     LOG(INFO) << "Client requested cancellation for RunStep, cancelling "
752                  "worker operations.";
753     calls.StartCancel();
754   });
755   auto token = cm->get_cancellation_token();
756   const bool success =
757       cm->RegisterCallback(token, [&calls]() { calls.StartCancel(); });
758   if (!success) {
759     calls.StartCancel();
760   }
761   calls.Wait();
762   call_opts->ClearCancelCallback();
763   if (success) {
764     cm->DeregisterCallback(token);
765   } else {
766     return errors::Cancelled("Step was cancelled");
767   }
768   TF_RETURN_IF_ERROR(calls.status());
769 
770   // Collects fetches and metadata.
771   Status status;
772   for (int i = 0; i < num; ++i) {
773     const Part& part = partitions_[i];
774     MutableRunGraphResponseWrapper* run_graph_resp = calls.get(i)->resp.get();
775     for (size_t j = 0; j < run_graph_resp->num_recvs(); ++j) {
776       auto iter = part.key_fetch.find(run_graph_resp->recv_key(j));
777       if (iter == part.key_fetch.end()) {
778         status.Update(errors::Internal("Unexpected fetch key: ",
779                                        run_graph_resp->recv_key(j)));
780         break;
781       }
782       const string& fetch = iter->second;
783       status.Update(
784           resp->AddTensorFromRunGraphResponse(fetch, run_graph_resp, j));
785       if (!status.ok()) {
786         break;
787       }
788     }
789     if (pss->collect_timeline) {
790       pss->step_stats[i].Swap(run_graph_resp->mutable_step_stats());
791     }
792     if (pss->collect_costs) {
793       CostGraphDef* cost_graph = run_graph_resp->mutable_cost_graph();
794       for (int j = 0; j < cost_graph->node_size(); ++j) {
795         resp->mutable_metadata()->mutable_cost_graph()->add_node()->Swap(
796             cost_graph->mutable_node(j));
797       }
798     }
799     if (pss->collect_partition_graphs) {
800       protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
801           resp->mutable_metadata()->mutable_partition_graphs();
802       for (size_t i = 0; i < run_graph_resp->num_partition_graphs(); i++) {
803         partition_graph_defs->Add()->Swap(
804             run_graph_resp->mutable_partition_graph(i));
805       }
806     }
807   }
808   return status;
809 }
810 
RunPartitions(const MasterEnv * env,int64_t step_id,int64_t execution_count,PerStepState * pss,CallOptions * call_opts,const RunStepRequestWrapper & req,MutableRunStepResponseWrapper * resp,CancellationManager * cm,const bool is_last_partial_run)811 Status MasterSession::ReffedClientGraph::RunPartitions(
812     const MasterEnv* env, int64_t step_id, int64_t execution_count,
813     PerStepState* pss, CallOptions* call_opts, const RunStepRequestWrapper& req,
814     MutableRunStepResponseWrapper* resp, CancellationManager* cm,
815     const bool is_last_partial_run) {
816   VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
817           << execution_count;
818   // Maps the names of fed tensors to their index in `req`.
819   std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3);
820   for (size_t i = 0; i < req.num_feeds(); ++i) {
821     if (!feeds.insert({req.feed_name(i), i}).second) {
822       return errors::InvalidArgument("Duplicated feeds: ", req.feed_name(i));
823     }
824   }
825 
826   std::vector<string> fetches;
827   fetches.reserve(req.num_fetches());
828   for (size_t i = 0; i < req.num_fetches(); ++i) {
829     fetches.push_back(req.fetch_name(i));
830   }
831 
832   return RunPartitionsHelper(feeds, fetches, env, step_id, execution_count, pss,
833                              call_opts, req, resp, cm, is_last_partial_run);
834 }
835 
RunPartitions(const MasterEnv * env,int64_t step_id,int64_t execution_count,PerStepState * pss,CallOptions * call_opts,const RunCallableRequest & req,RunCallableResponse * resp,CancellationManager * cm)836 Status MasterSession::ReffedClientGraph::RunPartitions(
837     const MasterEnv* env, int64_t step_id, int64_t execution_count,
838     PerStepState* pss, CallOptions* call_opts, const RunCallableRequest& req,
839     RunCallableResponse* resp, CancellationManager* cm) {
840   VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
841           << execution_count;
842   // Maps the names of fed tensors to their index in `req`.
843   std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3);
844   for (size_t i = 0, end = callable_opts_.feed_size(); i < end; ++i) {
845     if (!feeds.insert({callable_opts_.feed(i), i}).second) {
846       // MakeCallable will fail if there are two feeds with the same name.
847       return errors::Internal("Duplicated feeds in callable: ",
848                               callable_opts_.feed(i));
849     }
850   }
851 
852   // Create a wrapped response object to collect the fetched values and
853   // rearrange them for the RunCallableResponse.
854   RunCallableResponseWrapper wrapped_resp;
855   wrapped_resp.resp = resp;
856 
857   TF_RETURN_IF_ERROR(RunPartitionsHelper(
858       feeds, callable_opts_.fetch(), env, step_id, execution_count, pss,
859       call_opts, req, &wrapped_resp, cm, false /* is_last_partial_run */));
860 
861   // Collects fetches.
862   for (const string& fetch : callable_opts_.fetch()) {
863     TensorProto* fetch_proto = resp->mutable_fetch()->Add();
864     auto iter = wrapped_resp.fetch_key_to_protos.find(fetch);
865     if (iter == wrapped_resp.fetch_key_to_protos.end()) {
866       return errors::Internal("Worker did not return a value for fetch: ",
867                               fetch);
868     }
869     fetch_proto->Swap(&iter->second);
870   }
871   return OkStatus();
872 }
873 
874 namespace {
875 
876 class CleanupBroadcastHelper {
877  public:
CleanupBroadcastHelper(int64_t step_id,int num_calls,StatusCallback done)878   CleanupBroadcastHelper(int64_t step_id, int num_calls, StatusCallback done)
879       : resps_(num_calls), num_pending_(num_calls), done_(std::move(done)) {
880     req_.set_step_id(step_id);
881   }
882 
883   // Returns a non-owned pointer to a request buffer for all calls.
request()884   CleanupGraphRequest* request() { return &req_; }
885 
886   // Returns a non-owned pointer to a response buffer for the ith call.
response(int i)887   CleanupGraphResponse* response(int i) { return &resps_[i]; }
888 
889   // Called when the ith response is received.
call_done(int i,const Status & s)890   void call_done(int i, const Status& s) {
891     bool run_callback = false;
892     Status status_copy;
893     {
894       mutex_lock l(mu_);
895       status_.Update(s);
896       if (--num_pending_ == 0) {
897         run_callback = true;
898         status_copy = status_;
899       }
900     }
901     if (run_callback) {
902       done_(status_copy);
903       // This is the last call, so delete the helper object.
904       delete this;
905     }
906   }
907 
908  private:
909   // A single request shared between all workers.
910   CleanupGraphRequest req_;
911   // One response buffer for each worker.
912   gtl::InlinedVector<CleanupGraphResponse, 4> resps_;
913 
914   mutex mu_;
915   // Number of requests remaining to be collected.
916   int num_pending_ TF_GUARDED_BY(mu_);
917   // Aggregate status of the operation.
918   Status status_ TF_GUARDED_BY(mu_);
919   // Callback to be called when all operations complete.
920   StatusCallback done_;
921 
922   TF_DISALLOW_COPY_AND_ASSIGN(CleanupBroadcastHelper);
923 };
924 
925 }  // namespace
926 
CleanupPartitionsAsync(int64_t step_id,StatusCallback done)927 void MasterSession::ReffedClientGraph::CleanupPartitionsAsync(
928     int64_t step_id, StatusCallback done) {
929   const int num = partitions_.size();
930   // Helper object will be deleted when the final call completes.
931   CleanupBroadcastHelper* helper =
932       new CleanupBroadcastHelper(step_id, num, std::move(done));
933   for (int i = 0; i < num; ++i) {
934     const Part& part = partitions_[i];
935     part.worker->CleanupGraphAsync(
936         helper->request(), helper->response(i),
937         [helper, i](const Status& s) { helper->call_done(i, s); });
938   }
939 }
940 
ProcessStats(int64_t step_id,PerStepState * pss,ProfileHandler * ph,const RunOptions & options,RunMetadata * resp)941 void MasterSession::ReffedClientGraph::ProcessStats(int64_t step_id,
942                                                     PerStepState* pss,
943                                                     ProfileHandler* ph,
944                                                     const RunOptions& options,
945                                                     RunMetadata* resp) {
946   if (!pss->collect_costs && !pss->collect_timeline) return;
947 
948   // Out-of-band logging data is collected now, during post-processing.
949   if (pss->collect_timeline) {
950     SetRPCLogging(false);
951     RetrieveLogs(step_id, &pss->rpc_stats);
952   }
953   for (size_t i = 0; i < partitions_.size(); ++i) {
954     const StepStats& ss = pss->step_stats[i];
955     if (ph) {
956       for (const auto& ds : ss.dev_stats()) {
957         ProcessDeviceStats(ph, ds, false /*is_rpc*/);
958       }
959     }
960   }
961   if (ph) {
962     for (const auto& ds : pss->rpc_stats.dev_stats()) {
963       ProcessDeviceStats(ph, ds, true /*is_rpc*/);
964     }
965     ph->StepDone(pss->start_micros, pss->end_micros,
966                  Microseconds(0) /*cleanup_time*/, 0 /*total_runops*/,
967                  OkStatus());
968   }
969   // Assemble all stats for this timeline into a merged StepStats.
970   if (pss->collect_timeline) {
971     StepStats step_stats_proto;
972     step_stats_proto.Swap(&pss->rpc_stats);
973     for (size_t i = 0; i < partitions_.size(); ++i) {
974       step_stats_proto.MergeFrom(pss->step_stats[i]);
975       pss->step_stats[i].Clear();
976     }
977     pss->step_stats.clear();
978     // Copy the stats back, but only for on-demand profiling to avoid slowing
979     // down calls that trigger the automatic profiling.
980     if (options.trace_level() == RunOptions::FULL_TRACE) {
981       resp->mutable_step_stats()->Swap(&step_stats_proto);
982     } else {
983       // If FULL_TRACE, it can be fetched from Session API, no need for
984       // duplicated publishing.
985       stats_publisher_->PublishStatsProto(step_stats_proto);
986     }
987   }
988 }
989 
ProcessDeviceStats(ProfileHandler * ph,const DeviceStepStats & ds,bool is_rpc)990 void MasterSession::ReffedClientGraph::ProcessDeviceStats(
991     ProfileHandler* ph, const DeviceStepStats& ds, bool is_rpc) {
992   const string& dev_name = ds.device();
993   VLOG(1) << "Device " << dev_name << " reports stats for "
994           << ds.node_stats_size() << " nodes";
995   for (const auto& ns : ds.node_stats()) {
996     if (is_rpc) {
997       // We don't have access to a good Node pointer, so we rely on
998       // sufficient data being present in the NodeExecStats.
999       ph->RecordOneOp(dev_name, ns, true /*is_copy*/, "", ns.node_name(),
1000                       ns.timeline_label());
1001     } else {
1002       auto iter = name_to_node_details_.find(ns.node_name());
1003       const bool found_node_in_graph = iter != name_to_node_details_.end();
1004       if (!found_node_in_graph && ns.timeline_label().empty()) {
1005         // The counter incrementing is not thread-safe. But we don't really
1006         // care.
1007         // TODO(zhengxq): we should implement a LOG_FIRST_N and LOG_EVERY_N for
1008         // more general usage.
1009         static int log_counter = 0;
1010         if (log_counter < 10) {
1011           log_counter++;
1012           LOG(WARNING) << "Failed to find node " << ns.node_name()
1013                        << " for dev " << dev_name;
1014         }
1015         continue;
1016       }
1017       const string& optype =
1018           found_node_in_graph ? iter->second.type_string : ns.node_name();
1019       string details;
1020       if (!ns.timeline_label().empty()) {
1021         details = ns.timeline_label();
1022       } else if (found_node_in_graph) {
1023         details = DetailText(iter->second, ns);
1024       } else {
1025         // Leave details string empty
1026       }
1027       ph->RecordOneOp(dev_name, ns, false /*is_copy*/, ns.node_name(), optype,
1028                       details);
1029     }
1030   }
1031 }
1032 
1033 // TODO(suharshs): Merge with CheckFetches in DirectSession.
1034 // TODO(suharsh,mrry): Build a map from fetch target to set of feeds it depends
1035 // on once at setup time to prevent us from computing the dependencies
1036 // everytime.
CheckFetches(const RunStepRequestWrapper & req,const RunState * run_state,GraphExecutionState * execution_state)1037 Status MasterSession::ReffedClientGraph::CheckFetches(
1038     const RunStepRequestWrapper& req, const RunState* run_state,
1039     GraphExecutionState* execution_state) {
1040   // Build the set of pending feeds that we haven't seen.
1041   std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
1042   for (const auto& input : run_state->pending_inputs) {
1043     // Skip if already fed.
1044     if (input.second) continue;
1045     TensorId id(ParseTensorName(input.first));
1046     const Node* n = execution_state->get_node_by_name(string(id.first));
1047     if (n == nullptr) {
1048       return errors::NotFound("Feed ", input.first, ": not found");
1049     }
1050     pending_feeds.insert(id);
1051   }
1052   for (size_t i = 0; i < req.num_feeds(); ++i) {
1053     const TensorId id(ParseTensorName(req.feed_name(i)));
1054     pending_feeds.erase(id);
1055   }
1056 
1057   // Initialize the stack with the fetch nodes.
1058   std::vector<const Node*> stack;
1059   for (size_t i = 0; i < req.num_fetches(); ++i) {
1060     const string& fetch = req.fetch_name(i);
1061     const TensorId id(ParseTensorName(fetch));
1062     const Node* n = execution_state->get_node_by_name(string(id.first));
1063     if (n == nullptr) {
1064       return errors::NotFound("Fetch ", fetch, ": not found");
1065     }
1066     stack.push_back(n);
1067   }
1068 
1069   // Any tensor needed for fetches can't be in pending_feeds.
1070   // We need to use the original full graph from execution state.
1071   const Graph* graph = execution_state->full_graph();
1072   std::vector<bool> visited(graph->num_node_ids(), false);
1073   while (!stack.empty()) {
1074     const Node* n = stack.back();
1075     stack.pop_back();
1076 
1077     for (const Edge* in_edge : n->in_edges()) {
1078       const Node* in_node = in_edge->src();
1079       if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) {
1080         return errors::InvalidArgument("Fetch ", in_node->name(), ":",
1081                                        in_edge->src_output(),
1082                                        " can't be computed from the feeds"
1083                                        " that have been fed so far.");
1084       }
1085       if (!visited[in_node->id()]) {
1086         visited[in_node->id()] = true;
1087         stack.push_back(in_node);
1088       }
1089     }
1090   }
1091   return OkStatus();
1092 }
1093 
1094 // Asynchronously deregisters subgraphs on the workers, without waiting for the
1095 // result.
DeregisterPartitions()1096 void MasterSession::ReffedClientGraph::DeregisterPartitions() {
1097   struct Call {
1098     DeregisterGraphRequest req;
1099     DeregisterGraphResponse resp;
1100   };
1101   for (Part& part : partitions_) {
1102     // The graph handle may be empty if we failed during partition registration.
1103     if (!part.graph_handle.empty()) {
1104       Call* c = new Call;
1105       c->req.set_session_handle(session_handle_);
1106       c->req.set_create_worker_session_called(!should_deregister_);
1107       c->req.set_graph_handle(part.graph_handle);
1108       // NOTE(mrry): We must capture `worker_cache_` since `this`
1109       // could be deleted before the callback is called.
1110       WorkerCacheInterface* worker_cache = worker_cache_;
1111       const string name = part.name;
1112       WorkerInterface* w = part.worker;
1113       CHECK_NOTNULL(w);
1114       auto cb = [worker_cache, c, name, w](const Status& s) {
1115         if (!s.ok()) {
1116           // This error is potentially benign, so we don't log at the
1117           // error level.
1118           LOG(INFO) << "DeregisterGraph error: " << s;
1119         }
1120         delete c;
1121         worker_cache->ReleaseWorker(name, w);
1122       };
1123       w->DeregisterGraphAsync(&c->req, &c->resp, cb);
1124     }
1125   }
1126 }
1127 
1128 namespace {
CopyAndSortStrings(size_t size,const std::function<string (size_t)> & input_accessor,protobuf::RepeatedPtrField<string> * output)1129 void CopyAndSortStrings(size_t size,
1130                         const std::function<string(size_t)>& input_accessor,
1131                         protobuf::RepeatedPtrField<string>* output) {
1132   std::vector<string> temp;
1133   temp.reserve(size);
1134   for (size_t i = 0; i < size; ++i) {
1135     output->Add(input_accessor(i));
1136   }
1137   std::sort(output->begin(), output->end());
1138 }
1139 }  // namespace
1140 
BuildBuildGraphOptions(const RunStepRequestWrapper & req,const ConfigProto & config,BuildGraphOptions * opts)1141 void BuildBuildGraphOptions(const RunStepRequestWrapper& req,
1142                             const ConfigProto& config,
1143                             BuildGraphOptions* opts) {
1144   CallableOptions* callable_opts = &opts->callable_options;
1145   CopyAndSortStrings(
1146       req.num_feeds(), [&req](size_t i) { return req.feed_name(i); },
1147       callable_opts->mutable_feed());
1148   CopyAndSortStrings(
1149       req.num_fetches(), [&req](size_t i) { return req.fetch_name(i); },
1150       callable_opts->mutable_fetch());
1151   CopyAndSortStrings(
1152       req.num_targets(), [&req](size_t i) { return req.target_name(i); },
1153       callable_opts->mutable_target());
1154 
1155   if (!req.options().debug_options().debug_tensor_watch_opts().empty()) {
1156     *callable_opts->mutable_run_options()->mutable_debug_options() =
1157         req.options().debug_options();
1158   }
1159 
1160   opts->collective_graph_key =
1161       req.options().experimental().collective_graph_key();
1162   if (config.experimental().collective_deterministic_sequential_execution()) {
1163     opts->collective_order = GraphCollectiveOrder::kEdges;
1164   } else if (config.experimental().collective_nccl()) {
1165     opts->collective_order = GraphCollectiveOrder::kAttrs;
1166   }
1167 }
1168 
BuildBuildGraphOptions(const PartialRunSetupRequest & req,BuildGraphOptions * opts)1169 void BuildBuildGraphOptions(const PartialRunSetupRequest& req,
1170                             BuildGraphOptions* opts) {
1171   CallableOptions* callable_opts = &opts->callable_options;
1172   CopyAndSortStrings(
1173       req.feed_size(), [&req](size_t i) { return req.feed(i); },
1174       callable_opts->mutable_feed());
1175   CopyAndSortStrings(
1176       req.fetch_size(), [&req](size_t i) { return req.fetch(i); },
1177       callable_opts->mutable_fetch());
1178   CopyAndSortStrings(
1179       req.target_size(), [&req](size_t i) { return req.target(i); },
1180       callable_opts->mutable_target());
1181 
1182   // TODO(cais): Add TFDBG support to partial runs.
1183 }
1184 
HashBuildGraphOptions(const BuildGraphOptions & opts)1185 uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) {
1186   uint64 h = 0x2b992ddfa23249d6ull;
1187   for (const string& name : opts.callable_options.feed()) {
1188     h = Hash64(name.c_str(), name.size(), h);
1189   }
1190   for (const string& name : opts.callable_options.target()) {
1191     h = Hash64(name.c_str(), name.size(), h);
1192   }
1193   for (const string& name : opts.callable_options.fetch()) {
1194     h = Hash64(name.c_str(), name.size(), h);
1195   }
1196 
1197   const DebugOptions& debug_options =
1198       opts.callable_options.run_options().debug_options();
1199   if (!debug_options.debug_tensor_watch_opts().empty()) {
1200     const string watch_summary =
1201         SummarizeDebugTensorWatches(debug_options.debug_tensor_watch_opts());
1202     h = Hash64(watch_summary.c_str(), watch_summary.size(), h);
1203   }
1204 
1205   return h;
1206 }
1207 
BuildGraphOptionsString(const BuildGraphOptions & opts)1208 string BuildGraphOptionsString(const BuildGraphOptions& opts) {
1209   string buf;
1210   for (const string& name : opts.callable_options.feed()) {
1211     strings::StrAppend(&buf, " FdE: ", name);
1212   }
1213   strings::StrAppend(&buf, "\n");
1214   for (const string& name : opts.callable_options.target()) {
1215     strings::StrAppend(&buf, " TN: ", name);
1216   }
1217   strings::StrAppend(&buf, "\n");
1218   for (const string& name : opts.callable_options.fetch()) {
1219     strings::StrAppend(&buf, " FeE: ", name);
1220   }
1221   if (opts.collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey) {
1222     strings::StrAppend(&buf, "\nGK: ", opts.collective_graph_key);
1223   }
1224   strings::StrAppend(&buf, "\n");
1225   return buf;
1226 }
1227 
MasterSession(const SessionOptions & opt,const MasterEnv * env,std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,std::unique_ptr<WorkerCacheInterface> worker_cache,std::unique_ptr<DeviceSet> device_set,std::vector<string> filtered_worker_list,StatsPublisherFactory stats_publisher_factory)1228 MasterSession::MasterSession(
1229     const SessionOptions& opt, const MasterEnv* env,
1230     std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
1231     std::unique_ptr<WorkerCacheInterface> worker_cache,
1232     std::unique_ptr<DeviceSet> device_set,
1233     std::vector<string> filtered_worker_list,
1234     StatsPublisherFactory stats_publisher_factory)
1235     : session_opts_(opt),
1236       env_(env),
1237       handle_(strings::FpToString(random::New64())),
1238       remote_devs_(std::move(remote_devs)),
1239       worker_cache_(std::move(worker_cache)),
1240       devices_(std::move(device_set)),
1241       filtered_worker_list_(std::move(filtered_worker_list)),
1242       stats_publisher_factory_(std::move(stats_publisher_factory)),
1243       graph_version_(0),
1244       run_graphs_(5),
1245       partial_run_graphs_(5) {
1246   UpdateLastAccessTime();
1247   CHECK(devices_) << "device_set was null!";
1248 
1249   VLOG(1) << "Session " << handle_ << " #local " << env->local_devices.size()
1250           << " #remote " << remote_devs_->size();
1251   VLOG(1) << "Start master session " << handle_
1252           << " with config: " << session_opts_.config.ShortDebugString();
1253 }
1254 
~MasterSession()1255 MasterSession::~MasterSession() {
1256   for (const auto& iter : run_graphs_) iter.second->Unref();
1257   for (const auto& iter : partial_run_graphs_) iter.second->Unref();
1258 }
1259 
UpdateLastAccessTime()1260 void MasterSession::UpdateLastAccessTime() {
1261   last_access_time_usec_.store(Env::Default()->NowMicros());
1262 }
1263 
Create(GraphDef && graph_def,const ClusterDef & cluster_def)1264 Status MasterSession::Create(GraphDef&& graph_def,
1265                              const ClusterDef& cluster_def) {
1266   if (session_opts_.config.use_per_session_threads() ||
1267       session_opts_.config.session_inter_op_thread_pool_size() > 0) {
1268     return errors::InvalidArgument(
1269         "Distributed session does not support session thread pool options.");
1270   }
1271   if (session_opts_.config.graph_options().place_pruned_graph()) {
1272     // TODO(b/29900832): Fix this or remove the option.
1273     LOG(WARNING) << "Distributed session does not support the "
1274                     "place_pruned_graph option.";
1275     session_opts_.config.mutable_graph_options()->set_place_pruned_graph(false);
1276   }
1277 
1278   GraphExecutionStateOptions execution_options;
1279   execution_options.device_set = devices_.get();
1280   execution_options.session_options = &session_opts_;
1281   {
1282     mutex_lock l(mu_);
1283     TF_RETURN_IF_ERROR(GraphExecutionState::MakeForBaseGraph(
1284         std::move(graph_def), execution_options, &execution_state_));
1285   }
1286   should_delete_worker_sessions_ = true;
1287   return CreateWorkerSessions(cluster_def);
1288 }
1289 
CreateWorkerSessions(const ClusterDef & cluster_def)1290 Status MasterSession::CreateWorkerSessions(const ClusterDef& cluster_def) {
1291   const std::vector<string> worker_names = filtered_worker_list_;
1292   WorkerCacheInterface* worker_cache = get_worker_cache();
1293 
1294   struct WorkerGroup {
1295     // The worker name. (Not owned.)
1296     const string* name;
1297 
1298     // The worker referenced by name. (Not owned.)
1299     WorkerInterface* worker = nullptr;
1300 
1301     // Request and responses used for a given worker.
1302     CreateWorkerSessionRequest request;
1303     CreateWorkerSessionResponse response;
1304     Status status = OkStatus();
1305   };
1306   BlockingCounter done(worker_names.size());
1307   std::vector<WorkerGroup> workers(worker_names.size());
1308 
1309   // Release the workers.
1310   auto cleanup = gtl::MakeCleanup([&workers, worker_cache] {
1311     for (auto&& worker_group : workers) {
1312       if (worker_group.worker != nullptr) {
1313         worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker);
1314       }
1315     }
1316   });
1317 
1318   string task_name;
1319   string local_device_name;
1320   DeviceNameUtils::SplitDeviceName(devices_->client_device()->name(),
1321                                    &task_name, &local_device_name);
1322   const int64_t client_device_incarnation =
1323       devices_->client_device()->attributes().incarnation();
1324 
1325   Status status = OkStatus();
1326   // Create all the workers & kick off the computations.
1327   for (size_t i = 0; i < worker_names.size(); ++i) {
1328     workers[i].name = &worker_names[i];
1329     workers[i].worker = worker_cache->GetOrCreateWorker(worker_names[i]);
1330     workers[i].request.set_session_handle(handle_);
1331     workers[i].request.set_master_task(task_name);
1332     workers[i].request.set_master_incarnation(client_device_incarnation);
1333     if (session_opts_.config.share_cluster_devices_in_session() ||
1334         session_opts_.config.experimental()
1335             .share_cluster_devices_in_session()) {
1336       for (const auto& remote_dev : devices_->devices()) {
1337         *workers[i].request.add_cluster_device_attributes() =
1338             remote_dev->attributes();
1339       }
1340 
1341       if (!session_opts_.config.share_cluster_devices_in_session() &&
1342           session_opts_.config.experimental()
1343               .share_cluster_devices_in_session()) {
1344         LOG(WARNING)
1345             << "ConfigProto.Experimental.share_cluster_devices_in_session has "
1346                "been promoted to a non-experimental API. Please use "
1347                "ConfigProto.share_cluster_devices_in_session instead. The "
1348                "experimental option will be removed in the future.";
1349       }
1350     }
1351 
1352     DeviceNameUtils::ParsedName name;
1353     if (!DeviceNameUtils::ParseFullName(worker_names[i], &name)) {
1354       status = errors::Internal("Could not parse name ", worker_names[i]);
1355       LOG(WARNING) << status;
1356       return status;
1357     }
1358     if (!name.has_job || !name.has_task) {
1359       status = errors::Internal("Incomplete worker name ", worker_names[i]);
1360       LOG(WARNING) << status;
1361       return status;
1362     }
1363 
1364     workers[i].request.mutable_server_def()->set_protocol("grpc");
1365     workers[i].request.mutable_server_def()->set_job_name(name.job);
1366     workers[i].request.mutable_server_def()->set_task_index(name.task);
1367     if (!cluster_def.job().empty()) {
1368       *workers[i].request.mutable_server_def()->mutable_cluster() = cluster_def;
1369       // Session state is always isolated when ClusterSpec propagation
1370       // is in use.
1371       workers[i].request.set_isolate_session_state(true);
1372     } else {
1373       // NOTE(mrry): Do not set any component of the ServerDef,
1374       // because the worker will use its local configuration.
1375       workers[i].request.set_isolate_session_state(
1376           session_opts_.config.isolate_session_state());
1377     }
1378     CoordinationServiceConfig coordination_config;
1379     // Enable coordination service in session options by default if
1380     // unspecified in non-local targets.
1381     if (session_opts_.target != "local" &&
1382         !session_opts_.config.experimental().has_coordination_config()) {
1383       coordination_config.set_service_type("standalone");
1384     } else {
1385       coordination_config =
1386           session_opts_.config.experimental().coordination_config();
1387     }
1388     // Specify master task as coordination service leader.
1389     coordination_config.set_service_leader(task_name);
1390     *workers[i]
1391          .request.mutable_server_def()
1392          ->mutable_default_session_config()
1393          ->mutable_experimental()
1394          ->mutable_coordination_config() = coordination_config;
1395 
1396     if (session_opts_.config.experimental()
1397             .share_session_state_in_clusterspec_propagation()) {
1398       // In a dynamic cluster, the ClusterSpec info is usually propagated by
1399       // master sessions. However, in data parallel training with multiple
1400       // masters
1401       // ("between-graph replication"), we need to disable isolation for
1402       // different worker sessions to update the same variables in PS tasks.
1403       workers[i].request.set_isolate_session_state(false);
1404     }
1405   }
1406 
1407   for (size_t i = 0; i < worker_names.size(); ++i) {
1408     auto cb = [i, &workers, &done](const Status& s) {
1409       workers[i].status = s;
1410       done.DecrementCount();
1411     };
1412     workers[i].worker->CreateWorkerSessionAsync(&workers[i].request,
1413                                                 &workers[i].response, cb);
1414   }
1415 
1416   done.Wait();
1417   for (size_t i = 0; i < workers.size(); ++i) {
1418     status.Update(workers[i].status);
1419   }
1420   return status;
1421 }
1422 
DeleteWorkerSessions()1423 Status MasterSession::DeleteWorkerSessions() {
1424   WorkerCacheInterface* worker_cache = get_worker_cache();
1425   const std::vector<string>& worker_names = filtered_worker_list_;
1426 
1427   struct WorkerGroup {
1428     // The worker name. (Not owned.)
1429     const string* name;
1430 
1431     // The worker referenced by name. (Not owned.)
1432     WorkerInterface* worker = nullptr;
1433 
1434     CallOptions call_opts;
1435 
1436     // Request and responses used for a given worker.
1437     DeleteWorkerSessionRequest request;
1438     DeleteWorkerSessionResponse response;
1439     Status status = OkStatus();
1440   };
1441   BlockingCounter done(worker_names.size());
1442   std::vector<WorkerGroup> workers(worker_names.size());
1443 
1444   // Release the workers.
1445   auto cleanup = gtl::MakeCleanup([&workers, worker_cache] {
1446     for (auto&& worker_group : workers) {
1447       if (worker_group.worker != nullptr) {
1448         worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker);
1449       }
1450     }
1451   });
1452 
1453   Status status = OkStatus();
1454   // Create all the workers & kick off the computations.
1455   for (size_t i = 0; i < worker_names.size(); ++i) {
1456     workers[i].name = &worker_names[i];
1457     workers[i].worker = worker_cache->GetOrCreateWorker(worker_names[i]);
1458     workers[i].request.set_session_handle(handle_);
1459     // Since the worker may have gone away, set a timeout to avoid blocking the
1460     // session-close operation.
1461     workers[i].call_opts.SetTimeout(10000);
1462   }
1463 
1464   for (size_t i = 0; i < worker_names.size(); ++i) {
1465     auto cb = [i, &workers, &done](const Status& s) {
1466       workers[i].status = s;
1467       done.DecrementCount();
1468     };
1469     workers[i].worker->DeleteWorkerSessionAsync(
1470         &workers[i].call_opts, &workers[i].request, &workers[i].response, cb);
1471   }
1472 
1473   done.Wait();
1474   for (size_t i = 0; i < workers.size(); ++i) {
1475     status.Update(workers[i].status);
1476   }
1477   return status;
1478 }
1479 
ListDevices(ListDevicesResponse * resp) const1480 Status MasterSession::ListDevices(ListDevicesResponse* resp) const {
1481   if (worker_cache_) {
1482     // This is a ClusterSpec-propagated session, and thus env_->local_devices
1483     // are invalid.
1484 
1485     // Mark the "client_device" as the sole local device.
1486     const Device* client_device = devices_->client_device();
1487     for (const Device* dev : devices_->devices()) {
1488       if (dev != client_device) {
1489         *(resp->add_remote_device()) = dev->attributes();
1490       }
1491     }
1492     *(resp->add_local_device()) = client_device->attributes();
1493   } else {
1494     for (Device* dev : env_->local_devices) {
1495       *(resp->add_local_device()) = dev->attributes();
1496     }
1497     for (auto&& dev : *remote_devs_) {
1498       *(resp->add_local_device()) = dev->attributes();
1499     }
1500   }
1501   return OkStatus();
1502 }
1503 
Extend(const ExtendSessionRequest * req,ExtendSessionResponse * resp)1504 Status MasterSession::Extend(const ExtendSessionRequest* req,
1505                              ExtendSessionResponse* resp) {
1506   UpdateLastAccessTime();
1507   std::unique_ptr<GraphExecutionState> extended_execution_state;
1508   {
1509     mutex_lock l(mu_);
1510     if (closed_) {
1511       return errors::FailedPrecondition("Session is closed.");
1512     }
1513 
1514     if (graph_version_ != req->current_graph_version()) {
1515       return errors::Aborted("Current version is ", graph_version_,
1516                              " but caller expected ",
1517                              req->current_graph_version(), ".");
1518     }
1519 
1520     CHECK(execution_state_);
1521     TF_RETURN_IF_ERROR(
1522         execution_state_->Extend(req->graph_def(), &extended_execution_state));
1523 
1524     CHECK(extended_execution_state);
1525     // The old execution state will be released outside the lock.
1526     execution_state_.swap(extended_execution_state);
1527     ++graph_version_;
1528     resp->set_new_graph_version(graph_version_);
1529   }
1530   return OkStatus();
1531 }
1532 
get_worker_cache() const1533 WorkerCacheInterface* MasterSession::get_worker_cache() const {
1534   if (worker_cache_) {
1535     return worker_cache_.get();
1536   }
1537   return env_->worker_cache;
1538 }
1539 
StartStep(const BuildGraphOptions & opts,bool is_partial,ReffedClientGraph ** out_rcg,int64_t * out_count)1540 Status MasterSession::StartStep(const BuildGraphOptions& opts, bool is_partial,
1541                                 ReffedClientGraph** out_rcg,
1542                                 int64_t* out_count) {
1543   const uint64 hash = HashBuildGraphOptions(opts);
1544   {
1545     mutex_lock l(mu_);
1546     // TODO(suharshs): We cache partial run graphs and run graphs separately
1547     // because there is preprocessing that needs to only be run for partial
1548     // run calls.
1549     RCGMap* m = is_partial ? &partial_run_graphs_ : &run_graphs_;
1550     auto iter = m->find(hash);
1551     if (iter == m->end()) {
1552       // We have not seen this subgraph before. Build the subgraph and
1553       // cache it.
1554       VLOG(1) << "Unseen hash " << hash << " for "
1555               << BuildGraphOptionsString(opts) << " is_partial = " << is_partial
1556               << "\n";
1557       std::unique_ptr<ClientGraph> client_graph;
1558       TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
1559       WorkerCacheInterface* worker_cache = get_worker_cache();
1560       auto entry = new ReffedClientGraph(
1561           handle_, opts, std::move(client_graph), session_opts_,
1562           stats_publisher_factory_, is_partial, worker_cache,
1563           !should_delete_worker_sessions_);
1564       iter = m->insert({hash, entry}).first;
1565       VLOG(1) << "Preparing to execute new graph";
1566     }
1567     *out_rcg = iter->second;
1568     (*out_rcg)->Ref();
1569     *out_count = (*out_rcg)->get_and_increment_execution_count();
1570   }
1571   return OkStatus();
1572 }
1573 
ClearRunsTable(std::vector<ReffedClientGraph * > * to_unref,RCGMap * rcg_map)1574 void MasterSession::ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
1575                                    RCGMap* rcg_map) {
1576   VLOG(1) << "Discarding all reffed graphs";
1577   for (auto p : *rcg_map) {
1578     ReffedClientGraph* rcg = p.second;
1579     if (to_unref) {
1580       to_unref->push_back(rcg);
1581     } else {
1582       rcg->Unref();
1583     }
1584   }
1585   rcg_map->clear();
1586 }
1587 
NewStepId(int64_t graph_key)1588 uint64 MasterSession::NewStepId(int64_t graph_key) {
1589   if (graph_key == BuildGraphOptions::kNoCollectiveGraphKey) {
1590     // StepId must leave the most-significant 7 bits empty for future use.
1591     return random::New64() & (((1uLL << 56) - 1) | (1uLL << 56));
1592   } else {
1593     uint64 step_id = env_->collective_executor_mgr->NextStepId(graph_key);
1594     int32_t retry_count = 0;
1595     while (static_cast<int64_t>(step_id) == CollectiveExecutor::kInvalidId) {
1596       Notification note;
1597       Status status;
1598       env_->collective_executor_mgr->RefreshStepIdSequenceAsync(
1599           graph_key, [&status, &note](const Status& s) {
1600             status = s;
1601             note.Notify();
1602           });
1603       note.WaitForNotification();
1604       if (!status.ok()) {
1605         LOG(ERROR) << "Bad status from "
1606                       "collective_executor_mgr->RefreshStepIdSequence: "
1607                    << status << ".  Retrying.";
1608         int64_t delay_micros = std::min(60000000LL, 1000000LL * ++retry_count);
1609         Env::Default()->SleepForMicroseconds(delay_micros);
1610       } else {
1611         step_id = env_->collective_executor_mgr->NextStepId(graph_key);
1612       }
1613     }
1614     return step_id;
1615   }
1616 }
1617 
PartialRunSetup(const PartialRunSetupRequest * req,PartialRunSetupResponse * resp)1618 Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req,
1619                                       PartialRunSetupResponse* resp) {
1620   std::vector<string> inputs, outputs, targets;
1621   for (const auto& feed : req->feed()) {
1622     inputs.push_back(feed);
1623   }
1624   for (const auto& fetch : req->fetch()) {
1625     outputs.push_back(fetch);
1626   }
1627   for (const auto& target : req->target()) {
1628     targets.push_back(target);
1629   }
1630 
1631   string handle = std::to_string(partial_run_handle_counter_.fetch_add(1));
1632 
1633   ReffedClientGraph* rcg = nullptr;
1634 
1635   // Prepare.
1636   BuildGraphOptions opts;
1637   BuildBuildGraphOptions(*req, &opts);
1638   int64_t count = 0;
1639   TF_RETURN_IF_ERROR(StartStep(opts, true, &rcg, &count));
1640 
1641   rcg->Ref();
1642   RunState* run_state =
1643       new RunState(inputs, outputs, rcg,
1644                    NewStepId(BuildGraphOptions::kNoCollectiveGraphKey), count);
1645   {
1646     mutex_lock l(mu_);
1647     partial_runs_.emplace(
1648         std::make_pair(handle, std::unique_ptr<RunState>(run_state)));
1649   }
1650 
1651   TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg));
1652 
1653   resp->set_partial_run_handle(handle);
1654   return OkStatus();
1655 }
1656 
Run(CallOptions * opts,const RunStepRequestWrapper & req,MutableRunStepResponseWrapper * resp)1657 Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req,
1658                           MutableRunStepResponseWrapper* resp) {
1659   UpdateLastAccessTime();
1660   {
1661     mutex_lock l(mu_);
1662     if (closed_) {
1663       return errors::FailedPrecondition("Session is closed.");
1664     }
1665     ++num_running_;
1666     // Note: all code paths must eventually call MarkRunCompletion()
1667     // in order to appropriate decrement the num_running_ counter.
1668   }
1669   Status status;
1670   if (!req.partial_run_handle().empty()) {
1671     status = DoPartialRun(opts, req, resp);
1672   } else {
1673     status = DoRunWithLocalExecution(opts, req, resp);
1674   }
1675   return status;
1676 }
1677 
1678 // Decrements num_running_ and broadcasts if num_running_ is zero.
MarkRunCompletion()1679 void MasterSession::MarkRunCompletion() {
1680   mutex_lock l(mu_);
1681   --num_running_;
1682   if (num_running_ == 0) {
1683     num_running_is_zero_.notify_all();
1684   }
1685 }
1686 
BuildAndRegisterPartitions(ReffedClientGraph * rcg)1687 Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
1688   // Registers subgraphs if haven't done so.
1689   PartitionOptions popts;
1690   popts.node_to_loc = SplitByWorker;
1691   // The closures popts.{new_name,get_incarnation} are called synchronously in
1692   // RegisterPartitions() below, so do not need a Ref()/Unref() pair to keep
1693   // "this" alive during the closure.
1694   popts.new_name = [this](const string& prefix) {
1695     mutex_lock l(mu_);
1696     return strings::StrCat(prefix, "_S", next_node_id_++);
1697   };
1698   popts.get_incarnation = [this](const string& name) -> int64 {
1699     Device* d = devices_->FindDeviceByName(name);
1700     if (d == nullptr) {
1701       return PartitionOptions::kIllegalIncarnation;
1702     } else {
1703       return d->attributes().incarnation();
1704     }
1705   };
1706   popts.control_flow_added = false;
1707   const bool enable_bfloat16_sendrecv =
1708       session_opts_.config.graph_options().enable_bfloat16_sendrecv();
1709   popts.should_cast = [enable_bfloat16_sendrecv](const Edge* e) {
1710     if (e->IsControlEdge()) {
1711       return DT_FLOAT;
1712     }
1713     DataType dtype = BaseType(e->src()->output_type(e->src_output()));
1714     if (enable_bfloat16_sendrecv && dtype == DT_FLOAT) {
1715       return DT_BFLOAT16;
1716     } else {
1717       return dtype;
1718     }
1719   };
1720   if (session_opts_.config.graph_options().enable_recv_scheduling()) {
1721     popts.scheduling_for_recvs = true;
1722     popts.need_to_record_start_times = true;
1723   }
1724 
1725   TF_RETURN_IF_ERROR(rcg->RegisterPartitions(std::move(popts)));
1726 
1727   return OkStatus();
1728 }
1729 
DoPartialRun(CallOptions * opts,const RunStepRequestWrapper & req,MutableRunStepResponseWrapper * resp)1730 Status MasterSession::DoPartialRun(CallOptions* opts,
1731                                    const RunStepRequestWrapper& req,
1732                                    MutableRunStepResponseWrapper* resp) {
1733   auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
1734   const string& prun_handle = req.partial_run_handle();
1735   RunState* run_state = nullptr;
1736   {
1737     mutex_lock l(mu_);
1738     auto it = partial_runs_.find(prun_handle);
1739     if (it == partial_runs_.end()) {
1740       return errors::InvalidArgument(
1741           "Must run PartialRunSetup before performing partial runs");
1742     }
1743     run_state = it->second.get();
1744   }
1745   // CollectiveOps are not supported in partial runs.
1746   if (req.options().experimental().collective_graph_key() !=
1747       BuildGraphOptions::kNoCollectiveGraphKey) {
1748     return errors::InvalidArgument(
1749         "PartialRun does not support Collective ops.  collective_graph_key "
1750         "must be kNoCollectiveGraphKey.");
1751   }
1752 
1753   // If this is the first partial run, initialize the PerStepState.
1754   if (!run_state->step_started) {
1755     run_state->step_started = true;
1756     PerStepState pss;
1757 
1758     const auto count = run_state->count;
1759     pss.collect_timeline =
1760         req.options().trace_level() == RunOptions::FULL_TRACE;
1761     pss.collect_rpcs = req.options().trace_level() == RunOptions::FULL_TRACE;
1762     pss.report_tensor_allocations_upon_oom =
1763         req.options().report_tensor_allocations_upon_oom();
1764 
1765     // Build the cost model every 'build_cost_model_every' steps after skipping
1766     // an
1767     // initial 'build_cost_model_after' steps.
1768     const int64_t build_cost_model_after =
1769         session_opts_.config.graph_options().build_cost_model_after();
1770     const int64_t build_cost_model_every =
1771         session_opts_.config.graph_options().build_cost_model();
1772     pss.collect_costs =
1773         build_cost_model_every > 0 &&
1774         ((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
1775     pss.collect_partition_graphs = req.options().output_partition_graphs();
1776 
1777     std::unique_ptr<ProfileHandler> ph = run_state->rcg->GetProfileHandler(
1778         run_state->step_id, count, req.options());
1779     if (ph) {
1780       pss.collect_timeline = true;
1781       pss.collect_rpcs = ph->should_collect_rpcs();
1782     }
1783 
1784     run_state->pss = std::move(pss);
1785     run_state->ph = std::move(ph);
1786   }
1787 
1788   // Make sure that this is a new set of feeds that are still pending.
1789   for (size_t i = 0; i < req.num_feeds(); ++i) {
1790     const string& feed = req.feed_name(i);
1791     auto it = run_state->pending_inputs.find(feed);
1792     if (it == run_state->pending_inputs.end()) {
1793       return errors::InvalidArgument(
1794           "The feed ", feed, " was not specified in partial_run_setup.");
1795     } else if (it->second) {
1796       return errors::InvalidArgument("The feed ", feed,
1797                                      " has already been fed.");
1798     }
1799   }
1800   // Check that this is a new set of fetches that are still pending.
1801   for (size_t i = 0; i < req.num_fetches(); ++i) {
1802     const string& fetch = req.fetch_name(i);
1803     auto it = run_state->pending_outputs.find(fetch);
1804     if (it == run_state->pending_outputs.end()) {
1805       return errors::InvalidArgument(
1806           "The fetch ", fetch, " was not specified in partial_run_setup.");
1807     } else if (it->second) {
1808       return errors::InvalidArgument("The fetch ", fetch,
1809                                      " has already been fetched.");
1810     }
1811   }
1812 
1813   // Ensure that the requested fetches can be computed from the provided feeds.
1814   {
1815     mutex_lock l(mu_);
1816     TF_RETURN_IF_ERROR(
1817         run_state->rcg->CheckFetches(req, run_state, execution_state_.get()));
1818   }
1819 
1820   // Determine if this partial run satisfies all the pending inputs and outputs.
1821   for (size_t i = 0; i < req.num_feeds(); ++i) {
1822     auto it = run_state->pending_inputs.find(req.feed_name(i));
1823     it->second = true;
1824   }
1825   for (size_t i = 0; i < req.num_fetches(); ++i) {
1826     auto it = run_state->pending_outputs.find(req.fetch_name(i));
1827     it->second = true;
1828   }
1829   bool is_last_partial_run = run_state->PendingDone();
1830 
1831   Status s = run_state->rcg->RunPartitions(
1832       env_, run_state->step_id, run_state->count, &run_state->pss, opts, req,
1833       resp, &cancellation_manager_, is_last_partial_run);
1834 
1835   // Delete the run state if there is an error or all fetches are done.
1836   if (!s.ok() || is_last_partial_run) {
1837     ReffedClientGraph* rcg = run_state->rcg;
1838     run_state->pss.end_micros = Env::Default()->NowMicros();
1839     // Schedule post-processing and cleanup to be done asynchronously.
1840     Ref();
1841     rcg->Ref();
1842     rcg->ProcessStats(run_state->step_id, &run_state->pss, run_state->ph.get(),
1843                       req.options(), resp->mutable_metadata());
1844     cleanup.release();  // MarkRunCompletion called in done closure.
1845     rcg->CleanupPartitionsAsync(
1846         run_state->step_id, [this, rcg, prun_handle](const Status& s) {
1847           if (!s.ok()) {
1848             LOG(ERROR) << "Cleanup partition error: " << s;
1849           }
1850           rcg->Unref();
1851           MarkRunCompletion();
1852           Unref();
1853         });
1854     mutex_lock l(mu_);
1855     partial_runs_.erase(prun_handle);
1856   }
1857   return s;
1858 }
1859 
CreateDebuggerState(const DebugOptions & debug_options,const RunStepRequestWrapper & req,int64_t rcg_execution_count,std::unique_ptr<DebuggerStateInterface> * debugger_state)1860 Status MasterSession::CreateDebuggerState(
1861     const DebugOptions& debug_options, const RunStepRequestWrapper& req,
1862     int64_t rcg_execution_count,
1863     std::unique_ptr<DebuggerStateInterface>* debugger_state) {
1864   TF_RETURN_IF_ERROR(
1865       DebuggerStateRegistry::CreateState(debug_options, debugger_state));
1866 
1867   std::vector<string> input_names;
1868   for (size_t i = 0; i < req.num_feeds(); ++i) {
1869     input_names.push_back(req.feed_name(i));
1870   }
1871   std::vector<string> output_names;
1872   for (size_t i = 0; i < req.num_fetches(); ++i) {
1873     output_names.push_back(req.fetch_name(i));
1874   }
1875   std::vector<string> target_names;
1876   for (size_t i = 0; i < req.num_targets(); ++i) {
1877     target_names.push_back(req.target_name(i));
1878   }
1879 
1880   // TODO(cais): We currently use -1 as a dummy value for session run count.
1881   // While this counter value is straightforward to define and obtain for
1882   // DirectSessions, it is less so for non-direct Sessions. Devise a better
1883   // way to get its value when the need arises.
1884   TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
1885       debug_options.global_step(), rcg_execution_count, rcg_execution_count,
1886       input_names, output_names, target_names));
1887 
1888   return OkStatus();
1889 }
1890 
FillPerStepState(MasterSession::ReffedClientGraph * rcg,const RunOptions & run_options,uint64 step_id,int64_t count,PerStepState * out_pss,std::unique_ptr<ProfileHandler> * out_ph)1891 void MasterSession::FillPerStepState(MasterSession::ReffedClientGraph* rcg,
1892                                      const RunOptions& run_options,
1893                                      uint64 step_id, int64_t count,
1894                                      PerStepState* out_pss,
1895                                      std::unique_ptr<ProfileHandler>* out_ph) {
1896   out_pss->collect_timeline =
1897       run_options.trace_level() == RunOptions::FULL_TRACE;
1898   out_pss->collect_rpcs = run_options.trace_level() == RunOptions::FULL_TRACE;
1899   out_pss->report_tensor_allocations_upon_oom =
1900       run_options.report_tensor_allocations_upon_oom();
1901   // Build the cost model every 'build_cost_model_every' steps after skipping an
1902   // initial 'build_cost_model_after' steps.
1903   const int64_t build_cost_model_after =
1904       session_opts_.config.graph_options().build_cost_model_after();
1905   const int64_t build_cost_model_every =
1906       session_opts_.config.graph_options().build_cost_model();
1907   out_pss->collect_costs =
1908       build_cost_model_every > 0 &&
1909       ((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
1910   out_pss->collect_partition_graphs = run_options.output_partition_graphs();
1911 
1912   *out_ph = rcg->GetProfileHandler(step_id, count, run_options);
1913   if (*out_ph) {
1914     out_pss->collect_timeline = true;
1915     out_pss->collect_rpcs = (*out_ph)->should_collect_rpcs();
1916   }
1917 }
1918 
PostRunCleanup(MasterSession::ReffedClientGraph * rcg,uint64 step_id,const RunOptions & run_options,PerStepState * pss,const std::unique_ptr<ProfileHandler> & ph,const Status & run_status,RunMetadata * out_run_metadata)1919 Status MasterSession::PostRunCleanup(MasterSession::ReffedClientGraph* rcg,
1920                                      uint64 step_id,
1921                                      const RunOptions& run_options,
1922                                      PerStepState* pss,
1923                                      const std::unique_ptr<ProfileHandler>& ph,
1924                                      const Status& run_status,
1925                                      RunMetadata* out_run_metadata) {
1926   Status s = run_status;
1927   if (s.ok()) {
1928     pss->end_micros = Env::Default()->NowMicros();
1929     if (rcg->collective_graph_key() !=
1930         BuildGraphOptions::kNoCollectiveGraphKey) {
1931       env_->collective_executor_mgr->RetireStepId(rcg->collective_graph_key(),
1932                                                   step_id);
1933     }
1934     // Schedule post-processing and cleanup to be done asynchronously.
1935     rcg->ProcessStats(step_id, pss, ph.get(), run_options, out_run_metadata);
1936   } else if (errors::IsCancelled(s)) {
1937     mutex_lock l(mu_);
1938     if (closed_) {
1939       if (garbage_collected_) {
1940         s = errors::Cancelled(
1941             "Step was cancelled because the session was garbage collected due "
1942             "to inactivity.");
1943       } else {
1944         s = errors::Cancelled(
1945             "Step was cancelled by an explicit call to `Session::Close()`.");
1946       }
1947     }
1948   }
1949   Ref();
1950   rcg->Ref();
1951   rcg->CleanupPartitionsAsync(step_id, [this, rcg](const Status& s) {
1952     if (!s.ok()) {
1953       LOG(ERROR) << "Cleanup partition error: " << s;
1954     }
1955     rcg->Unref();
1956     MarkRunCompletion();
1957     Unref();
1958   });
1959   return s;
1960 }
1961 
DoRunWithLocalExecution(CallOptions * opts,const RunStepRequestWrapper & req,MutableRunStepResponseWrapper * resp)1962 Status MasterSession::DoRunWithLocalExecution(
1963     CallOptions* opts, const RunStepRequestWrapper& req,
1964     MutableRunStepResponseWrapper* resp) {
1965   VLOG(2) << "DoRunWithLocalExecution req: " << req.DebugString();
1966   PerStepState pss;
1967   pss.start_micros = Env::Default()->NowMicros();
1968   auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
1969 
1970   // Prepare.
1971   BuildGraphOptions bgopts;
1972   BuildBuildGraphOptions(req, session_opts_.config, &bgopts);
1973   ReffedClientGraph* rcg = nullptr;
1974   int64_t count;
1975   TF_RETURN_IF_ERROR(StartStep(bgopts, false, &rcg, &count));
1976 
1977   // Unref "rcg" when out of scope.
1978   core::ScopedUnref unref(rcg);
1979 
1980   std::unique_ptr<DebuggerStateInterface> debugger_state;
1981   const DebugOptions& debug_options = req.options().debug_options();
1982 
1983   if (!debug_options.debug_tensor_watch_opts().empty()) {
1984     TF_RETURN_IF_ERROR(
1985         CreateDebuggerState(debug_options, req, count, &debugger_state));
1986   }
1987   TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg));
1988 
1989   // Keeps the highest 8 bits 0x01: we reserve some bits of the
1990   // step_id for future use.
1991   uint64 step_id = NewStepId(rcg->collective_graph_key());
1992   TRACEPRINTF("stepid %llu", step_id);
1993 
1994   std::unique_ptr<ProfileHandler> ph;
1995   FillPerStepState(rcg, req.options(), step_id, count, &pss, &ph);
1996 
1997   if (pss.collect_partition_graphs &&
1998       session_opts_.config.experimental().disable_output_partition_graphs()) {
1999     return errors::InvalidArgument(
2000         "RunOptions.output_partition_graphs() is not supported when "
2001         "disable_output_partition_graphs is true.");
2002   }
2003 
2004   Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp,
2005                                 &cancellation_manager_, false);
2006 
2007   cleanup.release();  // MarkRunCompletion called in PostRunCleanup().
2008   return PostRunCleanup(rcg, step_id, req.options(), &pss, ph, s,
2009                         resp->mutable_metadata());
2010 }
2011 
MakeCallable(const MakeCallableRequest & req,MakeCallableResponse * resp)2012 Status MasterSession::MakeCallable(const MakeCallableRequest& req,
2013                                    MakeCallableResponse* resp) {
2014   UpdateLastAccessTime();
2015 
2016   BuildGraphOptions opts;
2017   opts.callable_options = req.options();
2018   opts.use_function_convention = false;
2019 
2020   ReffedClientGraph* callable;
2021 
2022   {
2023     mutex_lock l(mu_);
2024     if (closed_) {
2025       return errors::FailedPrecondition("Session is closed.");
2026     }
2027     std::unique_ptr<ClientGraph> client_graph;
2028     TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
2029     callable = new ReffedClientGraph(handle_, opts, std::move(client_graph),
2030                                      session_opts_, stats_publisher_factory_,
2031                                      false /* is_partial */, get_worker_cache(),
2032                                      !should_delete_worker_sessions_);
2033   }
2034 
2035   Status s = BuildAndRegisterPartitions(callable);
2036   if (!s.ok()) {
2037     callable->Unref();
2038     return s;
2039   }
2040 
2041   uint64 handle;
2042   {
2043     mutex_lock l(mu_);
2044     handle = next_callable_handle_++;
2045     callables_[handle] = callable;
2046   }
2047 
2048   resp->set_handle(handle);
2049   return OkStatus();
2050 }
2051 
DoRunCallable(CallOptions * opts,ReffedClientGraph * rcg,const RunCallableRequest & req,RunCallableResponse * resp)2052 Status MasterSession::DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg,
2053                                     const RunCallableRequest& req,
2054                                     RunCallableResponse* resp) {
2055   VLOG(2) << "DoRunCallable req: " << req.DebugString();
2056   PerStepState pss;
2057   pss.start_micros = Env::Default()->NowMicros();
2058   auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
2059 
2060   // Prepare.
2061   int64_t count = rcg->get_and_increment_execution_count();
2062 
2063   const uint64 step_id = NewStepId(rcg->collective_graph_key());
2064   TRACEPRINTF("stepid %llu", step_id);
2065 
2066   const RunOptions& run_options = rcg->callable_options().run_options();
2067 
2068   if (run_options.timeout_in_ms() != 0) {
2069     opts->SetTimeout(run_options.timeout_in_ms());
2070   }
2071 
2072   std::unique_ptr<ProfileHandler> ph;
2073   FillPerStepState(rcg, run_options, step_id, count, &pss, &ph);
2074   Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp,
2075                                 &cancellation_manager_);
2076   cleanup.release();  // MarkRunCompletion called in PostRunCleanup().
2077   return PostRunCleanup(rcg, step_id, run_options, &pss, ph, s,
2078                         resp->mutable_metadata());
2079 }
2080 
RunCallable(CallOptions * opts,const RunCallableRequest & req,RunCallableResponse * resp)2081 Status MasterSession::RunCallable(CallOptions* opts,
2082                                   const RunCallableRequest& req,
2083                                   RunCallableResponse* resp) {
2084   UpdateLastAccessTime();
2085   ReffedClientGraph* callable;
2086   {
2087     mutex_lock l(mu_);
2088     if (closed_) {
2089       return errors::FailedPrecondition("Session is closed.");
2090     }
2091     int64_t handle = req.handle();
2092     if (handle >= next_callable_handle_) {
2093       return errors::InvalidArgument("No such callable handle: ", handle);
2094     }
2095     auto iter = callables_.find(req.handle());
2096     if (iter == callables_.end()) {
2097       return errors::InvalidArgument(
2098           "Attempted to run callable after handle was released: ", handle);
2099     }
2100     callable = iter->second;
2101     callable->Ref();
2102     ++num_running_;
2103   }
2104   core::ScopedUnref unref_callable(callable);
2105   return DoRunCallable(opts, callable, req, resp);
2106 }
2107 
ReleaseCallable(const ReleaseCallableRequest & req,ReleaseCallableResponse * resp)2108 Status MasterSession::ReleaseCallable(const ReleaseCallableRequest& req,
2109                                       ReleaseCallableResponse* resp) {
2110   UpdateLastAccessTime();
2111   ReffedClientGraph* to_unref = nullptr;
2112   {
2113     mutex_lock l(mu_);
2114     auto iter = callables_.find(req.handle());
2115     if (iter != callables_.end()) {
2116       to_unref = iter->second;
2117       callables_.erase(iter);
2118     }
2119   }
2120   if (to_unref != nullptr) {
2121     to_unref->Unref();
2122   }
2123   return OkStatus();
2124 }
2125 
Close()2126 Status MasterSession::Close() {
2127   {
2128     mutex_lock l(mu_);
2129     closed_ = true;  // All subsequent calls to Run() or Extend() will fail.
2130   }
2131   cancellation_manager_.StartCancel();
2132   std::vector<ReffedClientGraph*> to_unref;
2133   {
2134     mutex_lock l(mu_);
2135     while (num_running_ != 0) {
2136       num_running_is_zero_.wait(l);
2137     }
2138     ClearRunsTable(&to_unref, &run_graphs_);
2139     ClearRunsTable(&to_unref, &partial_run_graphs_);
2140     ClearRunsTable(&to_unref, &callables_);
2141   }
2142   for (ReffedClientGraph* rcg : to_unref) rcg->Unref();
2143   if (should_delete_worker_sessions_) {
2144     Status s = DeleteWorkerSessions();
2145     if (!s.ok()) {
2146       LOG(WARNING) << s;
2147     }
2148   }
2149   return OkStatus();
2150 }
2151 
GarbageCollect()2152 void MasterSession::GarbageCollect() {
2153   {
2154     mutex_lock l(mu_);
2155     closed_ = true;
2156     garbage_collected_ = true;
2157   }
2158   cancellation_manager_.StartCancel();
2159   Unref();
2160 }
2161 
RunState(const std::vector<string> & input_names,const std::vector<string> & output_names,ReffedClientGraph * rcg,const uint64 step_id,const int64_t count)2162 MasterSession::RunState::RunState(const std::vector<string>& input_names,
2163                                   const std::vector<string>& output_names,
2164                                   ReffedClientGraph* rcg, const uint64 step_id,
2165                                   const int64_t count)
2166     : rcg(rcg), step_id(step_id), count(count) {
2167   // Initially all the feeds and fetches are pending.
2168   for (auto& name : input_names) {
2169     pending_inputs[name] = false;
2170   }
2171   for (auto& name : output_names) {
2172     pending_outputs[name] = false;
2173   }
2174 }
2175 
~RunState()2176 MasterSession::RunState::~RunState() {
2177   if (rcg) rcg->Unref();
2178 }
2179 
PendingDone() const2180 bool MasterSession::RunState::PendingDone() const {
2181   for (const auto& it : pending_inputs) {
2182     if (!it.second) return false;
2183   }
2184   for (const auto& it : pending_outputs) {
2185     if (!it.second) return false;
2186   }
2187   return true;
2188 }
2189 
2190 }  // end namespace tensorflow
2191