1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RESPONSE_CACHE_H_ 16 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RESPONSE_CACHE_H_ 17 18 #include <memory> 19 #include <unordered_map> 20 #include <vector> 21 22 #include "tensorflow/core/framework/tensor.h" 23 #include "tensorflow/core/lib/core/errors.h" 24 #include "tensorflow/core/lib/core/status.h" 25 #include "tensorflow/core/lib/gtl/flatmap.h" 26 #include "tensorflow/core/platform/mutex.h" 27 28 // gRPC response caching. Most WorkerService methods cannot be retried directly 29 // as they will fail or deadlock. To enable retrying, we can instead cache 30 // responses and reply to duplicate requests from the cache. The cache will be 31 // cleaned when the MarkRecvFinishedRequest is received from the receiver or the 32 // session step is completed. 33 namespace tensorflow { 34 35 // Track and cache the state of worker service RPCs. An RPC can be in 3 states: 36 // 37 // * PENDING: this is the first call of the RPC, and it will transition to 38 // * ACTIVE: another thread is active processing this RPC 39 // * FINISHED: the worker has finished processing the method 40 41 class GrpcResponseCache { 42 public: 43 using FinishResponseCB = std::function<void( 44 const Tensor& tensor, bool is_dead, const Status& status)>; 45 46 // Add the given request to the cache. 47 // If the request is in the cache, 48 // If it is finished, invoke `cb` immediately 49 // If active, cb will be invoked when the current call completes. 50 // In either case, return true. 51 // Otherwise, store the request and cb in the cache, and return false. 52 // Note FinishResponseCB is assumed to be thread-safe. 53 bool QueueRequest(int64_t request_id, int64_t step_id, 54 const FinishResponseCB& cb); 55 56 // Fill the response cache for the given request_id and respond to all 57 // pending request. 58 void OnRequestFinished(int64_t request_id, const Tensor& tensor, bool is_dead, 59 const Status& status); 60 61 // Erase the cache entry with the given request_id 62 void EraseRequestId(int64_t request_id); 63 64 // Erase cache entries with the given step_id 65 void CleanEntriesForStep(int64_t step_id); 66 67 private: 68 struct ResponseCacheEntry { 69 enum class State { 70 PENDING = 0, 71 ACTIVE = 1, 72 FINISHED = 2, 73 }; 74 75 State state = State::PENDING; 76 int64_t step_id = -1; 77 Tensor tensor; 78 bool is_dead = false; 79 Status response_status; 80 FinishResponseResponseCacheEntry81 void FinishResponse(const FinishResponseCB& cb) const { 82 cb(tensor, is_dead, response_status); 83 } 84 std::vector<FinishResponseCB> callbacks; 85 }; 86 87 mutex mu_; 88 // response_cache_ is expected to be small, as entries are cleared immediately 89 // on ack from the receiver. 90 gtl::FlatMap<int64_t, ResponseCacheEntry> response_cache_ TF_GUARDED_BY(mu_); 91 }; 92 93 } // namespace tensorflow 94 95 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RESPONSE_CACHE_H_ 96