1 #pragma once 2 3 #include <torch/csrc/distributed/c10d/Store.hpp> 4 5 namespace c10d { 6 7 class TORCH_API PrefixStore : public Store { 8 public: 9 explicit PrefixStore(std::string prefix, c10::intrusive_ptr<Store> store); 10 11 using Store::set; 12 void set(const std::string& key, const std::vector<uint8_t>& value) override; 13 14 using Store::compareSet; 15 std::vector<uint8_t> compareSet( 16 const std::string& key, 17 const std::vector<uint8_t>& expectedValue, 18 const std::vector<uint8_t>& desiredValue) override; 19 20 std::vector<uint8_t> get(const std::string& key) override; 21 22 int64_t add(const std::string& key, int64_t value) override; 23 24 bool deleteKey(const std::string& key) override; 25 26 int64_t getNumKeys() override; 27 28 bool check(const std::vector<std::string>& keys) override; 29 30 void wait(const std::vector<std::string>& keys) override; 31 32 void wait( 33 const std::vector<std::string>& keys, 34 const std::chrono::milliseconds& timeout) override; 35 36 const std::chrono::milliseconds& getTimeout() const noexcept override; 37 38 void setTimeout(const std::chrono::milliseconds& timeout) override; 39 40 void append(const std::string& key, const std::vector<uint8_t>& value) 41 override; 42 43 std::vector<std::vector<uint8_t>> multiGet( 44 const std::vector<std::string>& keys) override; 45 46 void multiSet( 47 const std::vector<std::string>& keys, 48 const std::vector<std::vector<uint8_t>>& values) override; 49 50 // Returns true if this store support append, multiGet and multiSet 51 bool hasExtendedApi() const override; 52 53 c10::intrusive_ptr<Store> getUnderlyingStore(); 54 55 // Recursively to fetch the store before layers of wrapping with PrefixStore. 56 c10::intrusive_ptr<Store> getUnderlyingNonPrefixStore(); 57 58 protected: 59 std::string prefix_; 60 c10::intrusive_ptr<Store> store_; 61 62 std::string joinKey(const std::string& key); 63 std::vector<std::string> joinKeys(const std::vector<std::string>& keys); 64 }; 65 66 } // namespace c10d 67