xref: /aosp_15_r20/external/pytorch/c10/cuda/CUDAAllocatorConfig.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAAllocatorConfig.h>
2*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDACachingAllocator.h>
3*da0073e9SAndroid Build Coastguard Worker #include <c10/util/llvmMathExtras.h>
4*da0073e9SAndroid Build Coastguard Worker 
5*da0073e9SAndroid Build Coastguard Worker #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
6*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/driver_api.h>
7*da0073e9SAndroid Build Coastguard Worker #endif
8*da0073e9SAndroid Build Coastguard Worker 
9*da0073e9SAndroid Build Coastguard Worker namespace c10::cuda::CUDACachingAllocator {
10*da0073e9SAndroid Build Coastguard Worker 
11*da0073e9SAndroid Build Coastguard Worker constexpr size_t kRoundUpPowerOfTwoIntervals = 16;
12*da0073e9SAndroid Build Coastguard Worker 
CUDAAllocatorConfig()13*da0073e9SAndroid Build Coastguard Worker CUDAAllocatorConfig::CUDAAllocatorConfig()
14*da0073e9SAndroid Build Coastguard Worker     : m_max_split_size(std::numeric_limits<size_t>::max()),
15*da0073e9SAndroid Build Coastguard Worker       m_garbage_collection_threshold(0),
16*da0073e9SAndroid Build Coastguard Worker       m_pinned_num_register_threads(1),
17*da0073e9SAndroid Build Coastguard Worker       m_expandable_segments(false),
18*da0073e9SAndroid Build Coastguard Worker       m_release_lock_on_cudamalloc(false),
19*da0073e9SAndroid Build Coastguard Worker       m_pinned_use_cuda_host_register(false),
20*da0073e9SAndroid Build Coastguard Worker       m_last_allocator_settings("") {
21*da0073e9SAndroid Build Coastguard Worker   m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0);
22*da0073e9SAndroid Build Coastguard Worker }
23*da0073e9SAndroid Build Coastguard Worker 
roundup_power2_divisions(size_t size)24*da0073e9SAndroid Build Coastguard Worker size_t CUDAAllocatorConfig::roundup_power2_divisions(size_t size) {
25*da0073e9SAndroid Build Coastguard Worker   size_t log_size = (63 - llvm::countLeadingZeros(size));
26*da0073e9SAndroid Build Coastguard Worker 
27*da0073e9SAndroid Build Coastguard Worker   // Our intervals start at 1MB and end at 64GB
28*da0073e9SAndroid Build Coastguard Worker   const size_t interval_start =
29*da0073e9SAndroid Build Coastguard Worker       63 - llvm::countLeadingZeros(static_cast<size_t>(1048576));
30*da0073e9SAndroid Build Coastguard Worker   const size_t interval_end =
31*da0073e9SAndroid Build Coastguard Worker       63 - llvm::countLeadingZeros(static_cast<size_t>(68719476736));
32*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
33*da0073e9SAndroid Build Coastguard Worker       (interval_end - interval_start == kRoundUpPowerOfTwoIntervals),
34*da0073e9SAndroid Build Coastguard Worker       "kRoundUpPowerOfTwoIntervals mismatch");
35*da0073e9SAndroid Build Coastguard Worker 
36*da0073e9SAndroid Build Coastguard Worker   int index = static_cast<int>(log_size) - static_cast<int>(interval_start);
37*da0073e9SAndroid Build Coastguard Worker 
38*da0073e9SAndroid Build Coastguard Worker   index = std::max(0, index);
39*da0073e9SAndroid Build Coastguard Worker   index = std::min(index, static_cast<int>(kRoundUpPowerOfTwoIntervals) - 1);
40*da0073e9SAndroid Build Coastguard Worker   return instance().m_roundup_power2_divisions[index];
41*da0073e9SAndroid Build Coastguard Worker }
42*da0073e9SAndroid Build Coastguard Worker 
lexArgs(const char * env,std::vector<std::string> & config)43*da0073e9SAndroid Build Coastguard Worker void CUDAAllocatorConfig::lexArgs(
44*da0073e9SAndroid Build Coastguard Worker     const char* env,
45*da0073e9SAndroid Build Coastguard Worker     std::vector<std::string>& config) {
46*da0073e9SAndroid Build Coastguard Worker   std::vector<char> buf;
47*da0073e9SAndroid Build Coastguard Worker 
48*da0073e9SAndroid Build Coastguard Worker   size_t env_length = strlen(env);
49*da0073e9SAndroid Build Coastguard Worker   for (size_t i = 0; i < env_length; i++) {
50*da0073e9SAndroid Build Coastguard Worker     if (env[i] == ',' || env[i] == ':' || env[i] == '[' || env[i] == ']') {
51*da0073e9SAndroid Build Coastguard Worker       if (!buf.empty()) {
52*da0073e9SAndroid Build Coastguard Worker         config.emplace_back(buf.begin(), buf.end());
53*da0073e9SAndroid Build Coastguard Worker         buf.clear();
54*da0073e9SAndroid Build Coastguard Worker       }
55*da0073e9SAndroid Build Coastguard Worker       config.emplace_back(1, env[i]);
56*da0073e9SAndroid Build Coastguard Worker     } else if (env[i] != ' ') {
57*da0073e9SAndroid Build Coastguard Worker       buf.emplace_back(static_cast<char>(env[i]));
58*da0073e9SAndroid Build Coastguard Worker     }
59*da0073e9SAndroid Build Coastguard Worker   }
60*da0073e9SAndroid Build Coastguard Worker   if (!buf.empty()) {
61*da0073e9SAndroid Build Coastguard Worker     config.emplace_back(buf.begin(), buf.end());
62*da0073e9SAndroid Build Coastguard Worker   }
63*da0073e9SAndroid Build Coastguard Worker }
64*da0073e9SAndroid Build Coastguard Worker 
consumeToken(const std::vector<std::string> & config,size_t i,const char c)65*da0073e9SAndroid Build Coastguard Worker void CUDAAllocatorConfig::consumeToken(
66*da0073e9SAndroid Build Coastguard Worker     const std::vector<std::string>& config,
67*da0073e9SAndroid Build Coastguard Worker     size_t i,
68*da0073e9SAndroid Build Coastguard Worker     const char c) {
69*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
70*da0073e9SAndroid Build Coastguard Worker       i < config.size() && config[i] == std::string(1, c),
71*da0073e9SAndroid Build Coastguard Worker       "Error parsing CachingAllocator settings, expected ",
72*da0073e9SAndroid Build Coastguard Worker       c,
73*da0073e9SAndroid Build Coastguard Worker       "");
74*da0073e9SAndroid Build Coastguard Worker }
75*da0073e9SAndroid Build Coastguard Worker 
parseMaxSplitSize(const std::vector<std::string> & config,size_t i)76*da0073e9SAndroid Build Coastguard Worker size_t CUDAAllocatorConfig::parseMaxSplitSize(
77*da0073e9SAndroid Build Coastguard Worker     const std::vector<std::string>& config,
78*da0073e9SAndroid Build Coastguard Worker     size_t i) {
79*da0073e9SAndroid Build Coastguard Worker   consumeToken(config, ++i, ':');
80*da0073e9SAndroid Build Coastguard Worker   constexpr int mb = 1024 * 1024;
81*da0073e9SAndroid Build Coastguard Worker   if (++i < config.size()) {
82*da0073e9SAndroid Build Coastguard Worker     size_t val1 = stoi(config[i]);
83*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(
84*da0073e9SAndroid Build Coastguard Worker         val1 > kLargeBuffer / mb,
85*da0073e9SAndroid Build Coastguard Worker         "CachingAllocator option max_split_size_mb too small, must be > ",
86*da0073e9SAndroid Build Coastguard Worker         kLargeBuffer / mb,
87*da0073e9SAndroid Build Coastguard Worker         "");
88*da0073e9SAndroid Build Coastguard Worker     val1 = std::max(val1, kLargeBuffer / mb);
89*da0073e9SAndroid Build Coastguard Worker     val1 = std::min(val1, (std::numeric_limits<size_t>::max() / mb));
90*da0073e9SAndroid Build Coastguard Worker     m_max_split_size = val1 * 1024 * 1024;
91*da0073e9SAndroid Build Coastguard Worker   } else {
92*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(false, "Error, expecting max_split_size_mb value", "");
93*da0073e9SAndroid Build Coastguard Worker   }
94*da0073e9SAndroid Build Coastguard Worker   return i;
95*da0073e9SAndroid Build Coastguard Worker }
96*da0073e9SAndroid Build Coastguard Worker 
parseGarbageCollectionThreshold(const std::vector<std::string> & config,size_t i)97*da0073e9SAndroid Build Coastguard Worker size_t CUDAAllocatorConfig::parseGarbageCollectionThreshold(
98*da0073e9SAndroid Build Coastguard Worker     const std::vector<std::string>& config,
99*da0073e9SAndroid Build Coastguard Worker     size_t i) {
100*da0073e9SAndroid Build Coastguard Worker   consumeToken(config, ++i, ':');
101*da0073e9SAndroid Build Coastguard Worker   if (++i < config.size()) {
102*da0073e9SAndroid Build Coastguard Worker     double val1 = stod(config[i]);
103*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(
104*da0073e9SAndroid Build Coastguard Worker         val1 > 0, "garbage_collect_threshold too small, set it 0.0~1.0", "");
105*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(
106*da0073e9SAndroid Build Coastguard Worker         val1 < 1.0, "garbage_collect_threshold too big, set it 0.0~1.0", "");
107*da0073e9SAndroid Build Coastguard Worker     m_garbage_collection_threshold = val1;
108*da0073e9SAndroid Build Coastguard Worker   } else {
109*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(
110*da0073e9SAndroid Build Coastguard Worker         false, "Error, expecting garbage_collection_threshold value", "");
111*da0073e9SAndroid Build Coastguard Worker   }
112*da0073e9SAndroid Build Coastguard Worker   return i;
113*da0073e9SAndroid Build Coastguard Worker }
114*da0073e9SAndroid Build Coastguard Worker 
parseRoundUpPower2Divisions(const std::vector<std::string> & config,size_t i)115*da0073e9SAndroid Build Coastguard Worker size_t CUDAAllocatorConfig::parseRoundUpPower2Divisions(
116*da0073e9SAndroid Build Coastguard Worker     const std::vector<std::string>& config,
117*da0073e9SAndroid Build Coastguard Worker     size_t i) {
118*da0073e9SAndroid Build Coastguard Worker   consumeToken(config, ++i, ':');
119*da0073e9SAndroid Build Coastguard Worker   bool first_value = true;
120*da0073e9SAndroid Build Coastguard Worker 
121*da0073e9SAndroid Build Coastguard Worker   if (++i < config.size()) {
122*da0073e9SAndroid Build Coastguard Worker     if (std::string_view(config[i]) == "[") {
123*da0073e9SAndroid Build Coastguard Worker       size_t last_index = 0;
124*da0073e9SAndroid Build Coastguard Worker       while (++i < config.size() && std::string_view(config[i]) != "]") {
125*da0073e9SAndroid Build Coastguard Worker         const std::string& val1 = config[i];
126*da0073e9SAndroid Build Coastguard Worker         size_t val2 = 0;
127*da0073e9SAndroid Build Coastguard Worker 
128*da0073e9SAndroid Build Coastguard Worker         consumeToken(config, ++i, ':');
129*da0073e9SAndroid Build Coastguard Worker         if (++i < config.size()) {
130*da0073e9SAndroid Build Coastguard Worker           val2 = stoi(config[i]);
131*da0073e9SAndroid Build Coastguard Worker         } else {
132*da0073e9SAndroid Build Coastguard Worker           TORCH_CHECK(
133*da0073e9SAndroid Build Coastguard Worker               false, "Error parsing roundup_power2_divisions value", "");
134*da0073e9SAndroid Build Coastguard Worker         }
135*da0073e9SAndroid Build Coastguard Worker         TORCH_CHECK(
136*da0073e9SAndroid Build Coastguard Worker             val2 == 0 || llvm::isPowerOf2_64(val2),
137*da0073e9SAndroid Build Coastguard Worker             "For roundups, the divisons has to be power of 2 or 0 to disable roundup ",
138*da0073e9SAndroid Build Coastguard Worker             "");
139*da0073e9SAndroid Build Coastguard Worker 
140*da0073e9SAndroid Build Coastguard Worker         if (std::string_view(val1) == ">") {
141*da0073e9SAndroid Build Coastguard Worker           std::fill(
142*da0073e9SAndroid Build Coastguard Worker               std::next(
143*da0073e9SAndroid Build Coastguard Worker                   m_roundup_power2_divisions.begin(),
144*da0073e9SAndroid Build Coastguard Worker                   static_cast<std::vector<unsigned long>::difference_type>(
145*da0073e9SAndroid Build Coastguard Worker                       last_index)),
146*da0073e9SAndroid Build Coastguard Worker               m_roundup_power2_divisions.end(),
147*da0073e9SAndroid Build Coastguard Worker               val2);
148*da0073e9SAndroid Build Coastguard Worker         } else {
149*da0073e9SAndroid Build Coastguard Worker           size_t val1_long = stoul(val1);
150*da0073e9SAndroid Build Coastguard Worker           TORCH_CHECK(
151*da0073e9SAndroid Build Coastguard Worker               llvm::isPowerOf2_64(val1_long),
152*da0073e9SAndroid Build Coastguard Worker               "For roundups, the intervals have to be power of 2 ",
153*da0073e9SAndroid Build Coastguard Worker               "");
154*da0073e9SAndroid Build Coastguard Worker 
155*da0073e9SAndroid Build Coastguard Worker           size_t index = 63 - llvm::countLeadingZeros(val1_long);
156*da0073e9SAndroid Build Coastguard Worker           index = std::max((size_t)0, index);
157*da0073e9SAndroid Build Coastguard Worker           index = std::min(index, m_roundup_power2_divisions.size() - 1);
158*da0073e9SAndroid Build Coastguard Worker 
159*da0073e9SAndroid Build Coastguard Worker           if (first_value) {
160*da0073e9SAndroid Build Coastguard Worker             std::fill(
161*da0073e9SAndroid Build Coastguard Worker                 m_roundup_power2_divisions.begin(),
162*da0073e9SAndroid Build Coastguard Worker                 std::next(
163*da0073e9SAndroid Build Coastguard Worker                     m_roundup_power2_divisions.begin(),
164*da0073e9SAndroid Build Coastguard Worker                     static_cast<std::vector<unsigned long>::difference_type>(
165*da0073e9SAndroid Build Coastguard Worker                         index)),
166*da0073e9SAndroid Build Coastguard Worker                 val2);
167*da0073e9SAndroid Build Coastguard Worker             first_value = false;
168*da0073e9SAndroid Build Coastguard Worker           }
169*da0073e9SAndroid Build Coastguard Worker           if (index < m_roundup_power2_divisions.size()) {
170*da0073e9SAndroid Build Coastguard Worker             m_roundup_power2_divisions[index] = val2;
171*da0073e9SAndroid Build Coastguard Worker           }
172*da0073e9SAndroid Build Coastguard Worker           last_index = index;
173*da0073e9SAndroid Build Coastguard Worker         }
174*da0073e9SAndroid Build Coastguard Worker 
175*da0073e9SAndroid Build Coastguard Worker         if (std::string_view(config[i + 1]) != "]") {
176*da0073e9SAndroid Build Coastguard Worker           consumeToken(config, ++i, ',');
177*da0073e9SAndroid Build Coastguard Worker         }
178*da0073e9SAndroid Build Coastguard Worker       }
179*da0073e9SAndroid Build Coastguard Worker     } else { // Keep this for backwards compatibility
180*da0073e9SAndroid Build Coastguard Worker       size_t val1 = stoi(config[i]);
181*da0073e9SAndroid Build Coastguard Worker       TORCH_CHECK(
182*da0073e9SAndroid Build Coastguard Worker           llvm::isPowerOf2_64(val1),
183*da0073e9SAndroid Build Coastguard Worker           "For roundups, the divisons has to be power of 2 ",
184*da0073e9SAndroid Build Coastguard Worker           "");
185*da0073e9SAndroid Build Coastguard Worker       std::fill(
186*da0073e9SAndroid Build Coastguard Worker           m_roundup_power2_divisions.begin(),
187*da0073e9SAndroid Build Coastguard Worker           m_roundup_power2_divisions.end(),
188*da0073e9SAndroid Build Coastguard Worker           val1);
189*da0073e9SAndroid Build Coastguard Worker     }
190*da0073e9SAndroid Build Coastguard Worker   } else {
191*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(false, "Error, expecting roundup_power2_divisions value", "");
192*da0073e9SAndroid Build Coastguard Worker   }
193*da0073e9SAndroid Build Coastguard Worker   return i;
194*da0073e9SAndroid Build Coastguard Worker }
195*da0073e9SAndroid Build Coastguard Worker 
parseAllocatorConfig(const std::vector<std::string> & config,size_t i,bool & used_cudaMallocAsync)196*da0073e9SAndroid Build Coastguard Worker size_t CUDAAllocatorConfig::parseAllocatorConfig(
197*da0073e9SAndroid Build Coastguard Worker     const std::vector<std::string>& config,
198*da0073e9SAndroid Build Coastguard Worker     size_t i,
199*da0073e9SAndroid Build Coastguard Worker     bool& used_cudaMallocAsync) {
200*da0073e9SAndroid Build Coastguard Worker   consumeToken(config, ++i, ':');
201*da0073e9SAndroid Build Coastguard Worker   if (++i < config.size()) {
202*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(
203*da0073e9SAndroid Build Coastguard Worker         ((config[i] == "native") || (config[i] == "cudaMallocAsync")),
204*da0073e9SAndroid Build Coastguard Worker         "Unknown allocator backend, "
205*da0073e9SAndroid Build Coastguard Worker         "options are native and cudaMallocAsync");
206*da0073e9SAndroid Build Coastguard Worker     used_cudaMallocAsync = (config[i] == "cudaMallocAsync");
207*da0073e9SAndroid Build Coastguard Worker #ifndef USE_ROCM
208*da0073e9SAndroid Build Coastguard Worker     // HIP supports hipMallocAsync and does not need to check versions
209*da0073e9SAndroid Build Coastguard Worker     if (used_cudaMallocAsync) {
210*da0073e9SAndroid Build Coastguard Worker #if CUDA_VERSION >= 11040
211*da0073e9SAndroid Build Coastguard Worker       int version = 0;
212*da0073e9SAndroid Build Coastguard Worker       C10_CUDA_CHECK(cudaDriverGetVersion(&version));
213*da0073e9SAndroid Build Coastguard Worker       TORCH_CHECK(
214*da0073e9SAndroid Build Coastguard Worker           version >= 11040,
215*da0073e9SAndroid Build Coastguard Worker           "backend:cudaMallocAsync requires CUDA runtime "
216*da0073e9SAndroid Build Coastguard Worker           "11.4 or newer, but cudaDriverGetVersion returned ",
217*da0073e9SAndroid Build Coastguard Worker           version);
218*da0073e9SAndroid Build Coastguard Worker #else
219*da0073e9SAndroid Build Coastguard Worker       TORCH_CHECK(
220*da0073e9SAndroid Build Coastguard Worker           false,
221*da0073e9SAndroid Build Coastguard Worker           "backend:cudaMallocAsync requires PyTorch to be built with "
222*da0073e9SAndroid Build Coastguard Worker           "CUDA 11.4 or newer, but CUDA_VERSION is ",
223*da0073e9SAndroid Build Coastguard Worker           CUDA_VERSION);
224*da0073e9SAndroid Build Coastguard Worker #endif
225*da0073e9SAndroid Build Coastguard Worker     }
226*da0073e9SAndroid Build Coastguard Worker #endif
227*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT(
228*da0073e9SAndroid Build Coastguard Worker         config[i] == get()->name(),
229*da0073e9SAndroid Build Coastguard Worker         "Allocator backend parsed at runtime != "
230*da0073e9SAndroid Build Coastguard Worker         "allocator backend parsed at load time");
231*da0073e9SAndroid Build Coastguard Worker   } else {
232*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(false, "Error parsing backend value", "");
233*da0073e9SAndroid Build Coastguard Worker   }
234*da0073e9SAndroid Build Coastguard Worker   return i;
235*da0073e9SAndroid Build Coastguard Worker }
236*da0073e9SAndroid Build Coastguard Worker 
parseArgs(const char * env)237*da0073e9SAndroid Build Coastguard Worker void CUDAAllocatorConfig::parseArgs(const char* env) {
238*da0073e9SAndroid Build Coastguard Worker   // If empty, set the default values
239*da0073e9SAndroid Build Coastguard Worker   m_max_split_size = std::numeric_limits<size_t>::max();
240*da0073e9SAndroid Build Coastguard Worker   m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0);
241*da0073e9SAndroid Build Coastguard Worker   m_garbage_collection_threshold = 0;
242*da0073e9SAndroid Build Coastguard Worker   bool used_cudaMallocAsync = false;
243*da0073e9SAndroid Build Coastguard Worker   bool used_native_specific_option = false;
244*da0073e9SAndroid Build Coastguard Worker 
245*da0073e9SAndroid Build Coastguard Worker   if (env == nullptr) {
246*da0073e9SAndroid Build Coastguard Worker     return;
247*da0073e9SAndroid Build Coastguard Worker   }
248*da0073e9SAndroid Build Coastguard Worker   {
249*da0073e9SAndroid Build Coastguard Worker     std::lock_guard<std::mutex> lock(m_last_allocator_settings_mutex);
250*da0073e9SAndroid Build Coastguard Worker     m_last_allocator_settings = env;
251*da0073e9SAndroid Build Coastguard Worker   }
252*da0073e9SAndroid Build Coastguard Worker 
253*da0073e9SAndroid Build Coastguard Worker   std::vector<std::string> config;
254*da0073e9SAndroid Build Coastguard Worker   lexArgs(env, config);
255*da0073e9SAndroid Build Coastguard Worker 
256*da0073e9SAndroid Build Coastguard Worker   for (size_t i = 0; i < config.size(); i++) {
257*da0073e9SAndroid Build Coastguard Worker     std::string_view config_item_view(config[i]);
258*da0073e9SAndroid Build Coastguard Worker     if (config_item_view == "max_split_size_mb") {
259*da0073e9SAndroid Build Coastguard Worker       i = parseMaxSplitSize(config, i);
260*da0073e9SAndroid Build Coastguard Worker       used_native_specific_option = true;
261*da0073e9SAndroid Build Coastguard Worker     } else if (config_item_view == "garbage_collection_threshold") {
262*da0073e9SAndroid Build Coastguard Worker       i = parseGarbageCollectionThreshold(config, i);
263*da0073e9SAndroid Build Coastguard Worker       used_native_specific_option = true;
264*da0073e9SAndroid Build Coastguard Worker     } else if (config_item_view == "roundup_power2_divisions") {
265*da0073e9SAndroid Build Coastguard Worker       i = parseRoundUpPower2Divisions(config, i);
266*da0073e9SAndroid Build Coastguard Worker       used_native_specific_option = true;
267*da0073e9SAndroid Build Coastguard Worker     } else if (config_item_view == "backend") {
268*da0073e9SAndroid Build Coastguard Worker       i = parseAllocatorConfig(config, i, used_cudaMallocAsync);
269*da0073e9SAndroid Build Coastguard Worker     } else if (config_item_view == "expandable_segments") {
270*da0073e9SAndroid Build Coastguard Worker       used_native_specific_option = true;
271*da0073e9SAndroid Build Coastguard Worker       consumeToken(config, ++i, ':');
272*da0073e9SAndroid Build Coastguard Worker       ++i;
273*da0073e9SAndroid Build Coastguard Worker       TORCH_CHECK(
274*da0073e9SAndroid Build Coastguard Worker           i < config.size() &&
275*da0073e9SAndroid Build Coastguard Worker               (std::string_view(config[i]) == "True" ||
276*da0073e9SAndroid Build Coastguard Worker                std::string_view(config[i]) == "False"),
277*da0073e9SAndroid Build Coastguard Worker           "Expected a single True/False argument for expandable_segments");
278*da0073e9SAndroid Build Coastguard Worker       config_item_view = config[i];
279*da0073e9SAndroid Build Coastguard Worker       m_expandable_segments = (config_item_view == "True");
280*da0073e9SAndroid Build Coastguard Worker     } else if (
281*da0073e9SAndroid Build Coastguard Worker         // ROCm build's hipify step will change "cuda" to "hip", but for ease of
282*da0073e9SAndroid Build Coastguard Worker         // use, accept both. We must break up the string to prevent hipify here.
283*da0073e9SAndroid Build Coastguard Worker         config_item_view == "release_lock_on_hipmalloc" ||
284*da0073e9SAndroid Build Coastguard Worker         config_item_view ==
285*da0073e9SAndroid Build Coastguard Worker             "release_lock_on_c"
286*da0073e9SAndroid Build Coastguard Worker             "udamalloc") {
287*da0073e9SAndroid Build Coastguard Worker       used_native_specific_option = true;
288*da0073e9SAndroid Build Coastguard Worker       consumeToken(config, ++i, ':');
289*da0073e9SAndroid Build Coastguard Worker       ++i;
290*da0073e9SAndroid Build Coastguard Worker       TORCH_CHECK(
291*da0073e9SAndroid Build Coastguard Worker           i < config.size() &&
292*da0073e9SAndroid Build Coastguard Worker               (std::string_view(config[i]) == "True" ||
293*da0073e9SAndroid Build Coastguard Worker                std::string_view(config[i]) == "False"),
294*da0073e9SAndroid Build Coastguard Worker           "Expected a single True/False argument for release_lock_on_cudamalloc");
295*da0073e9SAndroid Build Coastguard Worker       config_item_view = config[i];
296*da0073e9SAndroid Build Coastguard Worker       m_release_lock_on_cudamalloc = (config_item_view == "True");
297*da0073e9SAndroid Build Coastguard Worker     } else if (
298*da0073e9SAndroid Build Coastguard Worker         // ROCm build's hipify step will change "cuda" to "hip", but for ease of
299*da0073e9SAndroid Build Coastguard Worker         // use, accept both. We must break up the string to prevent hipify here.
300*da0073e9SAndroid Build Coastguard Worker         config_item_view == "pinned_use_hip_host_register" ||
301*da0073e9SAndroid Build Coastguard Worker         config_item_view ==
302*da0073e9SAndroid Build Coastguard Worker             "pinned_use_c"
303*da0073e9SAndroid Build Coastguard Worker             "uda_host_register") {
304*da0073e9SAndroid Build Coastguard Worker       i = parsePinnedUseCudaHostRegister(config, i);
305*da0073e9SAndroid Build Coastguard Worker       used_native_specific_option = true;
306*da0073e9SAndroid Build Coastguard Worker     } else if (config_item_view == "pinned_num_register_threads") {
307*da0073e9SAndroid Build Coastguard Worker       i = parsePinnedNumRegisterThreads(config, i);
308*da0073e9SAndroid Build Coastguard Worker       used_native_specific_option = true;
309*da0073e9SAndroid Build Coastguard Worker     } else {
310*da0073e9SAndroid Build Coastguard Worker       TORCH_CHECK(
311*da0073e9SAndroid Build Coastguard Worker           false, "Unrecognized CachingAllocator option: ", config_item_view);
312*da0073e9SAndroid Build Coastguard Worker     }
313*da0073e9SAndroid Build Coastguard Worker 
314*da0073e9SAndroid Build Coastguard Worker     if (i + 1 < config.size()) {
315*da0073e9SAndroid Build Coastguard Worker       consumeToken(config, ++i, ',');
316*da0073e9SAndroid Build Coastguard Worker     }
317*da0073e9SAndroid Build Coastguard Worker   }
318*da0073e9SAndroid Build Coastguard Worker 
319*da0073e9SAndroid Build Coastguard Worker   if (used_cudaMallocAsync && used_native_specific_option) {
320*da0073e9SAndroid Build Coastguard Worker     TORCH_WARN(
321*da0073e9SAndroid Build Coastguard Worker         "backend:cudaMallocAsync ignores max_split_size_mb,"
322*da0073e9SAndroid Build Coastguard Worker         "roundup_power2_divisions, and garbage_collect_threshold.");
323*da0073e9SAndroid Build Coastguard Worker   }
324*da0073e9SAndroid Build Coastguard Worker }
325*da0073e9SAndroid Build Coastguard Worker 
parsePinnedUseCudaHostRegister(const std::vector<std::string> & config,size_t i)326*da0073e9SAndroid Build Coastguard Worker size_t CUDAAllocatorConfig::parsePinnedUseCudaHostRegister(
327*da0073e9SAndroid Build Coastguard Worker     const std::vector<std::string>& config,
328*da0073e9SAndroid Build Coastguard Worker     size_t i) {
329*da0073e9SAndroid Build Coastguard Worker   consumeToken(config, ++i, ':');
330*da0073e9SAndroid Build Coastguard Worker   if (++i < config.size()) {
331*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(
332*da0073e9SAndroid Build Coastguard Worker         (config[i] == "True" || config[i] == "False"),
333*da0073e9SAndroid Build Coastguard Worker         "Expected a single True/False argument for pinned_use_cuda_host_register");
334*da0073e9SAndroid Build Coastguard Worker     m_pinned_use_cuda_host_register = (config[i] == "True");
335*da0073e9SAndroid Build Coastguard Worker   } else {
336*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(
337*da0073e9SAndroid Build Coastguard Worker         false, "Error, expecting pinned_use_cuda_host_register value", "");
338*da0073e9SAndroid Build Coastguard Worker   }
339*da0073e9SAndroid Build Coastguard Worker   return i;
340*da0073e9SAndroid Build Coastguard Worker }
341*da0073e9SAndroid Build Coastguard Worker 
parsePinnedNumRegisterThreads(const std::vector<std::string> & config,size_t i)342*da0073e9SAndroid Build Coastguard Worker size_t CUDAAllocatorConfig::parsePinnedNumRegisterThreads(
343*da0073e9SAndroid Build Coastguard Worker     const std::vector<std::string>& config,
344*da0073e9SAndroid Build Coastguard Worker     size_t i) {
345*da0073e9SAndroid Build Coastguard Worker   consumeToken(config, ++i, ':');
346*da0073e9SAndroid Build Coastguard Worker   if (++i < config.size()) {
347*da0073e9SAndroid Build Coastguard Worker     size_t val2 = stoi(config[i]);
348*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(
349*da0073e9SAndroid Build Coastguard Worker         llvm::isPowerOf2_64(val2),
350*da0073e9SAndroid Build Coastguard Worker         "Number of register threads has to be power of 2 ",
351*da0073e9SAndroid Build Coastguard Worker         "");
352*da0073e9SAndroid Build Coastguard Worker     auto maxThreads = CUDAAllocatorConfig::pinned_max_register_threads();
353*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(
354*da0073e9SAndroid Build Coastguard Worker         val2 <= maxThreads,
355*da0073e9SAndroid Build Coastguard Worker         "Number of register threads should be less than or equal to " +
356*da0073e9SAndroid Build Coastguard Worker             std::to_string(maxThreads),
357*da0073e9SAndroid Build Coastguard Worker         "");
358*da0073e9SAndroid Build Coastguard Worker     m_pinned_num_register_threads = val2;
359*da0073e9SAndroid Build Coastguard Worker   } else {
360*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(
361*da0073e9SAndroid Build Coastguard Worker         false, "Error, expecting pinned_num_register_threads value", "");
362*da0073e9SAndroid Build Coastguard Worker   }
363*da0073e9SAndroid Build Coastguard Worker   return i;
364*da0073e9SAndroid Build Coastguard Worker }
365*da0073e9SAndroid Build Coastguard Worker 
366*da0073e9SAndroid Build Coastguard Worker // General caching allocator utilities
setAllocatorSettings(const std::string & env)367*da0073e9SAndroid Build Coastguard Worker void setAllocatorSettings(const std::string& env) {
368*da0073e9SAndroid Build Coastguard Worker   CUDACachingAllocator::CUDAAllocatorConfig::instance().parseArgs(env.c_str());
369*da0073e9SAndroid Build Coastguard Worker }
370*da0073e9SAndroid Build Coastguard Worker 
371*da0073e9SAndroid Build Coastguard Worker } // namespace c10::cuda::CUDACachingAllocator
372