xref: /aosp_15_r20/external/tensorflow/tensorflow/core/distributed_runtime/message_wrappers.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/message_wrappers.h"
17 
18 #include "tensorflow/core/framework/cost_graph.pb.h"
19 #include "tensorflow/core/framework/step_stats.pb.h"
20 #include "tensorflow/core/framework/tensor.pb.h"
21 #include "tensorflow/core/protobuf/config.pb.h"
22 #include "tensorflow/core/protobuf/named_tensor.pb.h"
23 
24 namespace tensorflow {
25 
ParseTensorProtoToTensor(const TensorProto & tensor_proto,Tensor * out_tensor)26 bool ParseTensorProtoToTensor(const TensorProto& tensor_proto,
27                               Tensor* out_tensor) {
28   if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) {
29     Tensor parsed(tensor_proto.dtype());
30     if (parsed.FromProto(cpu_allocator(), tensor_proto)) {
31       *out_tensor = parsed;
32       return true;
33     }
34   }
35   return false;
36 }
37 
session_handle() const38 const string& InMemoryRunStepRequest::session_handle() const {
39   return session_handle_;
40 }
41 
set_session_handle(const string & handle)42 void InMemoryRunStepRequest::set_session_handle(const string& handle) {
43   session_handle_ = handle;
44 }
45 
partial_run_handle() const46 const string& InMemoryRunStepRequest::partial_run_handle() const {
47   return partial_run_handle_;
48 }
49 
set_partial_run_handle(const string & handle)50 void InMemoryRunStepRequest::set_partial_run_handle(const string& handle) {
51   partial_run_handle_ = handle;
52 }
53 
num_feeds() const54 size_t InMemoryRunStepRequest::num_feeds() const { return feeds_.size(); }
feed_name(size_t i) const55 const string& InMemoryRunStepRequest::feed_name(size_t i) const {
56   return feeds_[i].first;
57 }
58 
FeedValue(size_t i,Tensor * out_tensor) const59 Status InMemoryRunStepRequest::FeedValue(size_t i, Tensor* out_tensor) const {
60   *out_tensor = feeds_[i].second;
61   return OkStatus();
62 }
63 
FeedValue(size_t i,TensorProto * out_tensor) const64 Status InMemoryRunStepRequest::FeedValue(size_t i,
65                                          TensorProto* out_tensor) const {
66   feeds_[i].second.AsProtoTensorContent(out_tensor);
67   return OkStatus();
68 }
69 
add_feed(const string & name,const Tensor & value)70 void InMemoryRunStepRequest::add_feed(const string& name, const Tensor& value) {
71   feeds_.emplace_back(name, value);
72 }
73 
num_fetches() const74 size_t InMemoryRunStepRequest::num_fetches() const { return fetches_.size(); }
fetch_name(size_t i) const75 const string& InMemoryRunStepRequest::fetch_name(size_t i) const {
76   return fetches_[i];
77 }
add_fetch(const string & name)78 void InMemoryRunStepRequest::add_fetch(const string& name) {
79   fetches_.push_back(name);
80 }
81 
num_targets() const82 size_t InMemoryRunStepRequest::num_targets() const { return targets_.size(); }
target_name(size_t i) const83 const string& InMemoryRunStepRequest::target_name(size_t i) const {
84   return targets_[i];
85 }
add_target(const string & name)86 void InMemoryRunStepRequest::add_target(const string& name) {
87   targets_.push_back(name);
88 }
89 
options() const90 const RunOptions& InMemoryRunStepRequest::options() const { return options_; }
91 
mutable_options()92 RunOptions* InMemoryRunStepRequest::mutable_options() { return &options_; }
93 
store_errors_in_response_body() const94 bool InMemoryRunStepRequest::store_errors_in_response_body() const {
95   return store_errors_in_response_body_;
96 }
97 
request_id() const98 int64_t InMemoryRunStepRequest::request_id() const {
99   return 0;  // no need to track request id for local version.
100 }
101 
set_store_errors_in_response_body(bool store_errors)102 void InMemoryRunStepRequest::set_store_errors_in_response_body(
103     bool store_errors) {
104   store_errors_in_response_body_ = store_errors;
105 }
106 
DebugString() const107 string InMemoryRunStepRequest::DebugString() const {
108   return ToProto().DebugString();
109 }
110 
ToProto() const111 const RunStepRequest& InMemoryRunStepRequest::ToProto() const {
112   if (!proto_version_) {
113     proto_version_.reset(new RunStepRequest);
114     proto_version_->set_session_handle(session_handle());
115     proto_version_->set_partial_run_handle(partial_run_handle());
116     for (size_t i = 0; i < num_feeds(); ++i) {
117       auto feed = proto_version_->add_feed();
118       feed->set_name(feed_name(i));
119       feeds_[i].second.AsProtoTensorContent(feed->mutable_tensor());
120     }
121     for (size_t i = 0; i < num_fetches(); ++i) {
122       proto_version_->add_fetch(fetch_name(i));
123     }
124     for (size_t i = 0; i < num_targets(); ++i) {
125       proto_version_->add_target(target_name(i));
126     }
127     *proto_version_->mutable_options() = options();
128   }
129   return *proto_version_;
130 }
131 
session_handle() const132 const string& MutableProtoRunStepRequest::session_handle() const {
133   return request_.session_handle();
134 }
set_session_handle(const string & handle)135 void MutableProtoRunStepRequest::set_session_handle(const string& handle) {
136   request_.set_session_handle(handle);
137 }
138 
partial_run_handle() const139 const string& MutableProtoRunStepRequest::partial_run_handle() const {
140   return request_.partial_run_handle();
141 }
set_partial_run_handle(const string & handle)142 void MutableProtoRunStepRequest::set_partial_run_handle(const string& handle) {
143   request_.set_partial_run_handle(handle);
144 }
145 
num_feeds() const146 size_t MutableProtoRunStepRequest::num_feeds() const {
147   return request_.feed_size();
148 }
feed_name(size_t i) const149 const string& MutableProtoRunStepRequest::feed_name(size_t i) const {
150   return request_.feed(i).name();
151 }
FeedValue(size_t i,Tensor * out_tensor) const152 Status MutableProtoRunStepRequest::FeedValue(size_t i,
153                                              Tensor* out_tensor) const {
154   if (!ParseTensorProtoToTensor(request_.feed(i).tensor(), out_tensor)) {
155     return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
156   } else {
157     return OkStatus();
158   }
159 }
160 
FeedValue(size_t i,TensorProto * out_tensor) const161 Status MutableProtoRunStepRequest::FeedValue(size_t i,
162                                              TensorProto* out_tensor) const {
163   *out_tensor = request_.feed(i).tensor();
164   return OkStatus();
165 }
166 
add_feed(const string & name,const Tensor & value)167 void MutableProtoRunStepRequest::add_feed(const string& name,
168                                           const Tensor& value) {
169   NamedTensorProto* feed = request_.add_feed();
170   feed->set_name(name);
171   TensorProto* value_proto = feed->mutable_tensor();
172   value.AsProtoTensorContent(value_proto);
173 }
174 
num_fetches() const175 size_t MutableProtoRunStepRequest::num_fetches() const {
176   return request_.fetch_size();
177 }
178 
fetch_name(size_t i) const179 const string& MutableProtoRunStepRequest::fetch_name(size_t i) const {
180   return request_.fetch(i);
181 }
add_fetch(const string & name)182 void MutableProtoRunStepRequest::add_fetch(const string& name) {
183   request_.add_fetch(name);
184 }
185 
num_targets() const186 size_t MutableProtoRunStepRequest::num_targets() const {
187   return request_.target_size();
188 }
189 
target_name(size_t i) const190 const string& MutableProtoRunStepRequest::target_name(size_t i) const {
191   return request_.target(i);
192 }
193 
add_target(const string & name)194 void MutableProtoRunStepRequest::add_target(const string& name) {
195   request_.add_target(name);
196 }
197 
options() const198 const RunOptions& MutableProtoRunStepRequest::options() const {
199   return request_.options();
200 }
201 
mutable_options()202 RunOptions* MutableProtoRunStepRequest::mutable_options() {
203   return request_.mutable_options();
204 }
205 
store_errors_in_response_body() const206 bool MutableProtoRunStepRequest::store_errors_in_response_body() const {
207   return request_.store_errors_in_response_body();
208 }
209 
set_store_errors_in_response_body(bool store_errors)210 void MutableProtoRunStepRequest::set_store_errors_in_response_body(
211     bool store_errors) {
212   request_.set_store_errors_in_response_body(store_errors);
213 }
214 
request_id() const215 int64_t MutableProtoRunStepRequest::request_id() const {
216   return request_.request_id();
217 }
218 
DebugString() const219 string MutableProtoRunStepRequest::DebugString() const {
220   return request_.DebugString();
221 }
222 
ToProto() const223 const RunStepRequest& MutableProtoRunStepRequest::ToProto() const {
224   return request_;
225 }
226 
ProtoRunStepRequest(const RunStepRequest * request)227 ProtoRunStepRequest::ProtoRunStepRequest(const RunStepRequest* request)
228     : request_(request) {}
229 
session_handle() const230 const string& ProtoRunStepRequest::session_handle() const {
231   return request_->session_handle();
232 }
233 
partial_run_handle() const234 const string& ProtoRunStepRequest::partial_run_handle() const {
235   return request_->partial_run_handle();
236 }
237 
num_feeds() const238 size_t ProtoRunStepRequest::num_feeds() const { return request_->feed_size(); }
239 
feed_name(size_t i) const240 const string& ProtoRunStepRequest::feed_name(size_t i) const {
241   return request_->feed(i).name();
242 }
243 
FeedValue(size_t i,Tensor * out_tensor) const244 Status ProtoRunStepRequest::FeedValue(size_t i, Tensor* out_tensor) const {
245   if (!ParseTensorProtoToTensor(request_->feed(i).tensor(), out_tensor)) {
246     return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
247   } else {
248     return OkStatus();
249   }
250 }
251 
FeedValue(size_t i,TensorProto * out_tensor) const252 Status ProtoRunStepRequest::FeedValue(size_t i, TensorProto* out_tensor) const {
253   *out_tensor = request_->feed(i).tensor();
254   return OkStatus();
255 }
256 
num_fetches() const257 size_t ProtoRunStepRequest::num_fetches() const {
258   return request_->fetch_size();
259 }
260 
fetch_name(size_t i) const261 const string& ProtoRunStepRequest::fetch_name(size_t i) const {
262   return request_->fetch(i);
263 }
264 
num_targets() const265 size_t ProtoRunStepRequest::num_targets() const {
266   return request_->target_size();
267 }
268 
target_name(size_t i) const269 const string& ProtoRunStepRequest::target_name(size_t i) const {
270   return request_->target(i);
271 }
272 
options() const273 const RunOptions& ProtoRunStepRequest::options() const {
274   return request_->options();
275 }
276 
store_errors_in_response_body() const277 bool ProtoRunStepRequest::store_errors_in_response_body() const {
278   return request_->store_errors_in_response_body();
279 }
280 
request_id() const281 int64_t ProtoRunStepRequest::request_id() const {
282   return request_->request_id();
283 }
284 
DebugString() const285 string ProtoRunStepRequest::DebugString() const {
286   return request_->DebugString();
287 }
288 
ToProto() const289 const RunStepRequest& ProtoRunStepRequest::ToProto() const { return *request_; }
290 
session_handle() const291 const string& InMemoryRunGraphRequest::session_handle() const {
292   return session_handle_;
293 }
294 
create_worker_session_called() const295 bool InMemoryRunGraphRequest::create_worker_session_called() const {
296   return create_worker_session_called_;
297 }
298 
set_session_handle(const string & handle)299 void InMemoryRunGraphRequest::set_session_handle(const string& handle) {
300   session_handle_ = handle;
301 }
302 
set_create_worker_session_called(bool called)303 void InMemoryRunGraphRequest::set_create_worker_session_called(bool called) {
304   create_worker_session_called_ = called;
305 }
306 
graph_handle() const307 const string& InMemoryRunGraphRequest::graph_handle() const {
308   return graph_handle_;
309 }
310 
set_graph_handle(const string & handle)311 void InMemoryRunGraphRequest::set_graph_handle(const string& handle) {
312   graph_handle_ = handle;
313 }
314 
step_id() const315 int64_t InMemoryRunGraphRequest::step_id() const { return step_id_; }
316 
set_step_id(int64_t step_id)317 void InMemoryRunGraphRequest::set_step_id(int64_t step_id) {
318   step_id_ = step_id;
319 }
320 
exec_opts() const321 const ExecutorOpts& InMemoryRunGraphRequest::exec_opts() const {
322   return exec_opts_;
323 }
324 
mutable_exec_opts()325 ExecutorOpts* InMemoryRunGraphRequest::mutable_exec_opts() {
326   return &exec_opts_;
327 }
328 
num_sends() const329 size_t InMemoryRunGraphRequest::num_sends() const { return sends_.size(); }
330 
send_key(size_t i) const331 const string& InMemoryRunGraphRequest::send_key(size_t i) const {
332   return sends_[i].first;
333 }
334 
SendValue(size_t i,Tensor * out_tensor) const335 Status InMemoryRunGraphRequest::SendValue(size_t i, Tensor* out_tensor) const {
336   *out_tensor = sends_[i].second;
337   return OkStatus();
338 }
339 
AddSendFromRunStepRequest(const RunStepRequestWrapper & run_step_request,size_t i,const string & send_key)340 Status InMemoryRunGraphRequest::AddSendFromRunStepRequest(
341     const RunStepRequestWrapper& run_step_request, size_t i,
342     const string& send_key) {
343   Tensor tensor;
344   TF_RETURN_IF_ERROR(run_step_request.FeedValue(i, &tensor));
345   sends_.emplace_back(send_key, std::move(tensor));
346   return OkStatus();
347 }
348 
AddSendFromRunCallableRequest(const RunCallableRequest & run_callable_request,size_t i,const string & send_key)349 Status InMemoryRunGraphRequest::AddSendFromRunCallableRequest(
350     const RunCallableRequest& run_callable_request, size_t i,
351     const string& send_key) {
352   Tensor tensor;
353   if (!ParseTensorProtoToTensor(run_callable_request.feed(i), &tensor)) {
354     return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
355   }
356   sends_.emplace_back(send_key, std::move(tensor));
357   return OkStatus();
358 }
359 
num_recvs() const360 size_t InMemoryRunGraphRequest::num_recvs() const { return recvs_.size(); }
361 
recv_key(size_t i) const362 const string& InMemoryRunGraphRequest::recv_key(size_t i) const {
363   return recvs_[i];
364 }
365 
add_recv_key(const string & recv_key)366 void InMemoryRunGraphRequest::add_recv_key(const string& recv_key) {
367   recvs_.push_back(recv_key);
368 }
369 
is_partial() const370 bool InMemoryRunGraphRequest::is_partial() const { return is_partial_; }
371 
set_is_partial(bool is_partial)372 void InMemoryRunGraphRequest::set_is_partial(bool is_partial) {
373   is_partial_ = is_partial;
374 }
375 
is_last_partial_run() const376 bool InMemoryRunGraphRequest::is_last_partial_run() const {
377   return is_last_partial_run_;
378 }
379 
set_is_last_partial_run(bool is_last_partial_run)380 void InMemoryRunGraphRequest::set_is_last_partial_run(
381     bool is_last_partial_run) {
382   is_last_partial_run_ = is_last_partial_run;
383 }
384 
store_errors_in_response_body() const385 bool InMemoryRunGraphRequest::store_errors_in_response_body() const {
386   return store_errors_in_response_body_;
387 }
388 
set_store_errors_in_response_body(bool store_errors)389 void InMemoryRunGraphRequest::set_store_errors_in_response_body(
390     bool store_errors) {
391   store_errors_in_response_body_ = store_errors;
392 }
393 
request_id() const394 int64_t InMemoryRunGraphRequest::request_id() const { return request_id_; }
395 
set_request_id(int64_t request_id)396 void InMemoryRunGraphRequest::set_request_id(int64_t request_id) {
397   request_id_ = request_id;
398 }
399 
ToProto() const400 const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const {
401   if (!proto_version_) {
402     proto_version_.reset(new RunGraphRequest);
403     proto_version_->set_session_handle(session_handle());
404     proto_version_->set_create_worker_session_called(
405         create_worker_session_called());
406     proto_version_->set_graph_handle(graph_handle());
407     proto_version_->set_step_id(step_id());
408     *proto_version_->mutable_exec_opts() = exec_opts();
409     for (size_t i = 0; i < num_sends(); ++i) {
410       auto send = proto_version_->add_send();
411       send->set_name(send_key(i));
412       sends_[i].second.AsProtoTensorContent(send->mutable_tensor());
413     }
414     for (size_t i = 0; i < num_recvs(); ++i) {
415       proto_version_->add_recv_key(recv_key(i));
416     }
417     proto_version_->set_is_partial(is_partial());
418     proto_version_->set_is_last_partial_run(is_last_partial_run());
419   }
420   proto_version_->set_store_errors_in_response_body(
421       store_errors_in_response_body_);
422   proto_version_->set_request_id(request_id_);
423   return *proto_version_;
424 }
425 
session_handle() const426 const string& MutableProtoRunGraphRequest::session_handle() const {
427   return request_.session_handle();
428 }
429 
set_session_handle(const string & handle)430 void MutableProtoRunGraphRequest::set_session_handle(const string& handle) {
431   request_.set_session_handle(handle);
432 }
433 
create_worker_session_called() const434 bool MutableProtoRunGraphRequest::create_worker_session_called() const {
435   return request_.create_worker_session_called();
436 }
437 
set_create_worker_session_called(bool called)438 void MutableProtoRunGraphRequest::set_create_worker_session_called(
439     bool called) {
440   request_.set_create_worker_session_called(called);
441 }
442 
graph_handle() const443 const string& MutableProtoRunGraphRequest::graph_handle() const {
444   return request_.graph_handle();
445 }
446 
set_graph_handle(const string & handle)447 void MutableProtoRunGraphRequest::set_graph_handle(const string& handle) {
448   request_.set_graph_handle(handle);
449 }
450 
step_id() const451 int64_t MutableProtoRunGraphRequest::step_id() const {
452   return request_.step_id();
453 }
454 
set_step_id(int64_t step_id)455 void MutableProtoRunGraphRequest::set_step_id(int64_t step_id) {
456   request_.set_step_id(step_id);
457 }
458 
exec_opts() const459 const ExecutorOpts& MutableProtoRunGraphRequest::exec_opts() const {
460   return request_.exec_opts();
461 }
462 
mutable_exec_opts()463 ExecutorOpts* MutableProtoRunGraphRequest::mutable_exec_opts() {
464   return request_.mutable_exec_opts();
465 }
466 
num_sends() const467 size_t MutableProtoRunGraphRequest::num_sends() const {
468   return request_.send_size();
469 }
470 
send_key(size_t i) const471 const string& MutableProtoRunGraphRequest::send_key(size_t i) const {
472   return request_.send(i).name();
473 }
474 
SendValue(size_t i,Tensor * out_tensor) const475 Status MutableProtoRunGraphRequest::SendValue(size_t i,
476                                               Tensor* out_tensor) const {
477   if (!ParseTensorProtoToTensor(request_.send(i).tensor(), out_tensor)) {
478     return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
479   } else {
480     return OkStatus();
481   }
482 }
483 
AddSendFromRunStepRequest(const RunStepRequestWrapper & run_step_request,size_t i,const string & send_key)484 Status MutableProtoRunGraphRequest::AddSendFromRunStepRequest(
485     const RunStepRequestWrapper& run_step_request, size_t i,
486     const string& send_key) {
487   NamedTensorProto* send = request_.add_send();
488   send->set_name(send_key);
489   TF_RETURN_IF_ERROR(run_step_request.FeedValue(i, send->mutable_tensor()));
490   return OkStatus();
491 }
492 
AddSendFromRunCallableRequest(const RunCallableRequest & run_callable_request,size_t i,const string & send_key)493 Status MutableProtoRunGraphRequest::AddSendFromRunCallableRequest(
494     const RunCallableRequest& run_callable_request, size_t i,
495     const string& send_key) {
496   NamedTensorProto* send = request_.add_send();
497   send->set_name(send_key);
498   *send->mutable_tensor() = run_callable_request.feed(i);
499   return OkStatus();
500 }
501 
num_recvs() const502 size_t MutableProtoRunGraphRequest::num_recvs() const {
503   return request_.recv_key_size();
504 }
505 
recv_key(size_t i) const506 const string& MutableProtoRunGraphRequest::recv_key(size_t i) const {
507   return request_.recv_key(i);
508 }
509 
add_recv_key(const string & recv_key)510 void MutableProtoRunGraphRequest::add_recv_key(const string& recv_key) {
511   request_.add_recv_key(recv_key);
512 }
513 
is_partial() const514 bool MutableProtoRunGraphRequest::is_partial() const {
515   return request_.is_partial();
516 }
517 
set_is_partial(bool is_partial)518 void MutableProtoRunGraphRequest::set_is_partial(bool is_partial) {
519   request_.set_is_partial(is_partial);
520 }
521 
is_last_partial_run() const522 bool MutableProtoRunGraphRequest::is_last_partial_run() const {
523   return request_.is_last_partial_run();
524 }
525 
set_is_last_partial_run(bool is_last_partial_run)526 void MutableProtoRunGraphRequest::set_is_last_partial_run(
527     bool is_last_partial_run) {
528   request_.set_is_last_partial_run(is_last_partial_run);
529 }
530 
store_errors_in_response_body() const531 bool MutableProtoRunGraphRequest::store_errors_in_response_body() const {
532   return request_.store_errors_in_response_body();
533 }
534 
set_store_errors_in_response_body(bool store_errors)535 void MutableProtoRunGraphRequest::set_store_errors_in_response_body(
536     bool store_errors) {
537   request_.set_store_errors_in_response_body(store_errors);
538 }
539 
request_id() const540 int64_t MutableProtoRunGraphRequest::request_id() const {
541   return request_.request_id();
542 }
543 
set_request_id(int64_t request_id)544 void MutableProtoRunGraphRequest::set_request_id(int64_t request_id) {
545   request_.set_request_id(request_id);
546 }
547 
ToProto() const548 const RunGraphRequest& MutableProtoRunGraphRequest::ToProto() const {
549   return request_;
550 }
551 
ProtoRunGraphRequest(const RunGraphRequest * request)552 ProtoRunGraphRequest::ProtoRunGraphRequest(const RunGraphRequest* request)
553     : request_(request) {}
554 
session_handle() const555 const string& ProtoRunGraphRequest::session_handle() const {
556   return request_->session_handle();
557 }
558 
create_worker_session_called() const559 bool ProtoRunGraphRequest::create_worker_session_called() const {
560   return request_->create_worker_session_called();
561 }
562 
graph_handle() const563 const string& ProtoRunGraphRequest::graph_handle() const {
564   return request_->graph_handle();
565 }
566 
step_id() const567 int64_t ProtoRunGraphRequest::step_id() const { return request_->step_id(); }
568 
exec_opts() const569 const ExecutorOpts& ProtoRunGraphRequest::exec_opts() const {
570   return request_->exec_opts();
571 }
572 
num_sends() const573 size_t ProtoRunGraphRequest::num_sends() const { return request_->send_size(); }
574 
send_key(size_t i) const575 const string& ProtoRunGraphRequest::send_key(size_t i) const {
576   return request_->send(i).name();
577 }
578 
SendValue(size_t i,Tensor * out_tensor) const579 Status ProtoRunGraphRequest::SendValue(size_t i, Tensor* out_tensor) const {
580   if (!ParseTensorProtoToTensor(request_->send(i).tensor(), out_tensor)) {
581     return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
582   } else {
583     return OkStatus();
584   }
585 }
586 
num_recvs() const587 size_t ProtoRunGraphRequest::num_recvs() const {
588   return request_->recv_key_size();
589 }
590 
recv_key(size_t i) const591 const string& ProtoRunGraphRequest::recv_key(size_t i) const {
592   return request_->recv_key(i);
593 }
594 
is_partial() const595 bool ProtoRunGraphRequest::is_partial() const { return request_->is_partial(); }
596 
is_last_partial_run() const597 bool ProtoRunGraphRequest::is_last_partial_run() const {
598   return request_->is_last_partial_run();
599 }
600 
store_errors_in_response_body() const601 bool ProtoRunGraphRequest::store_errors_in_response_body() const {
602   return request_->store_errors_in_response_body();
603 }
604 
request_id() const605 int64_t ProtoRunGraphRequest::request_id() const {
606   return request_->request_id();
607 }
608 
ToProto() const609 const RunGraphRequest& ProtoRunGraphRequest::ToProto() const {
610   return *request_;
611 }
612 
num_recvs() const613 size_t InMemoryRunGraphResponse::num_recvs() const { return recvs_.size(); }
614 
recv_key(size_t i) const615 const string& InMemoryRunGraphResponse::recv_key(size_t i) const {
616   return recvs_[i].first;
617 }
618 
RecvValue(size_t i,TensorProto * out_tensor)619 Status InMemoryRunGraphResponse::RecvValue(size_t i, TensorProto* out_tensor) {
620   recvs_[i].second.AsProtoTensorContent(out_tensor);
621   return OkStatus();
622 }
623 
RecvValue(size_t i,Tensor * out_tensor)624 Status InMemoryRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) {
625   *out_tensor = recvs_[i].second;
626   return OkStatus();
627 }
628 
AddRecv(const string & key,const Tensor & value)629 void InMemoryRunGraphResponse::AddRecv(const string& key, const Tensor& value) {
630   recvs_.emplace_back(key, value);
631 }
632 
mutable_step_stats()633 StepStats* InMemoryRunGraphResponse::mutable_step_stats() {
634   return &step_stats_;
635 }
636 
mutable_cost_graph()637 CostGraphDef* InMemoryRunGraphResponse::mutable_cost_graph() {
638   return &cost_graph_;
639 }
640 
status() const641 Status InMemoryRunGraphResponse::status() const { return status_; }
642 
status_code() const643 errors::Code InMemoryRunGraphResponse::status_code() const {
644   return status_.code();
645 }
646 
status_error_message() const647 const string& InMemoryRunGraphResponse::status_error_message() const {
648   return status_.error_message();
649 }
650 
set_status(const Status & status)651 void InMemoryRunGraphResponse::set_status(const Status& status) {
652   status_ = status;
653 }
654 
get_proto()655 RunGraphResponse* InMemoryRunGraphResponse::get_proto() {
656   LOG(FATAL) << "Cannot get a mutable protobuf for an InMemoryRunGraphResponse";
657   return nullptr;
658 }
659 
num_partition_graphs() const660 size_t InMemoryRunGraphResponse::num_partition_graphs() const {
661   return partition_graphs_.size();
662 }
663 
mutable_partition_graph(size_t i)664 GraphDef* InMemoryRunGraphResponse::mutable_partition_graph(size_t i) {
665   return &partition_graphs_[i];
666 }
667 
AddPartitionGraph(const GraphDef & partition_graph)668 void InMemoryRunGraphResponse::AddPartitionGraph(
669     const GraphDef& partition_graph) {
670   partition_graphs_.push_back(partition_graph);
671 }
672 
num_recvs() const673 size_t OwnedProtoRunGraphResponse::num_recvs() const {
674   return response_.recv_size();
675 }
676 
recv_key(size_t i) const677 const string& OwnedProtoRunGraphResponse::recv_key(size_t i) const {
678   return response_.recv(i).name();
679 }
680 
RecvValue(size_t i,TensorProto * out_tensor)681 Status OwnedProtoRunGraphResponse::RecvValue(size_t i,
682                                              TensorProto* out_tensor) {
683   out_tensor->Swap(response_.mutable_recv(i)->mutable_tensor());
684   return OkStatus();
685 }
686 
RecvValue(size_t i,Tensor * out_tensor)687 Status OwnedProtoRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) {
688   if (!ParseTensorProtoToTensor(response_.recv(i).tensor(), out_tensor)) {
689     return errors::InvalidArgument("Invalid TensorProto for recv value ", i);
690   } else {
691     return OkStatus();
692   }
693 }
694 
AddRecv(const string & key,const Tensor & value)695 void OwnedProtoRunGraphResponse::AddRecv(const string& key,
696                                          const Tensor& value) {
697   NamedTensorProto* recv = response_.add_recv();
698   recv->set_name(key);
699   TensorProto* value_proto = recv->mutable_tensor();
700   value.AsProtoTensorContent(value_proto);
701 }
702 
mutable_step_stats()703 StepStats* OwnedProtoRunGraphResponse::mutable_step_stats() {
704   return response_.mutable_step_stats();
705 }
706 
mutable_cost_graph()707 CostGraphDef* OwnedProtoRunGraphResponse::mutable_cost_graph() {
708   return response_.mutable_cost_graph();
709 }
710 
status() const711 Status OwnedProtoRunGraphResponse::status() const {
712   return Status(response_.status_code(), response_.status_error_message());
713 }
714 
status_code() const715 errors::Code OwnedProtoRunGraphResponse::status_code() const {
716   return response_.status_code();
717 }
718 
status_error_message() const719 const string& OwnedProtoRunGraphResponse::status_error_message() const {
720   return response_.status_error_message();
721 }
722 
set_status(const Status & status)723 void OwnedProtoRunGraphResponse::set_status(const Status& status) {
724   response_.set_status_code(status.code());
725   response_.set_status_error_message(status.error_message());
726 }
727 
get_proto()728 RunGraphResponse* OwnedProtoRunGraphResponse::get_proto() { return &response_; }
729 
num_partition_graphs() const730 size_t OwnedProtoRunGraphResponse::num_partition_graphs() const {
731   return response_.partition_graph_size();
732 }
733 
mutable_partition_graph(size_t i)734 GraphDef* OwnedProtoRunGraphResponse::mutable_partition_graph(size_t i) {
735   return response_.mutable_partition_graph(i);
736 }
737 
AddPartitionGraph(const GraphDef & partition_graph)738 void OwnedProtoRunGraphResponse::AddPartitionGraph(
739     const GraphDef& partition_graph) {
740   GraphDef* graph_def = response_.mutable_partition_graph()->Add();
741   *graph_def = partition_graph;
742 }
743 
NonOwnedProtoRunGraphResponse(RunGraphResponse * response)744 NonOwnedProtoRunGraphResponse::NonOwnedProtoRunGraphResponse(
745     RunGraphResponse* response)
746     : response_(response) {}
747 
num_recvs() const748 size_t NonOwnedProtoRunGraphResponse::num_recvs() const {
749   return response_->recv_size();
750 }
751 
recv_key(size_t i) const752 const string& NonOwnedProtoRunGraphResponse::recv_key(size_t i) const {
753   return response_->recv(i).name();
754 }
755 
RecvValue(size_t i,TensorProto * out_tensor)756 Status NonOwnedProtoRunGraphResponse::RecvValue(size_t i,
757                                                 TensorProto* out_tensor) {
758   out_tensor->Swap(response_->mutable_recv(i)->mutable_tensor());
759   return OkStatus();
760 }
761 
RecvValue(size_t i,Tensor * out_tensor)762 Status NonOwnedProtoRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) {
763   if (!ParseTensorProtoToTensor(response_->recv(i).tensor(), out_tensor)) {
764     return errors::InvalidArgument("Invalid TensorProto for recv value ", i);
765   } else {
766     return OkStatus();
767   }
768 }
769 
AddRecv(const string & key,const Tensor & value)770 void NonOwnedProtoRunGraphResponse::AddRecv(const string& key,
771                                             const Tensor& value) {
772   NamedTensorProto* recv = response_->add_recv();
773   recv->set_name(key);
774   TensorProto* value_proto = recv->mutable_tensor();
775   value.AsProtoTensorContent(value_proto);
776 }
777 
mutable_step_stats()778 StepStats* NonOwnedProtoRunGraphResponse::mutable_step_stats() {
779   return response_->mutable_step_stats();
780 }
781 
mutable_cost_graph()782 CostGraphDef* NonOwnedProtoRunGraphResponse::mutable_cost_graph() {
783   return response_->mutable_cost_graph();
784 }
785 
status() const786 Status NonOwnedProtoRunGraphResponse::status() const {
787   return Status(response_->status_code(), response_->status_error_message());
788 }
789 
status_code() const790 errors::Code NonOwnedProtoRunGraphResponse::status_code() const {
791   return response_->status_code();
792 }
793 
status_error_message() const794 const string& NonOwnedProtoRunGraphResponse::status_error_message() const {
795   return response_->status_error_message();
796 }
797 
set_status(const Status & status)798 void NonOwnedProtoRunGraphResponse::set_status(const Status& status) {
799   response_->set_status_code(status.code());
800   response_->set_status_error_message(status.error_message());
801 }
802 
get_proto()803 RunGraphResponse* NonOwnedProtoRunGraphResponse::get_proto() {
804   return response_;
805 }
806 
num_partition_graphs() const807 size_t NonOwnedProtoRunGraphResponse::num_partition_graphs() const {
808   return response_->partition_graph_size();
809 }
810 
mutable_partition_graph(size_t i)811 GraphDef* NonOwnedProtoRunGraphResponse::mutable_partition_graph(size_t i) {
812   return response_->mutable_partition_graph(i);
813 }
814 
AddPartitionGraph(const GraphDef & partition_graph)815 void NonOwnedProtoRunGraphResponse::AddPartitionGraph(
816     const GraphDef& partition_graph) {
817   GraphDef* graph_def = response_->add_partition_graph();
818   *graph_def = partition_graph;
819 }
820 
~MutableRunStepResponseWrapper()821 MutableRunStepResponseWrapper::~MutableRunStepResponseWrapper() {}
822 
num_tensors() const823 size_t InMemoryRunStepResponse::num_tensors() const { return tensors_.size(); }
824 
tensor_name(size_t i) const825 const string& InMemoryRunStepResponse::tensor_name(size_t i) const {
826   return tensors_[i].first;
827 }
828 
TensorValue(size_t i,Tensor * out_tensor) const829 Status InMemoryRunStepResponse::TensorValue(size_t i,
830                                             Tensor* out_tensor) const {
831   *out_tensor = tensors_[i].second;
832   return OkStatus();
833 }
834 
metadata() const835 const RunMetadata& InMemoryRunStepResponse::metadata() const {
836   return metadata_;
837 }
838 
AddTensorFromRunGraphResponse(const string & name,MutableRunGraphResponseWrapper * wrapper,size_t i)839 Status InMemoryRunStepResponse::AddTensorFromRunGraphResponse(
840     const string& name, MutableRunGraphResponseWrapper* wrapper, size_t i) {
841   Tensor tensor;
842   TF_RETURN_IF_ERROR(wrapper->RecvValue(i, &tensor));
843   tensors_.emplace_back(name, tensor);
844   return OkStatus();
845 }
846 
mutable_metadata()847 RunMetadata* InMemoryRunStepResponse::mutable_metadata() { return &metadata_; }
848 
status() const849 Status InMemoryRunStepResponse::status() const { return status_; }
850 
status_code() const851 errors::Code InMemoryRunStepResponse::status_code() const {
852   return status_.code();
853 }
854 
status_error_message() const855 const string& InMemoryRunStepResponse::status_error_message() const {
856   return status_.error_message();
857 }
858 
set_status(const Status & status)859 void InMemoryRunStepResponse::set_status(const Status& status) {
860   status_ = status;
861 }
862 
get_proto()863 RunStepResponse* InMemoryRunStepResponse::get_proto() {
864   LOG(FATAL) << "Cannot get a mutable protobuf for an InMemoryRunStepResponse";
865   return nullptr;
866 }
867 
num_tensors() const868 size_t OwnedProtoRunStepResponse::num_tensors() const {
869   return response_.tensor_size();
870 }
871 
tensor_name(size_t i) const872 const string& OwnedProtoRunStepResponse::tensor_name(size_t i) const {
873   return response_.tensor(i).name();
874 }
875 
TensorValue(size_t i,Tensor * out_tensor) const876 Status OwnedProtoRunStepResponse::TensorValue(size_t i,
877                                               Tensor* out_tensor) const {
878   if (!ParseTensorProtoToTensor(response_.tensor(i).tensor(), out_tensor)) {
879     return errors::InvalidArgument("Invalid TensorProto for fetch value ", i);
880   } else {
881     return OkStatus();
882   }
883 }
884 
metadata() const885 const RunMetadata& OwnedProtoRunStepResponse::metadata() const {
886   return response_.metadata();
887 }
888 
AddTensorFromRunGraphResponse(const string & name,MutableRunGraphResponseWrapper * run_graph_response,size_t i)889 Status OwnedProtoRunStepResponse::AddTensorFromRunGraphResponse(
890     const string& name, MutableRunGraphResponseWrapper* run_graph_response,
891     size_t i) {
892   NamedTensorProto* response_tensor = response_.add_tensor();
893   response_tensor->set_name(name);
894   return run_graph_response->RecvValue(i, response_tensor->mutable_tensor());
895 }
896 
mutable_metadata()897 RunMetadata* OwnedProtoRunStepResponse::mutable_metadata() {
898   return response_.mutable_metadata();
899 }
900 
status() const901 Status OwnedProtoRunStepResponse::status() const {
902   return Status(response_.status_code(), response_.status_error_message());
903 }
904 
status_code() const905 errors::Code OwnedProtoRunStepResponse::status_code() const {
906   return response_.status_code();
907 }
908 
status_error_message() const909 const string& OwnedProtoRunStepResponse::status_error_message() const {
910   return response_.status_error_message();
911 }
912 
set_status(const Status & status)913 void OwnedProtoRunStepResponse::set_status(const Status& status) {
914   response_.set_status_code(status.code());
915   response_.set_status_error_message(status.error_message());
916 }
917 
get_proto()918 RunStepResponse* OwnedProtoRunStepResponse::get_proto() { return &response_; }
919 
NonOwnedProtoRunStepResponse(RunStepResponse * response)920 NonOwnedProtoRunStepResponse::NonOwnedProtoRunStepResponse(
921     RunStepResponse* response)
922     : response_(response) {}
923 
num_tensors() const924 size_t NonOwnedProtoRunStepResponse::num_tensors() const {
925   return response_->tensor_size();
926 }
927 
tensor_name(size_t i) const928 const string& NonOwnedProtoRunStepResponse::tensor_name(size_t i) const {
929   return response_->tensor(i).name();
930 }
931 
TensorValue(size_t i,Tensor * out_tensor) const932 Status NonOwnedProtoRunStepResponse::TensorValue(size_t i,
933                                                  Tensor* out_tensor) const {
934   if (!ParseTensorProtoToTensor(response_->tensor(i).tensor(), out_tensor)) {
935     return errors::InvalidArgument("Invalid TensorProto for fetch value ", i);
936   } else {
937     return OkStatus();
938   }
939 }
940 
metadata() const941 const RunMetadata& NonOwnedProtoRunStepResponse::metadata() const {
942   return response_->metadata();
943 }
944 
AddTensorFromRunGraphResponse(const string & name,MutableRunGraphResponseWrapper * run_graph_response,size_t i)945 Status NonOwnedProtoRunStepResponse::AddTensorFromRunGraphResponse(
946     const string& name, MutableRunGraphResponseWrapper* run_graph_response,
947     size_t i) {
948   NamedTensorProto* response_tensor = response_->add_tensor();
949   response_tensor->set_name(name);
950   return run_graph_response->RecvValue(i, response_tensor->mutable_tensor());
951 }
952 
mutable_metadata()953 RunMetadata* NonOwnedProtoRunStepResponse::mutable_metadata() {
954   return response_->mutable_metadata();
955 }
956 
status() const957 Status NonOwnedProtoRunStepResponse::status() const {
958   return Status(response_->status_code(), response_->status_error_message());
959 }
960 
status_code() const961 errors::Code NonOwnedProtoRunStepResponse::status_code() const {
962   return response_->status_code();
963 }
964 
status_error_message() const965 const string& NonOwnedProtoRunStepResponse::status_error_message() const {
966   return response_->status_error_message();
967 }
968 
set_status(const Status & status)969 void NonOwnedProtoRunStepResponse::set_status(const Status& status) {
970   response_->set_status_code(status.code());
971   response_->set_status_error_message(status.error_message());
972 }
973 
get_proto()974 RunStepResponse* NonOwnedProtoRunStepResponse::get_proto() { return response_; }
975 
976 }  // namespace tensorflow
977