xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/PrefixStore.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/distributed/c10d/Store.hpp>
4 
5 namespace c10d {
6 
7 class TORCH_API PrefixStore : public Store {
8  public:
9   explicit PrefixStore(std::string prefix, c10::intrusive_ptr<Store> store);
10 
11   using Store::set;
12   void set(const std::string& key, const std::vector<uint8_t>& value) override;
13 
14   using Store::compareSet;
15   std::vector<uint8_t> compareSet(
16       const std::string& key,
17       const std::vector<uint8_t>& expectedValue,
18       const std::vector<uint8_t>& desiredValue) override;
19 
20   std::vector<uint8_t> get(const std::string& key) override;
21 
22   int64_t add(const std::string& key, int64_t value) override;
23 
24   bool deleteKey(const std::string& key) override;
25 
26   int64_t getNumKeys() override;
27 
28   bool check(const std::vector<std::string>& keys) override;
29 
30   void wait(const std::vector<std::string>& keys) override;
31 
32   void wait(
33       const std::vector<std::string>& keys,
34       const std::chrono::milliseconds& timeout) override;
35 
36   const std::chrono::milliseconds& getTimeout() const noexcept override;
37 
38   void setTimeout(const std::chrono::milliseconds& timeout) override;
39 
40   void append(const std::string& key, const std::vector<uint8_t>& value)
41       override;
42 
43   std::vector<std::vector<uint8_t>> multiGet(
44       const std::vector<std::string>& keys) override;
45 
46   void multiSet(
47       const std::vector<std::string>& keys,
48       const std::vector<std::vector<uint8_t>>& values) override;
49 
50   // Returns true if this store support append, multiGet and multiSet
51   bool hasExtendedApi() const override;
52 
53   c10::intrusive_ptr<Store> getUnderlyingStore();
54 
55   // Recursively to fetch the store before layers of wrapping with PrefixStore.
56   c10::intrusive_ptr<Store> getUnderlyingNonPrefixStore();
57 
58  protected:
59   std::string prefix_;
60   c10::intrusive_ptr<Store> store_;
61 
62   std::string joinKey(const std::string& key);
63   std::vector<std::string> joinKeys(const std::vector<std::string>& keys);
64 };
65 
66 } // namespace c10d
67