1 #pragma once 2 3 #include <mutex> 4 #include <unordered_map> 5 6 #include <torch/csrc/distributed/autograd/context/context.h> 7 8 namespace torch { 9 namespace distributed { 10 namespace autograd { 11 12 // Singleton class per worker which is responsible for storing the distributed 13 // autograd context for each autograd pass and also cleans up data for an 14 // autograd pass once its done. 15 // 16 // Each autograd pass is assigned a unique autograd_context_id and all data for 17 // that pass (DistAutogradContext) is stored in this container indexed by the 18 // autograd_context_id. The autograd_context_id itself is a 64 bit globally 19 // unique id. The first 16 bits is the worker_id and the next 48 bits is an 20 // auto-incrementing id for each worker. 21 // 22 // This container is also responsible for maintaining a globally unique message 23 // id, which is used to associate send/recv autograd function pairs. The format 24 // is similar to the autograd_context_id where we have a 64 bit integer with 25 // first 16 bits being the worker id and next 48 bits are auto-incrementing. 26 class TORCH_API DistAutogradContainer { 27 public: 28 explicit DistAutogradContainer(uint32_t num_shards); 29 30 // One time initialization of the container. 31 static DistAutogradContainer& init(int64_t worker_id); 32 33 // Retrieve the singleton instance of the container, ensures we have 34 // initialized the container. 35 static DistAutogradContainer& getInstance(); 36 37 // Create a new context for a distributed autograd pass. 38 const ContextPtr newContext(); 39 40 // Clean up resources for a given context_id once the autograd pass is done. 41 // Sends RPC to other workers this worker knows about, telling them to clean 42 // up their context as well. Throws an exception if the context_id does not 43 // exist. 44 void releaseContext(int64_t context_id); 45 46 // Releases an autograd context if it is present on this node. Also sends RPC 47 // to other workers this worker knows about, telling them to clean up their 48 // context. Does nothing if it is not present. 49 void releaseContextIfPresent(int64_t context_id); 50 51 // Checks if the passed in context_id is valid. 52 void isValidContext(int64_t context_id); 53 54 // Retrieve the autograd context for a given context_id. 55 ContextPtr retrieveContext(int64_t context_id); 56 57 // Retrieves the currently active autograd context for the current thread. 58 ContextPtr currentContext(); 59 60 // Checks whether or not the current thread has a valid autograd context. 61 bool hasValidContext() const; 62 63 // Generate a new autograd_message_id for send/recv autograd functions. 64 int64_t newAutogradMessageId(); 65 66 // Creates a new autograd context with the provided context_id. If a context 67 // already exists with the provided context_id, we just return it. 68 // This does not set the current context for the current thread. 69 ContextPtr getOrCreateContext(int64_t context_id); 70 71 // Retrieves the maximum possible autograd_context_id/autograd_message_id that 72 // can be generated by this worker. 73 int64_t getMaxId(); 74 75 // Retrieves the worker ID for this node 76 rpc::worker_id_t getWorkerId() const; 77 78 // Can set current context id if there is no valid context yet 79 static void setCurrentContextId(int64_t contextId); 80 81 // Forcibly sets the thread local current context id. Should only be used in 82 // cases where you know what you're doing and need to override the thread 83 // local. Otherwise, use setCurrentContextId instead. 84 static void forceCurrentContextId(int64_t contextId); 85 86 // Clear current context id 87 void clearCurrentContext(); 88 89 // Returns the number of autograd contexts in the container. 90 size_t numAutogradContexts() const; 91 92 // Returns the current thread local context id for this thread. 93 static int64_t currentContextId(); 94 95 DistAutogradContainer(const DistAutogradContainer&) = delete; 96 DistAutogradContainer& operator=(const DistAutogradContainer&) = delete; 97 DistAutogradContainer(DistAutogradContainer&&) = delete; 98 DistAutogradContainer& operator=(DistAutogradContainer&&) = delete; 99 100 private: 101 // Number of shards for the map storing autograd contexts. We'd like this 102 // to be a power of 2 and we don't expect a value much higher than the 103 // number of cores would provide much benefit. 104 static constexpr uint32_t kNumDefaultShards = 128; 105 106 // Use cache line size for alignment. 107 static constexpr int kCacheLineSize = 64; 108 109 // Structure holding one shard of the sharded autograd context map with its 110 // associated lock. Align to cache line size to avoid contention between 111 // adjacent entries. 112 struct alignas(kCacheLineSize) ContextsShard { 113 // Lock for this shard. 114 mutable std::mutex lock; 115 116 // Map storing autograd contexts for this shard. 117 std::unordered_map<int64_t, ContextPtr> contexts; 118 }; 119 120 DistAutogradContainer() = delete; 121 ~DistAutogradContainer() = default; 122 123 static DistAutogradContainer& getInstanceInternal(); 124 125 // Retrieve the shard for given context_id. 126 ContextsShard& getShard(int64_t context_id); 127 128 // Sends an RPC to the workers that have a context corresponding to passed in 129 // context_id. This function should be called with the lock. 130 void sendReleaseContextRpc( 131 const std::unordered_set<rpc::worker_id_t>& workerIds, 132 int64_t context_id); 133 134 // Erase context_id from the autograd context map, and reset the thread local 135 // current context id if it corresponds to the passed in context id. This 136 // function should be called with the lock. 137 void eraseContextIdAndReset(ContextsShard& shard, int64_t context_id); 138 139 // Compute the number of shards for the autograd_contexts_ map. 140 static uint32_t computeNumShards(); 141 142 // Auto incrementing context id used to identify unique autograd passes. 143 // Initialized with the first 16 bits being the worker_id. 144 std::atomic<int64_t> next_context_id_; 145 146 // Unique id to identify a worker in the distributed setting. 147 int16_t worker_id_; 148 149 // Whether or not the container has been initialized appropriately. 150 bool initialized_; 151 152 // Sharded autograd context map. 153 std::vector<ContextsShard> autograd_contexts_; 154 155 // Number of shards for the sharded autograd_contexts_ map. 156 uint32_t num_shards_; 157 158 // Autograd message id to identify unique send/recv autograd function pairs. 159 std::atomic<int64_t> next_autograd_message_id_; 160 161 // Maximum allowed value for autograd_context_id or autograd_message_id. 162 int64_t max_id_; 163 }; 164 165 } // namespace autograd 166 } // namespace distributed 167 } // namespace torch 168