xref: /aosp_15_r20/external/tensorflow/tensorflow/core/distributed_runtime/graph_mgr.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/graph_mgr.h"
17 
18 #include <chrono>  // NOLINT(build/c++11)
19 #include <vector>
20 
21 #include "tensorflow/core/common_runtime/build_graph_options.h"
22 #include "tensorflow/core/common_runtime/constant_folding.h"
23 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
24 #include "tensorflow/core/common_runtime/device.h"
25 #include "tensorflow/core/common_runtime/device_mgr.h"
26 #include "tensorflow/core/common_runtime/function.h"
27 #include "tensorflow/core/common_runtime/graph_constructor.h"
28 #include "tensorflow/core/common_runtime/graph_optimizer.h"
29 #include "tensorflow/core/common_runtime/memory_types.h"
30 #include "tensorflow/core/common_runtime/optimization_registry.h"
31 #include "tensorflow/core/common_runtime/process_util.h"
32 #include "tensorflow/core/common_runtime/rendezvous_util.h"
33 #include "tensorflow/core/common_runtime/step_stats_collector.h"
34 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
35 #include "tensorflow/core/framework/cancellation.h"
36 #include "tensorflow/core/framework/collective.h"
37 #include "tensorflow/core/framework/log_memory.h"
38 #include "tensorflow/core/framework/metrics.h"
39 #include "tensorflow/core/framework/node_def.pb.h"
40 #include "tensorflow/core/framework/node_def_util.h"
41 #include "tensorflow/core/framework/versions.pb.h"
42 #include "tensorflow/core/graph/graph.h"
43 #include "tensorflow/core/graph/graph_partition.h"
44 #include "tensorflow/core/graph/validate.h"
45 #include "tensorflow/core/lib/core/errors.h"
46 #include "tensorflow/core/lib/strings/stringprintf.h"
47 #include "tensorflow/core/platform/env.h"
48 #include "tensorflow/core/platform/logging.h"
49 #include "tensorflow/core/platform/mutex.h"
50 #include "tensorflow/core/platform/tracing.h"
51 #include "tensorflow/core/platform/types.h"
52 #include "tensorflow/core/profiler/lib/connected_traceme.h"
53 #include "tensorflow/core/profiler/lib/traceme_encode.h"
54 #include "tensorflow/core/protobuf/worker.pb.h"
55 #include "tensorflow/core/util/env_var.h"
56 
57 namespace tensorflow {
58 
GraphMgr(const WorkerEnv * worker_env,const DeviceMgr * device_mgr)59 GraphMgr::GraphMgr(const WorkerEnv* worker_env, const DeviceMgr* device_mgr)
60     : worker_env_(worker_env), device_mgr_(device_mgr), table_(5) {
61   // The default value of sync_on_finish will be flipped soon and this
62   // environment variable will be removed as well.
63   Status status =
64       ReadBoolFromEnvVar("TF_SYNC_ON_FINISH", true, &sync_on_finish_);
65   if (!status.ok()) {
66     LOG(ERROR) << status.error_message();
67   }
68 }
69 
~GraphMgr()70 GraphMgr::~GraphMgr() {
71   for (const auto& p : table_) p.second->Unref();
72 }
73 
~Item()74 GraphMgr::Item::~Item() {
75   for (const auto& unit : this->units) {
76     CHECK_NOTNULL(unit.device);
77     if (!graph_mgr->skip_cost_models_) {
78       graph_mgr->cost_model_manager_.RemoveCostModelForGraph(unit.graph.get());
79     }
80     delete unit.root;
81     unit.device->op_segment()->RemoveHold(this->session);
82   }
83 }
84 
85 // NOTE: node->device_name() is not set by GraphConstructor.  We
86 // expects that NodeDef in GraphDef given to workers fully specifies
87 // device names.
SplitByDevice(const Node * node)88 static string SplitByDevice(const Node* node) {
89   return node->assigned_device_name();
90 }
91 
92 // Validates "gdef" device specifications.
ValidateGraphDefForDevices(const GraphDef & gdef)93 static Status ValidateGraphDefForDevices(const GraphDef& gdef) {
94   DeviceNameUtils::ParsedName parsed;
95   for (const auto& ndef : gdef.node()) {
96     if (!DeviceNameUtils::ParseFullName(ndef.device(), &parsed)) {
97       return errors::InvalidArgument("Missing device name in: ",
98                                      FormatNodeDefForError(ndef));
99     }
100   }
101   return OkStatus();
102 }
103 
DecorateAndPublishGraphForDebug(const DebugOptions & debug_options,Graph * graph,Device * device)104 Status GraphMgr::DecorateAndPublishGraphForDebug(
105     const DebugOptions& debug_options, Graph* graph, Device* device) {
106   std::unique_ptr<DebugGraphDecoratorInterface> decorator;
107   TF_RETURN_IF_ERROR(
108       DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator));
109   TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device));
110   TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name()));
111   return OkStatus();
112 }
113 
114 // Creates executors given a graph definition "gdef" of a "session".
115 // If a node in "gdef" is shared by other graphs in "session", the
116 // same op kernel is reused. E.g., typically a params node is shared
117 // by multiple graphs in a session.
118 //
119 // If "gdef" is assigned to multiple devices, extra nodes (e.g.,
120 // send/recv nodes) maybe added. The extra nodes' name are generated
121 // by calling "new_name(old_name)".
122 //
123 // "executors" are filled with one executor per device if success and
124 // the caller takes the ownership of returned executors.
InitItem(const string & handle,const GraphDef & gdef,const GraphOptions & graph_options,const DebugOptions & debug_options,const ConfigProto & config_proto,int64_t collective_graph_key,WorkerSession * session,DistributedFunctionLibraryRuntime * cluster_flr,Item * item)125 Status GraphMgr::InitItem(const string& handle, const GraphDef& gdef,
126                           const GraphOptions& graph_options,
127                           const DebugOptions& debug_options,
128                           const ConfigProto& config_proto,
129                           int64_t collective_graph_key, WorkerSession* session,
130                           DistributedFunctionLibraryRuntime* cluster_flr,
131                           Item* item) {
132   item->session = handle;
133   item->collective_graph_key = collective_graph_key;
134   item->lib_def.reset(
135       new FunctionLibraryDefinition(OpRegistry::Global(), gdef.library()));
136 
137   TF_RETURN_IF_ERROR(ValidateGraphDefForDevices(gdef));
138 
139   // We don't explicitly Validate the graph def because ConvertGraphDefToGraph
140   // does that below.
141   item->proc_flr.reset(new ProcessFunctionLibraryRuntime(
142       device_mgr_, worker_env_->env, /*config=*/&config_proto,
143       gdef.versions().producer(), item->lib_def.get(),
144       graph_options.optimizer_options(), worker_env_->compute_pool, cluster_flr,
145       /*session_metadata=*/nullptr,
146       Rendezvous::Factory{
147           [this, session](const int64_t step_id, const DeviceMgr*,
148                           Rendezvous** r) -> Status {
149             auto* remote_r = this->worker_env_->rendezvous_mgr->Find(step_id);
150             TF_RETURN_IF_ERROR(remote_r->Initialize(session));
151             *r = remote_r;
152             return OkStatus();
153           },
154           [this](const int64_t step_id) {
155             this->worker_env_->rendezvous_mgr->Cleanup(step_id);
156             return OkStatus();
157           }}));
158 
159   // Constructs the graph out of "gdef".
160   Graph graph(OpRegistry::Global());
161   GraphConstructorOptions opts;
162   opts.allow_internal_ops = true;
163   opts.expect_device_spec = true;
164   opts.validate_nodes = true;
165   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, gdef, &graph));
166 
167   // Splits "graph" into multiple subgraphs by device names.
168   std::unordered_map<string, GraphDef> partitions;
169   PartitionOptions popts;
170   popts.node_to_loc = SplitByDevice;
171   popts.new_name = [this](const string& prefix) {
172     mutex_lock l(mu_);
173     return strings::StrCat(prefix, "_G", next_id_++);
174   };
175   popts.get_incarnation = [this](const string& name) -> int64 {
176     Device* device = nullptr;
177     Status s = device_mgr_->LookupDevice(name, &device);
178     if (s.ok()) {
179       return device->attributes().incarnation();
180     } else {
181       return PartitionOptions::kIllegalIncarnation;
182     }
183   };
184   popts.flib_def = item->lib_def.get();
185   popts.control_flow_added = true;
186   popts.scheduling_for_recvs = graph_options.enable_recv_scheduling();
187   TF_RETURN_IF_ERROR(Partition(popts, &graph, &partitions));
188   if (popts.scheduling_for_recvs) {
189     TF_RETURN_IF_ERROR(AddControlEdges(popts, &partitions));
190   }
191 
192   std::unordered_map<string, std::unique_ptr<Graph>> partition_graphs;
193   for (auto& partition : partitions) {
194     std::unique_ptr<Graph> device_graph(new Graph(OpRegistry::Global()));
195     GraphConstructorOptions device_opts;
196     // There are internal operations (e.g., send/recv) that we now allow.
197     device_opts.allow_internal_ops = true;
198     device_opts.expect_device_spec = true;
199     TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
200         device_opts, std::move(partition.second), device_graph.get()));
201     partition_graphs.emplace(partition.first, std::move(device_graph));
202   }
203 
204   GraphOptimizationPassOptions optimization_options;
205   optimization_options.flib_def = item->lib_def.get();
206   optimization_options.partition_graphs = &partition_graphs;
207   TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
208       OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
209 
210   LocalExecutorParams params;
211 
212   item->units.reserve(partitions.size());
213   item->graph_mgr = this;
214   const auto& optimizer_opts = graph_options.optimizer_options();
215   GraphOptimizer optimizer(optimizer_opts);
216   for (auto& p : partition_graphs) {
217     const string& device_name = p.first;
218     std::unique_ptr<Graph>& subgraph = p.second;
219     item->units.resize(item->units.size() + 1);
220     ExecutionUnit* unit = &(item->units.back());
221 
222     // Find the device.
223     Status s = device_mgr_->LookupDevice(device_name, &unit->device);
224     if (!s.ok()) {
225       // Remove the empty unit from the item as the item destructor wants all
226       // units to have valid devices.
227       item->units.pop_back();
228       return s;
229     }
230 
231     // Give the device an opportunity to rewrite its subgraph.
232     TF_RETURN_IF_ERROR(unit->device->MaybeRewriteGraph(&subgraph));
233 
234     // Top-level nodes in the graph uses the op segment to cache
235     // kernels. Therefore, as long as the executor is alive, we need
236     // to ensure the kernels cached for the session are alive.
237     auto opseg = unit->device->op_segment();
238     opseg->AddHold(handle);
239 
240     // Function library runtime.
241     FunctionLibraryRuntime* lib = item->proc_flr->GetFLR(unit->device->name());
242     if (lib == nullptr) {
243       return errors::InvalidArgument("Cannot find FLR for device: ",
244                                      unit->device->name());
245     }
246 
247     // Construct the root executor for the subgraph.
248     params.device = unit->device;
249     params.function_library = lib;
250     params.create_kernel =
251         [handle, lib, opseg](const std::shared_ptr<const NodeProperties>& props,
252                              OpKernel** kernel) {
253           // NOTE(mrry): We must not share function kernels (implemented
254           // using `CallOp`) between subgraphs, because `CallOp::handle_`
255           // is tied to a particular subgraph. Even if the function itself
256           // is stateful, the `CallOp` that invokes it is not.
257           if (!OpSegment::ShouldOwnKernel(lib, props->node_def.op())) {
258             return lib->CreateKernel(props, kernel);
259           }
260           auto create_fn = [lib, &props](OpKernel** kernel) {
261             return lib->CreateKernel(props, kernel);
262           };
263           // Kernels created for subgraph nodes need to be cached.  On
264           // cache miss, create_fn() is invoked to create a kernel based
265           // on the function library here + global op registry.
266           return opseg->FindOrCreate(handle, props->node_def.name(), kernel,
267                                      create_fn);
268         };
269     params.delete_kernel = [lib](OpKernel* kernel) {
270       if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string())) {
271         delete kernel;
272       }
273     };
274 
275     optimizer.Optimize(lib, worker_env_->env, params.device, &subgraph,
276                        GraphOptimizer::Options());
277 
278     // TensorFlow Debugger (tfdbg) inserts debug nodes in the graph.
279     if (!debug_options.debug_tensor_watch_opts().empty()) {
280       TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug(
281           debug_options, subgraph.get(), params.device));
282     }
283 
284     TF_RETURN_IF_ERROR(
285         EnsureMemoryTypes(DeviceType(unit->device->device_type()),
286                           unit->device->name(), subgraph.get()));
287     unit->graph = std::move(subgraph);
288     unit->build_cost_model = graph_options.build_cost_model();
289     if (unit->build_cost_model > 0) {
290       skip_cost_models_ = false;
291     }
292     TF_RETURN_IF_ERROR(NewLocalExecutor(params, *unit->graph, &unit->root));
293   }
294   return OkStatus();
295 }
296 
Register(const string & handle,const GraphDef & gdef,const GraphOptions & graph_options,const DebugOptions & debug_options,const ConfigProto & config_proto,int64_t collective_graph_key,WorkerSession * session,DistributedFunctionLibraryRuntime * cluster_flr,string * graph_handle)297 Status GraphMgr::Register(const string& handle, const GraphDef& gdef,
298                           const GraphOptions& graph_options,
299                           const DebugOptions& debug_options,
300                           const ConfigProto& config_proto,
301                           int64_t collective_graph_key, WorkerSession* session,
302                           DistributedFunctionLibraryRuntime* cluster_flr,
303                           string* graph_handle) {
304   Item* item = new Item;
305   Status s = InitItem(handle, gdef, graph_options, debug_options, config_proto,
306                       collective_graph_key, session, cluster_flr, item);
307   if (!s.ok()) {
308     item->Unref();
309     return s;
310   }
311 
312   // Inserts one item into table_.
313   {
314     mutex_lock l(mu_);
315     *graph_handle =
316         strings::Printf("%016llx", static_cast<long long>(++next_id_));
317     item->handle = *graph_handle;
318     CHECK(table_.insert({*graph_handle, item}).second);
319   }
320   return OkStatus();
321 }
322 
Deregister(const string & handle)323 Status GraphMgr::Deregister(const string& handle) {
324   Item* item = nullptr;
325   // Removes one item from table_.
326   {
327     mutex_lock l(mu_);
328     auto iter = table_.find(handle);
329     if (iter == table_.end()) {
330       return errors::Aborted("Graph handle is not found: ", handle,
331                              ". Possibly, this worker just restarted.");
332     }
333     item = iter->second;
334     table_.erase(iter);
335   }
336   item->Unref();
337   return OkStatus();
338 }
339 
DeregisterAll()340 Status GraphMgr::DeregisterAll() {
341   std::vector<Item*> items;
342   // Removes all items from table_.
343   {
344     mutex_lock l(mu_);
345     for (const auto& entry : table_) {
346       items.push_back(entry.second);
347     }
348     table_.clear();
349   }
350   for (auto item : items) {
351     item->Unref();
352   }
353   return OkStatus();
354 }
355 
SendInputs(const int64_t step_id,const NamedTensors & in)356 Status GraphMgr::SendInputs(const int64_t step_id, const NamedTensors& in) {
357   Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
358   std::vector<string> keys;
359   std::vector<Tensor> tensors_to_send;
360   keys.reserve(in.size());
361   tensors_to_send.reserve(in.size());
362   size_t input_size = 0;
363   for (const auto& p : in) {
364     keys.push_back(p.first);
365     tensors_to_send.push_back(p.second);
366     input_size += p.second.AllocatedBytes();
367   }
368   metrics::RecordGraphInputTensors(input_size);
369   Status s =
370       SendTensorsToRendezvous(rendezvous, nullptr, {}, keys, tensors_to_send);
371   rendezvous->Unref();
372   return s;
373 }
374 
RecvOutputs(const int64_t step_id,NamedTensors * out)375 Status GraphMgr::RecvOutputs(const int64_t step_id, NamedTensors* out) {
376   Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
377   Status s = RecvOutputsFromRendezvous(rendezvous, out, Rendezvous::Args());
378   rendezvous->Unref();
379   if (!s.ok()) {
380     // Failing to fetch the outputs should not be possible, so rewrite the error
381     // status to an INTERNAL error.
382     s = errors::Internal("Failed to fetch outputs for step ", step_id,
383                          ". (Original error message: ", s.error_message(), ")");
384   }
385   size_t output_size = 0;
386   for (auto& p : *out) {
387     output_size += p.second.AllocatedBytes();
388   }
389   metrics::RecordGraphOutputTensors(output_size);
390   return s;
391 }
392 
RecvOutputsAsync(const int64_t step_id,NamedTensors * out,StatusCallback done)393 void GraphMgr::RecvOutputsAsync(const int64_t step_id, NamedTensors* out,
394                                 StatusCallback done) {
395   Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
396   std::vector<string> keys;
397   std::vector<Tensor>* received_keys = new std::vector<Tensor>;
398   keys.reserve(out->size());
399   received_keys->reserve(out->size());
400   for (const auto& p : *out) {
401     keys.push_back(p.first);
402     received_keys->push_back(p.second);
403   }
404   RecvOutputsFromRendezvousAsync(
405       rendezvous, nullptr, {}, keys, received_keys,
406       [done, rendezvous, received_keys, out, keys](const Status s) {
407         rendezvous->Unref();
408         size_t output_size = 0;
409         for (int i = 0, end = keys.size(); i < end; ++i) {
410           (*out)[keys[i]] = (*received_keys)[i];
411           output_size += (*out)[keys[i]].AllocatedBytes();
412         }
413         metrics::RecordGraphOutputTensors(output_size);
414         delete received_keys;
415         done(s);
416       });
417 }
418 
ExecuteAsync(const string & handle,const int64_t step_id,const ExecutorOpts & opts,const NamedTensors & in,WorkerSession * session,StepStatsCollector * collector,MutableRunGraphResponseWrapper * response,CancellationManager * cancellation_manager,CoordinationServiceAgent * coordination_service_agent,StatusCallback done)419 void GraphMgr::ExecuteAsync(
420     const string& handle, const int64_t step_id, const ExecutorOpts& opts,
421     const NamedTensors& in, WorkerSession* session,
422     StepStatsCollector* collector, MutableRunGraphResponseWrapper* response,
423     CancellationManager* cancellation_manager,
424     CoordinationServiceAgent* coordination_service_agent, StatusCallback done) {
425   const uint64 start_time_usecs = Env::Default()->NowMicros();
426   profiler::TraceMeProducer activity(
427       // To TraceMeConsumers in ExecutorState::Process/Finish or RunGraphDone.
428       [step_id] {
429         return profiler::TraceMeEncode(
430             "RunGraph", {{"id", step_id}, {"_r", 1} /*root_event*/});
431       },
432       profiler::ContextType::kTfExecutor, step_id,
433       profiler::TraceMeLevel::kInfo);
434   // Lookup an item. Holds one ref while executing.
435   Item* item = nullptr;
436   {
437     mutex_lock l(mu_);
438     auto iter = table_.find(handle);
439     if (iter != table_.end()) {
440       item = iter->second;
441       item->Ref();
442     }
443   }
444 
445   if (item == nullptr) {
446     done(errors::Aborted("Graph handle is not found: ", handle));
447     return;
448   }
449 
450   CostGraphDef* cost_graph = nullptr;
451   if (response != nullptr) {
452     cost_graph = response->mutable_cost_graph();
453     if (opts.record_partition_graphs()) {
454       for (const ExecutionUnit& unit : item->units) {
455         GraphDef graph_def;
456         unit.graph->ToGraphDef(&graph_def);
457         response->AddPartitionGraph(graph_def);
458       }
459     }
460   }
461 
462   RemoteRendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
463   Status s = rendezvous->Initialize(session);
464   CollectiveExecutor::Handle* ce_handle =
465       item->collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey
466           ? new CollectiveExecutor::Handle(
467                 worker_env_->collective_executor_mgr->FindOrCreate(step_id),
468                 true)
469           : nullptr;
470   // Sends values specified by the caller.
471   size_t input_size = 0;
472   if (s.ok()) {
473     std::vector<string> keys;
474     std::vector<Tensor> tensors_to_send;
475     keys.reserve(in.size());
476     tensors_to_send.reserve(in.size());
477     for (auto& p : in) {
478       keys.push_back(p.first);
479       tensors_to_send.push_back(p.second);
480       input_size += p.second.AllocatedBytes();
481     }
482     s = SendTensorsToRendezvous(rendezvous, nullptr, {}, keys, tensors_to_send);
483   }
484 
485   if (!s.ok()) {
486     done(s);
487     delete ce_handle;
488     item->Unref();
489     rendezvous->Unref();
490     return;
491   }
492 
493   StartParallelExecutors(
494       handle, step_id, item, rendezvous, ce_handle, collector, cost_graph,
495       cancellation_manager, session, start_time_usecs,
496       coordination_service_agent,
497       [item, rendezvous, ce_handle, done, start_time_usecs, input_size,
498        step_id](const Status& s) {
499         profiler::TraceMeConsumer activity(
500             // From TraceMeProducer in GraphMgr::ExecuteAsync.
501             [step_id] {
502               return profiler::TraceMeEncode("RunGraphDone", {{"id", step_id}});
503             },
504             profiler::ContextType::kTfExecutor, step_id,
505             profiler::TraceMeLevel::kInfo);
506         done(s);
507         metrics::RecordGraphInputTensors(input_size);
508         metrics::UpdateGraphExecTime(Env::Default()->NowMicros() -
509                                      start_time_usecs);
510         rendezvous->Unref();
511         item->Unref();
512         delete ce_handle;
513       });
514 }
515 
StartParallelExecutors(const string & handle,int64_t step_id,Item * item,Rendezvous * rendezvous,CollectiveExecutor::Handle * ce_handle,StepStatsCollector * collector,CostGraphDef * cost_graph,CancellationManager * cancellation_manager,WorkerSession * session,int64_t start_time_usecs,CoordinationServiceAgent * coordination_service_agent,StatusCallback done)516 void GraphMgr::StartParallelExecutors(
517     const string& handle, int64_t step_id, Item* item, Rendezvous* rendezvous,
518     CollectiveExecutor::Handle* ce_handle, StepStatsCollector* collector,
519     CostGraphDef* cost_graph, CancellationManager* cancellation_manager,
520     WorkerSession* session, int64_t start_time_usecs,
521     CoordinationServiceAgent* coordination_service_agent, StatusCallback done) {
522   const int num_units = item->units.size();
523   CHECK_GE(num_units, 1);
524   ScopedStepContainer* step_container = new ScopedStepContainer(
525       step_id,
526       [this](const string& name) { device_mgr_->ClearContainers({name}); });
527   // NOTE: Transfer one ref of rendezvous and item.
528   ExecutorBarrier* barrier =
529       new ExecutorBarrier(num_units, rendezvous,
530                           [this, item, collector, cost_graph, step_container,
531                            done](const Status& s) {
532                             BuildCostModel(item, collector, cost_graph);
533                             done(s);
534                             delete step_container;
535                           });
536   Executor::Args args;
537   args.step_id = step_id;
538   args.rendezvous = rendezvous;
539   args.collective_executor = ce_handle ? ce_handle->get() : nullptr;
540   args.cancellation_manager = cancellation_manager;
541   args.stats_collector = collector;
542   args.step_container = step_container;
543   args.sync_on_finish = sync_on_finish_;
544   args.start_time_usecs = start_time_usecs;
545   args.coordination_service_agent = coordination_service_agent;
546 
547   if (LogMemory::IsEnabled()) {
548     LogMemory::RecordStep(args.step_id, handle);
549   }
550   thread::ThreadPool* pool = worker_env_->compute_pool;
551   using std::placeholders::_1;
552   // Line below is equivalent to this code, but does one less indirect call:
553   //  args.runner = [pool](std::function<void()> fn) { pool->Schedule(fn); };
554   auto default_runner = std::bind(&thread::ThreadPool::Schedule, pool, _1);
555   for (const auto& unit : item->units) {
556     // TODO(zhengxq): if the device picks its own threadpool, we need to assign
557     //     less threads to the main compute pool by default.
558     thread::ThreadPool* device_thread_pool =
559         unit.device->tensorflow_device_thread_pool();
560     if (!device_thread_pool) {
561       args.runner = default_runner;
562     } else {
563       args.runner =
564           std::bind(&thread::ThreadPool::Schedule, device_thread_pool, _1);
565     }
566     unit.root->RunAsync(args, barrier->Get());
567   }
568 }
569 
BuildCostModel(Item * item,StepStatsCollector * collector,CostGraphDef * cost_graph)570 void GraphMgr::BuildCostModel(Item* item, StepStatsCollector* collector,
571                               CostGraphDef* cost_graph) {
572   if (collector && !skip_cost_models_) {
573     // Build the cost model
574     std::unordered_map<string, const Graph*> device_to_graph;
575     for (const auto& unit : item->units) {
576       if (unit.build_cost_model > 0) {
577         device_to_graph[unit.device->name()] = unit.graph.get();
578       }
579     }
580     collector->BuildCostModel(&cost_model_manager_, device_to_graph);
581 
582     if (cost_graph != nullptr) {
583       for (const auto& unit : item->units) {
584         cost_model_manager_.AddToCostGraphDef(unit.graph.get(), cost_graph)
585             .IgnoreError();
586       }
587     }
588   }
589 }
590 
591 }  // end namespace tensorflow
592