xref: /aosp_15_r20/external/tensorflow/tensorflow/core/distributed_runtime/rpc/grpc_response_cache.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RESPONSE_CACHE_H_
16 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RESPONSE_CACHE_H_
17 
18 #include <memory>
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/lib/gtl/flatmap.h"
26 #include "tensorflow/core/platform/mutex.h"
27 
28 // gRPC response caching.  Most WorkerService methods cannot be retried directly
29 // as they will fail or deadlock.  To enable retrying, we can instead cache
30 // responses and reply to duplicate requests from the cache. The cache will be
31 // cleaned when the MarkRecvFinishedRequest is received from the receiver or the
32 // session step is completed.
33 namespace tensorflow {
34 
35 // Track and cache the state of worker service RPCs.  An RPC can be in 3 states:
36 //
37 // * PENDING: this is the first call of the RPC, and it will transition to
38 // * ACTIVE: another thread is active processing this RPC
39 // * FINISHED: the worker has finished processing the method
40 
41 class GrpcResponseCache {
42  public:
43   using FinishResponseCB = std::function<void(
44       const Tensor& tensor, bool is_dead, const Status& status)>;
45 
46   // Add the given request to the cache.
47   // If the request is in the cache,
48   //    If it is finished, invoke `cb` immediately
49   //    If active, cb will be invoked when the current call completes.
50   //    In either case, return true.
51   // Otherwise, store the request and cb in the cache, and return false.
52   // Note FinishResponseCB is assumed to be thread-safe.
53   bool QueueRequest(int64_t request_id, int64_t step_id,
54                     const FinishResponseCB& cb);
55 
56   // Fill the response cache for the given request_id and respond to all
57   // pending request.
58   void OnRequestFinished(int64_t request_id, const Tensor& tensor, bool is_dead,
59                          const Status& status);
60 
61   // Erase the cache entry with the given request_id
62   void EraseRequestId(int64_t request_id);
63 
64   // Erase cache entries with the given step_id
65   void CleanEntriesForStep(int64_t step_id);
66 
67  private:
68   struct ResponseCacheEntry {
69     enum class State {
70       PENDING = 0,
71       ACTIVE = 1,
72       FINISHED = 2,
73     };
74 
75     State state = State::PENDING;
76     int64_t step_id = -1;
77     Tensor tensor;
78     bool is_dead = false;
79     Status response_status;
80 
FinishResponseResponseCacheEntry81     void FinishResponse(const FinishResponseCB& cb) const {
82       cb(tensor, is_dead, response_status);
83     }
84     std::vector<FinishResponseCB> callbacks;
85   };
86 
87   mutex mu_;
88   // response_cache_ is expected to be small, as entries are cleared immediately
89   // on ack from the receiver.
90   gtl::FlatMap<int64_t, ResponseCacheEntry> response_cache_ TF_GUARDED_BY(mu_);
91 };
92 
93 }  // namespace tensorflow
94 
95 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RESPONSE_CACHE_H_
96