xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/runtime/default/async_values_cache.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef XLA_RUNTIME_DEFAULT_ASYNC_VALUES_CACHE_H_
17 #define XLA_RUNTIME_DEFAULT_ASYNC_VALUES_CACHE_H_
18 
19 #include "absl/synchronization/mutex.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "tfrt/host_context/async_value_ref.h"  // from @tf_runtime
22 #include "tfrt/host_context/chain.h"  // from @tf_runtime
23 
24 namespace xla {
25 namespace runtime {
26 
27 using tfrt::AsyncValue;
28 using tfrt::AsyncValuePtr;
29 using tfrt::AsyncValueRef;
30 using tfrt::Chain;
31 using tfrt::MakeConstructedAsyncValueRef;
32 using tfrt::MakeUnconstructedAsyncValueRef;
33 
34 template <typename Key, typename Value>
35 class AsyncValuesCache {
36  public:
37   struct Entry;
38 
39   AsyncValuesCache() = default;
40 
41   // Returns a pointer to the cached value if it exists, otherwise returns
42   // nullptr. It is the caller's responsibility to form an async reference and
43   // extend its lifetime if the lifetime of the cached async value can be
44   // larger than the lifetime of the cache.
45   AsyncValuePtr<Value> Find(Key key) const;
46 
47   // Allocates an async value in the unconstructed state to store the cached
48   // value with the given key.
49   //
50   // The `entry.allocated` value is `true` if the new async value was allocated,
51   // and the caller is responsible for eventually setting the error or emplacing
52   // the value. If it is false, then it means that the storage was already
53   // allocated, and someone else will eventually update it.
54   //
55   // The returned `entry.size` value is equal to the size of the cache. If the
56   // new async value was allocated, it will be reflected in the size.
57   Entry Allocate(Key key);
58 
59   // Returns an async value that becomes available once all entries added to
60   // the cache are available.
61   AsyncValueRef<Chain> AllAvailable() const;
62 
63   struct Entry {
64     AsyncValuePtr<Value> ptr;
65     bool allocated;
66     size_t size;
67   };
68 
69  private:
70   mutable absl::Mutex mu_;
71   llvm::DenseMap<Key, AsyncValueRef<Value>> cache_ ABSL_GUARDED_BY(mu_);
72 };
73 
74 template <typename Key, typename Value>
Find(Key key)75 AsyncValuePtr<Value> AsyncValuesCache<Key, Value>::Find(Key key) const {
76   absl::MutexLock lock(&mu_);
77   auto it = cache_.find(key);
78   return it != cache_.end() ? it->getSecond().AsPtr() : AsyncValuePtr<Value>();
79 }
80 
81 template <typename Key, typename Value>
82 auto AsyncValuesCache<Key, Value>::Allocate(Key key) -> Entry {
83   absl::MutexLock lock(&mu_);
84   auto it = cache_.find(key);
85   if (it != cache_.end())
86     return {it->getSecond().AsPtr(), false, cache_.size()};
87 
88   AsyncValueRef<Value> allocated = MakeUnconstructedAsyncValueRef<Value>();
89 
90   auto emplaced = cache_.try_emplace(key, std::move(allocated));
91   assert(emplaced.second && "emplace must be successful");
92   return {emplaced.first->getSecond().AsPtr(), true, cache_.size()};
93 }
94 
95 template <typename Key, typename Value>
AllAvailable()96 AsyncValueRef<Chain> AsyncValuesCache<Key, Value>::AllAvailable() const {
97   absl::MutexLock lock(&mu_);
98 
99   llvm::SmallVector<AsyncValue*> avs;
100   for (auto& it : cache_) avs.push_back(it.getSecond().GetAsyncValue());
101 
102   AsyncValueRef<Chain> chain = MakeConstructedAsyncValueRef<Chain>();
103   RunWhenReady(avs, [chain]() { chain.SetStateConcrete(); });
104   return chain;
105 }
106 
107 }  // namespace runtime
108 }  // namespace xla
109 
110 #endif  // XLA_RUNTIME_DEFAULT_ASYNC_VALUES_CACHE_H_
111