xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/autograd/context/container.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <mutex>
4 #include <unordered_map>
5 
6 #include <torch/csrc/distributed/autograd/context/context.h>
7 
8 namespace torch {
9 namespace distributed {
10 namespace autograd {
11 
12 // Singleton class per worker which is responsible for storing the distributed
13 // autograd context for each autograd pass and also cleans up data for an
14 // autograd pass once its done.
15 //
16 // Each autograd pass is assigned a unique autograd_context_id and all data for
17 // that pass (DistAutogradContext) is stored in this container indexed by the
18 // autograd_context_id. The autograd_context_id itself is a 64 bit globally
19 // unique id. The first 16 bits is the worker_id and the next 48 bits is an
20 // auto-incrementing id for each worker.
21 //
22 // This container is also responsible for maintaining a globally unique message
23 // id, which is used to associate send/recv autograd function pairs. The format
24 // is similar to the autograd_context_id where we have a 64 bit integer with
25 // first 16 bits being the worker id and next 48 bits are auto-incrementing.
26 class TORCH_API DistAutogradContainer {
27  public:
28   explicit DistAutogradContainer(uint32_t num_shards);
29 
30   // One time initialization of the container.
31   static DistAutogradContainer& init(int64_t worker_id);
32 
33   // Retrieve the singleton instance of the container, ensures we have
34   // initialized the container.
35   static DistAutogradContainer& getInstance();
36 
37   // Create a new context for a distributed autograd pass.
38   const ContextPtr newContext();
39 
40   // Clean up resources for a given context_id once the autograd pass is done.
41   // Sends RPC to other workers this worker knows about, telling them to clean
42   // up their context as well. Throws an exception if the context_id does not
43   // exist.
44   void releaseContext(int64_t context_id);
45 
46   // Releases an autograd context if it is present on this node. Also sends RPC
47   // to other workers this worker knows about, telling them to clean up their
48   // context. Does nothing if it is not present.
49   void releaseContextIfPresent(int64_t context_id);
50 
51   // Checks if the passed in context_id is valid.
52   void isValidContext(int64_t context_id);
53 
54   // Retrieve the autograd context for a given context_id.
55   ContextPtr retrieveContext(int64_t context_id);
56 
57   // Retrieves the currently active autograd context for the current thread.
58   ContextPtr currentContext();
59 
60   // Checks whether or not the current thread has a valid autograd context.
61   bool hasValidContext() const;
62 
63   // Generate a new autograd_message_id for send/recv autograd functions.
64   int64_t newAutogradMessageId();
65 
66   // Creates a new autograd context with the provided context_id. If a context
67   // already exists with the provided context_id, we just return it.
68   // This does not set the current context for the current thread.
69   ContextPtr getOrCreateContext(int64_t context_id);
70 
71   // Retrieves the maximum possible autograd_context_id/autograd_message_id that
72   // can be generated by this worker.
73   int64_t getMaxId();
74 
75   // Retrieves the worker ID for this node
76   rpc::worker_id_t getWorkerId() const;
77 
78   // Can set current context id if there is no valid context yet
79   static void setCurrentContextId(int64_t contextId);
80 
81   // Forcibly sets the thread local current context id. Should only be used in
82   // cases where you know what you're doing and need to override the thread
83   // local. Otherwise, use setCurrentContextId instead.
84   static void forceCurrentContextId(int64_t contextId);
85 
86   // Clear current context id
87   void clearCurrentContext();
88 
89   // Returns the number of autograd contexts in the container.
90   size_t numAutogradContexts() const;
91 
92   // Returns the current thread local context id for this thread.
93   static int64_t currentContextId();
94 
95   DistAutogradContainer(const DistAutogradContainer&) = delete;
96   DistAutogradContainer& operator=(const DistAutogradContainer&) = delete;
97   DistAutogradContainer(DistAutogradContainer&&) = delete;
98   DistAutogradContainer& operator=(DistAutogradContainer&&) = delete;
99 
100  private:
101   // Number of shards for the map storing autograd contexts. We'd like this
102   // to be a power of 2 and we don't expect a value much higher than the
103   // number of cores would provide much benefit.
104   static constexpr uint32_t kNumDefaultShards = 128;
105 
106   // Use cache line size for alignment.
107   static constexpr int kCacheLineSize = 64;
108 
109   // Structure holding one shard of the sharded autograd context map with its
110   // associated lock. Align to cache line size to avoid contention between
111   // adjacent entries.
112   struct alignas(kCacheLineSize) ContextsShard {
113     // Lock for this shard.
114     mutable std::mutex lock;
115 
116     // Map storing autograd contexts for this shard.
117     std::unordered_map<int64_t, ContextPtr> contexts;
118   };
119 
120   DistAutogradContainer() = delete;
121   ~DistAutogradContainer() = default;
122 
123   static DistAutogradContainer& getInstanceInternal();
124 
125   // Retrieve the shard for given context_id.
126   ContextsShard& getShard(int64_t context_id);
127 
128   // Sends an RPC to the workers that have a context corresponding to passed in
129   // context_id. This function should be called with the lock.
130   void sendReleaseContextRpc(
131       const std::unordered_set<rpc::worker_id_t>& workerIds,
132       int64_t context_id);
133 
134   // Erase context_id from the autograd context map, and reset the thread local
135   // current context id if it corresponds to the passed in context id. This
136   // function should be called with the lock.
137   void eraseContextIdAndReset(ContextsShard& shard, int64_t context_id);
138 
139   // Compute the number of shards for the autograd_contexts_ map.
140   static uint32_t computeNumShards();
141 
142   // Auto incrementing context id used to identify unique autograd passes.
143   // Initialized with the first 16 bits being the worker_id.
144   std::atomic<int64_t> next_context_id_;
145 
146   // Unique id to identify a worker in the distributed setting.
147   int16_t worker_id_;
148 
149   // Whether or not the container has been initialized appropriately.
150   bool initialized_;
151 
152   // Sharded autograd context map.
153   std::vector<ContextsShard> autograd_contexts_;
154 
155   // Number of shards for the sharded autograd_contexts_ map.
156   uint32_t num_shards_;
157 
158   // Autograd message id to identify unique send/recv autograd function pairs.
159   std::atomic<int64_t> next_autograd_message_id_;
160 
161   // Maximum allowed value for autograd_context_id or autograd_message_id.
162   int64_t max_id_;
163 };
164 
165 } // namespace autograd
166 } // namespace distributed
167 } // namespace torch
168