1*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAAllocatorConfig.h>
2*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDACachingAllocator.h>
3*da0073e9SAndroid Build Coastguard Worker #include <c10/util/llvmMathExtras.h>
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Worker #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
6*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/driver_api.h>
7*da0073e9SAndroid Build Coastguard Worker #endif
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker namespace c10::cuda::CUDACachingAllocator {
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker constexpr size_t kRoundUpPowerOfTwoIntervals = 16;
12*da0073e9SAndroid Build Coastguard Worker
CUDAAllocatorConfig()13*da0073e9SAndroid Build Coastguard Worker CUDAAllocatorConfig::CUDAAllocatorConfig()
14*da0073e9SAndroid Build Coastguard Worker : m_max_split_size(std::numeric_limits<size_t>::max()),
15*da0073e9SAndroid Build Coastguard Worker m_garbage_collection_threshold(0),
16*da0073e9SAndroid Build Coastguard Worker m_pinned_num_register_threads(1),
17*da0073e9SAndroid Build Coastguard Worker m_expandable_segments(false),
18*da0073e9SAndroid Build Coastguard Worker m_release_lock_on_cudamalloc(false),
19*da0073e9SAndroid Build Coastguard Worker m_pinned_use_cuda_host_register(false),
20*da0073e9SAndroid Build Coastguard Worker m_last_allocator_settings("") {
21*da0073e9SAndroid Build Coastguard Worker m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0);
22*da0073e9SAndroid Build Coastguard Worker }
23*da0073e9SAndroid Build Coastguard Worker
roundup_power2_divisions(size_t size)24*da0073e9SAndroid Build Coastguard Worker size_t CUDAAllocatorConfig::roundup_power2_divisions(size_t size) {
25*da0073e9SAndroid Build Coastguard Worker size_t log_size = (63 - llvm::countLeadingZeros(size));
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker // Our intervals start at 1MB and end at 64GB
28*da0073e9SAndroid Build Coastguard Worker const size_t interval_start =
29*da0073e9SAndroid Build Coastguard Worker 63 - llvm::countLeadingZeros(static_cast<size_t>(1048576));
30*da0073e9SAndroid Build Coastguard Worker const size_t interval_end =
31*da0073e9SAndroid Build Coastguard Worker 63 - llvm::countLeadingZeros(static_cast<size_t>(68719476736));
32*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
33*da0073e9SAndroid Build Coastguard Worker (interval_end - interval_start == kRoundUpPowerOfTwoIntervals),
34*da0073e9SAndroid Build Coastguard Worker "kRoundUpPowerOfTwoIntervals mismatch");
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker int index = static_cast<int>(log_size) - static_cast<int>(interval_start);
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker index = std::max(0, index);
39*da0073e9SAndroid Build Coastguard Worker index = std::min(index, static_cast<int>(kRoundUpPowerOfTwoIntervals) - 1);
40*da0073e9SAndroid Build Coastguard Worker return instance().m_roundup_power2_divisions[index];
41*da0073e9SAndroid Build Coastguard Worker }
42*da0073e9SAndroid Build Coastguard Worker
lexArgs(const char * env,std::vector<std::string> & config)43*da0073e9SAndroid Build Coastguard Worker void CUDAAllocatorConfig::lexArgs(
44*da0073e9SAndroid Build Coastguard Worker const char* env,
45*da0073e9SAndroid Build Coastguard Worker std::vector<std::string>& config) {
46*da0073e9SAndroid Build Coastguard Worker std::vector<char> buf;
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker size_t env_length = strlen(env);
49*da0073e9SAndroid Build Coastguard Worker for (size_t i = 0; i < env_length; i++) {
50*da0073e9SAndroid Build Coastguard Worker if (env[i] == ',' || env[i] == ':' || env[i] == '[' || env[i] == ']') {
51*da0073e9SAndroid Build Coastguard Worker if (!buf.empty()) {
52*da0073e9SAndroid Build Coastguard Worker config.emplace_back(buf.begin(), buf.end());
53*da0073e9SAndroid Build Coastguard Worker buf.clear();
54*da0073e9SAndroid Build Coastguard Worker }
55*da0073e9SAndroid Build Coastguard Worker config.emplace_back(1, env[i]);
56*da0073e9SAndroid Build Coastguard Worker } else if (env[i] != ' ') {
57*da0073e9SAndroid Build Coastguard Worker buf.emplace_back(static_cast<char>(env[i]));
58*da0073e9SAndroid Build Coastguard Worker }
59*da0073e9SAndroid Build Coastguard Worker }
60*da0073e9SAndroid Build Coastguard Worker if (!buf.empty()) {
61*da0073e9SAndroid Build Coastguard Worker config.emplace_back(buf.begin(), buf.end());
62*da0073e9SAndroid Build Coastguard Worker }
63*da0073e9SAndroid Build Coastguard Worker }
64*da0073e9SAndroid Build Coastguard Worker
consumeToken(const std::vector<std::string> & config,size_t i,const char c)65*da0073e9SAndroid Build Coastguard Worker void CUDAAllocatorConfig::consumeToken(
66*da0073e9SAndroid Build Coastguard Worker const std::vector<std::string>& config,
67*da0073e9SAndroid Build Coastguard Worker size_t i,
68*da0073e9SAndroid Build Coastguard Worker const char c) {
69*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
70*da0073e9SAndroid Build Coastguard Worker i < config.size() && config[i] == std::string(1, c),
71*da0073e9SAndroid Build Coastguard Worker "Error parsing CachingAllocator settings, expected ",
72*da0073e9SAndroid Build Coastguard Worker c,
73*da0073e9SAndroid Build Coastguard Worker "");
74*da0073e9SAndroid Build Coastguard Worker }
75*da0073e9SAndroid Build Coastguard Worker
parseMaxSplitSize(const std::vector<std::string> & config,size_t i)76*da0073e9SAndroid Build Coastguard Worker size_t CUDAAllocatorConfig::parseMaxSplitSize(
77*da0073e9SAndroid Build Coastguard Worker const std::vector<std::string>& config,
78*da0073e9SAndroid Build Coastguard Worker size_t i) {
79*da0073e9SAndroid Build Coastguard Worker consumeToken(config, ++i, ':');
80*da0073e9SAndroid Build Coastguard Worker constexpr int mb = 1024 * 1024;
81*da0073e9SAndroid Build Coastguard Worker if (++i < config.size()) {
82*da0073e9SAndroid Build Coastguard Worker size_t val1 = stoi(config[i]);
83*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
84*da0073e9SAndroid Build Coastguard Worker val1 > kLargeBuffer / mb,
85*da0073e9SAndroid Build Coastguard Worker "CachingAllocator option max_split_size_mb too small, must be > ",
86*da0073e9SAndroid Build Coastguard Worker kLargeBuffer / mb,
87*da0073e9SAndroid Build Coastguard Worker "");
88*da0073e9SAndroid Build Coastguard Worker val1 = std::max(val1, kLargeBuffer / mb);
89*da0073e9SAndroid Build Coastguard Worker val1 = std::min(val1, (std::numeric_limits<size_t>::max() / mb));
90*da0073e9SAndroid Build Coastguard Worker m_max_split_size = val1 * 1024 * 1024;
91*da0073e9SAndroid Build Coastguard Worker } else {
92*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(false, "Error, expecting max_split_size_mb value", "");
93*da0073e9SAndroid Build Coastguard Worker }
94*da0073e9SAndroid Build Coastguard Worker return i;
95*da0073e9SAndroid Build Coastguard Worker }
96*da0073e9SAndroid Build Coastguard Worker
parseGarbageCollectionThreshold(const std::vector<std::string> & config,size_t i)97*da0073e9SAndroid Build Coastguard Worker size_t CUDAAllocatorConfig::parseGarbageCollectionThreshold(
98*da0073e9SAndroid Build Coastguard Worker const std::vector<std::string>& config,
99*da0073e9SAndroid Build Coastguard Worker size_t i) {
100*da0073e9SAndroid Build Coastguard Worker consumeToken(config, ++i, ':');
101*da0073e9SAndroid Build Coastguard Worker if (++i < config.size()) {
102*da0073e9SAndroid Build Coastguard Worker double val1 = stod(config[i]);
103*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
104*da0073e9SAndroid Build Coastguard Worker val1 > 0, "garbage_collect_threshold too small, set it 0.0~1.0", "");
105*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
106*da0073e9SAndroid Build Coastguard Worker val1 < 1.0, "garbage_collect_threshold too big, set it 0.0~1.0", "");
107*da0073e9SAndroid Build Coastguard Worker m_garbage_collection_threshold = val1;
108*da0073e9SAndroid Build Coastguard Worker } else {
109*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
110*da0073e9SAndroid Build Coastguard Worker false, "Error, expecting garbage_collection_threshold value", "");
111*da0073e9SAndroid Build Coastguard Worker }
112*da0073e9SAndroid Build Coastguard Worker return i;
113*da0073e9SAndroid Build Coastguard Worker }
114*da0073e9SAndroid Build Coastguard Worker
parseRoundUpPower2Divisions(const std::vector<std::string> & config,size_t i)115*da0073e9SAndroid Build Coastguard Worker size_t CUDAAllocatorConfig::parseRoundUpPower2Divisions(
116*da0073e9SAndroid Build Coastguard Worker const std::vector<std::string>& config,
117*da0073e9SAndroid Build Coastguard Worker size_t i) {
118*da0073e9SAndroid Build Coastguard Worker consumeToken(config, ++i, ':');
119*da0073e9SAndroid Build Coastguard Worker bool first_value = true;
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker if (++i < config.size()) {
122*da0073e9SAndroid Build Coastguard Worker if (std::string_view(config[i]) == "[") {
123*da0073e9SAndroid Build Coastguard Worker size_t last_index = 0;
124*da0073e9SAndroid Build Coastguard Worker while (++i < config.size() && std::string_view(config[i]) != "]") {
125*da0073e9SAndroid Build Coastguard Worker const std::string& val1 = config[i];
126*da0073e9SAndroid Build Coastguard Worker size_t val2 = 0;
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker consumeToken(config, ++i, ':');
129*da0073e9SAndroid Build Coastguard Worker if (++i < config.size()) {
130*da0073e9SAndroid Build Coastguard Worker val2 = stoi(config[i]);
131*da0073e9SAndroid Build Coastguard Worker } else {
132*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
133*da0073e9SAndroid Build Coastguard Worker false, "Error parsing roundup_power2_divisions value", "");
134*da0073e9SAndroid Build Coastguard Worker }
135*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
136*da0073e9SAndroid Build Coastguard Worker val2 == 0 || llvm::isPowerOf2_64(val2),
137*da0073e9SAndroid Build Coastguard Worker "For roundups, the divisons has to be power of 2 or 0 to disable roundup ",
138*da0073e9SAndroid Build Coastguard Worker "");
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker if (std::string_view(val1) == ">") {
141*da0073e9SAndroid Build Coastguard Worker std::fill(
142*da0073e9SAndroid Build Coastguard Worker std::next(
143*da0073e9SAndroid Build Coastguard Worker m_roundup_power2_divisions.begin(),
144*da0073e9SAndroid Build Coastguard Worker static_cast<std::vector<unsigned long>::difference_type>(
145*da0073e9SAndroid Build Coastguard Worker last_index)),
146*da0073e9SAndroid Build Coastguard Worker m_roundup_power2_divisions.end(),
147*da0073e9SAndroid Build Coastguard Worker val2);
148*da0073e9SAndroid Build Coastguard Worker } else {
149*da0073e9SAndroid Build Coastguard Worker size_t val1_long = stoul(val1);
150*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
151*da0073e9SAndroid Build Coastguard Worker llvm::isPowerOf2_64(val1_long),
152*da0073e9SAndroid Build Coastguard Worker "For roundups, the intervals have to be power of 2 ",
153*da0073e9SAndroid Build Coastguard Worker "");
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker size_t index = 63 - llvm::countLeadingZeros(val1_long);
156*da0073e9SAndroid Build Coastguard Worker index = std::max((size_t)0, index);
157*da0073e9SAndroid Build Coastguard Worker index = std::min(index, m_roundup_power2_divisions.size() - 1);
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker if (first_value) {
160*da0073e9SAndroid Build Coastguard Worker std::fill(
161*da0073e9SAndroid Build Coastguard Worker m_roundup_power2_divisions.begin(),
162*da0073e9SAndroid Build Coastguard Worker std::next(
163*da0073e9SAndroid Build Coastguard Worker m_roundup_power2_divisions.begin(),
164*da0073e9SAndroid Build Coastguard Worker static_cast<std::vector<unsigned long>::difference_type>(
165*da0073e9SAndroid Build Coastguard Worker index)),
166*da0073e9SAndroid Build Coastguard Worker val2);
167*da0073e9SAndroid Build Coastguard Worker first_value = false;
168*da0073e9SAndroid Build Coastguard Worker }
169*da0073e9SAndroid Build Coastguard Worker if (index < m_roundup_power2_divisions.size()) {
170*da0073e9SAndroid Build Coastguard Worker m_roundup_power2_divisions[index] = val2;
171*da0073e9SAndroid Build Coastguard Worker }
172*da0073e9SAndroid Build Coastguard Worker last_index = index;
173*da0073e9SAndroid Build Coastguard Worker }
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Worker if (std::string_view(config[i + 1]) != "]") {
176*da0073e9SAndroid Build Coastguard Worker consumeToken(config, ++i, ',');
177*da0073e9SAndroid Build Coastguard Worker }
178*da0073e9SAndroid Build Coastguard Worker }
179*da0073e9SAndroid Build Coastguard Worker } else { // Keep this for backwards compatibility
180*da0073e9SAndroid Build Coastguard Worker size_t val1 = stoi(config[i]);
181*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
182*da0073e9SAndroid Build Coastguard Worker llvm::isPowerOf2_64(val1),
183*da0073e9SAndroid Build Coastguard Worker "For roundups, the divisons has to be power of 2 ",
184*da0073e9SAndroid Build Coastguard Worker "");
185*da0073e9SAndroid Build Coastguard Worker std::fill(
186*da0073e9SAndroid Build Coastguard Worker m_roundup_power2_divisions.begin(),
187*da0073e9SAndroid Build Coastguard Worker m_roundup_power2_divisions.end(),
188*da0073e9SAndroid Build Coastguard Worker val1);
189*da0073e9SAndroid Build Coastguard Worker }
190*da0073e9SAndroid Build Coastguard Worker } else {
191*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(false, "Error, expecting roundup_power2_divisions value", "");
192*da0073e9SAndroid Build Coastguard Worker }
193*da0073e9SAndroid Build Coastguard Worker return i;
194*da0073e9SAndroid Build Coastguard Worker }
195*da0073e9SAndroid Build Coastguard Worker
parseAllocatorConfig(const std::vector<std::string> & config,size_t i,bool & used_cudaMallocAsync)196*da0073e9SAndroid Build Coastguard Worker size_t CUDAAllocatorConfig::parseAllocatorConfig(
197*da0073e9SAndroid Build Coastguard Worker const std::vector<std::string>& config,
198*da0073e9SAndroid Build Coastguard Worker size_t i,
199*da0073e9SAndroid Build Coastguard Worker bool& used_cudaMallocAsync) {
200*da0073e9SAndroid Build Coastguard Worker consumeToken(config, ++i, ':');
201*da0073e9SAndroid Build Coastguard Worker if (++i < config.size()) {
202*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
203*da0073e9SAndroid Build Coastguard Worker ((config[i] == "native") || (config[i] == "cudaMallocAsync")),
204*da0073e9SAndroid Build Coastguard Worker "Unknown allocator backend, "
205*da0073e9SAndroid Build Coastguard Worker "options are native and cudaMallocAsync");
206*da0073e9SAndroid Build Coastguard Worker used_cudaMallocAsync = (config[i] == "cudaMallocAsync");
207*da0073e9SAndroid Build Coastguard Worker #ifndef USE_ROCM
208*da0073e9SAndroid Build Coastguard Worker // HIP supports hipMallocAsync and does not need to check versions
209*da0073e9SAndroid Build Coastguard Worker if (used_cudaMallocAsync) {
210*da0073e9SAndroid Build Coastguard Worker #if CUDA_VERSION >= 11040
211*da0073e9SAndroid Build Coastguard Worker int version = 0;
212*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(cudaDriverGetVersion(&version));
213*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
214*da0073e9SAndroid Build Coastguard Worker version >= 11040,
215*da0073e9SAndroid Build Coastguard Worker "backend:cudaMallocAsync requires CUDA runtime "
216*da0073e9SAndroid Build Coastguard Worker "11.4 or newer, but cudaDriverGetVersion returned ",
217*da0073e9SAndroid Build Coastguard Worker version);
218*da0073e9SAndroid Build Coastguard Worker #else
219*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
220*da0073e9SAndroid Build Coastguard Worker false,
221*da0073e9SAndroid Build Coastguard Worker "backend:cudaMallocAsync requires PyTorch to be built with "
222*da0073e9SAndroid Build Coastguard Worker "CUDA 11.4 or newer, but CUDA_VERSION is ",
223*da0073e9SAndroid Build Coastguard Worker CUDA_VERSION);
224*da0073e9SAndroid Build Coastguard Worker #endif
225*da0073e9SAndroid Build Coastguard Worker }
226*da0073e9SAndroid Build Coastguard Worker #endif
227*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(
228*da0073e9SAndroid Build Coastguard Worker config[i] == get()->name(),
229*da0073e9SAndroid Build Coastguard Worker "Allocator backend parsed at runtime != "
230*da0073e9SAndroid Build Coastguard Worker "allocator backend parsed at load time");
231*da0073e9SAndroid Build Coastguard Worker } else {
232*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(false, "Error parsing backend value", "");
233*da0073e9SAndroid Build Coastguard Worker }
234*da0073e9SAndroid Build Coastguard Worker return i;
235*da0073e9SAndroid Build Coastguard Worker }
236*da0073e9SAndroid Build Coastguard Worker
parseArgs(const char * env)237*da0073e9SAndroid Build Coastguard Worker void CUDAAllocatorConfig::parseArgs(const char* env) {
238*da0073e9SAndroid Build Coastguard Worker // If empty, set the default values
239*da0073e9SAndroid Build Coastguard Worker m_max_split_size = std::numeric_limits<size_t>::max();
240*da0073e9SAndroid Build Coastguard Worker m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0);
241*da0073e9SAndroid Build Coastguard Worker m_garbage_collection_threshold = 0;
242*da0073e9SAndroid Build Coastguard Worker bool used_cudaMallocAsync = false;
243*da0073e9SAndroid Build Coastguard Worker bool used_native_specific_option = false;
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Worker if (env == nullptr) {
246*da0073e9SAndroid Build Coastguard Worker return;
247*da0073e9SAndroid Build Coastguard Worker }
248*da0073e9SAndroid Build Coastguard Worker {
249*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::mutex> lock(m_last_allocator_settings_mutex);
250*da0073e9SAndroid Build Coastguard Worker m_last_allocator_settings = env;
251*da0073e9SAndroid Build Coastguard Worker }
252*da0073e9SAndroid Build Coastguard Worker
253*da0073e9SAndroid Build Coastguard Worker std::vector<std::string> config;
254*da0073e9SAndroid Build Coastguard Worker lexArgs(env, config);
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard Worker for (size_t i = 0; i < config.size(); i++) {
257*da0073e9SAndroid Build Coastguard Worker std::string_view config_item_view(config[i]);
258*da0073e9SAndroid Build Coastguard Worker if (config_item_view == "max_split_size_mb") {
259*da0073e9SAndroid Build Coastguard Worker i = parseMaxSplitSize(config, i);
260*da0073e9SAndroid Build Coastguard Worker used_native_specific_option = true;
261*da0073e9SAndroid Build Coastguard Worker } else if (config_item_view == "garbage_collection_threshold") {
262*da0073e9SAndroid Build Coastguard Worker i = parseGarbageCollectionThreshold(config, i);
263*da0073e9SAndroid Build Coastguard Worker used_native_specific_option = true;
264*da0073e9SAndroid Build Coastguard Worker } else if (config_item_view == "roundup_power2_divisions") {
265*da0073e9SAndroid Build Coastguard Worker i = parseRoundUpPower2Divisions(config, i);
266*da0073e9SAndroid Build Coastguard Worker used_native_specific_option = true;
267*da0073e9SAndroid Build Coastguard Worker } else if (config_item_view == "backend") {
268*da0073e9SAndroid Build Coastguard Worker i = parseAllocatorConfig(config, i, used_cudaMallocAsync);
269*da0073e9SAndroid Build Coastguard Worker } else if (config_item_view == "expandable_segments") {
270*da0073e9SAndroid Build Coastguard Worker used_native_specific_option = true;
271*da0073e9SAndroid Build Coastguard Worker consumeToken(config, ++i, ':');
272*da0073e9SAndroid Build Coastguard Worker ++i;
273*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
274*da0073e9SAndroid Build Coastguard Worker i < config.size() &&
275*da0073e9SAndroid Build Coastguard Worker (std::string_view(config[i]) == "True" ||
276*da0073e9SAndroid Build Coastguard Worker std::string_view(config[i]) == "False"),
277*da0073e9SAndroid Build Coastguard Worker "Expected a single True/False argument for expandable_segments");
278*da0073e9SAndroid Build Coastguard Worker config_item_view = config[i];
279*da0073e9SAndroid Build Coastguard Worker m_expandable_segments = (config_item_view == "True");
280*da0073e9SAndroid Build Coastguard Worker } else if (
281*da0073e9SAndroid Build Coastguard Worker // ROCm build's hipify step will change "cuda" to "hip", but for ease of
282*da0073e9SAndroid Build Coastguard Worker // use, accept both. We must break up the string to prevent hipify here.
283*da0073e9SAndroid Build Coastguard Worker config_item_view == "release_lock_on_hipmalloc" ||
284*da0073e9SAndroid Build Coastguard Worker config_item_view ==
285*da0073e9SAndroid Build Coastguard Worker "release_lock_on_c"
286*da0073e9SAndroid Build Coastguard Worker "udamalloc") {
287*da0073e9SAndroid Build Coastguard Worker used_native_specific_option = true;
288*da0073e9SAndroid Build Coastguard Worker consumeToken(config, ++i, ':');
289*da0073e9SAndroid Build Coastguard Worker ++i;
290*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
291*da0073e9SAndroid Build Coastguard Worker i < config.size() &&
292*da0073e9SAndroid Build Coastguard Worker (std::string_view(config[i]) == "True" ||
293*da0073e9SAndroid Build Coastguard Worker std::string_view(config[i]) == "False"),
294*da0073e9SAndroid Build Coastguard Worker "Expected a single True/False argument for release_lock_on_cudamalloc");
295*da0073e9SAndroid Build Coastguard Worker config_item_view = config[i];
296*da0073e9SAndroid Build Coastguard Worker m_release_lock_on_cudamalloc = (config_item_view == "True");
297*da0073e9SAndroid Build Coastguard Worker } else if (
298*da0073e9SAndroid Build Coastguard Worker // ROCm build's hipify step will change "cuda" to "hip", but for ease of
299*da0073e9SAndroid Build Coastguard Worker // use, accept both. We must break up the string to prevent hipify here.
300*da0073e9SAndroid Build Coastguard Worker config_item_view == "pinned_use_hip_host_register" ||
301*da0073e9SAndroid Build Coastguard Worker config_item_view ==
302*da0073e9SAndroid Build Coastguard Worker "pinned_use_c"
303*da0073e9SAndroid Build Coastguard Worker "uda_host_register") {
304*da0073e9SAndroid Build Coastguard Worker i = parsePinnedUseCudaHostRegister(config, i);
305*da0073e9SAndroid Build Coastguard Worker used_native_specific_option = true;
306*da0073e9SAndroid Build Coastguard Worker } else if (config_item_view == "pinned_num_register_threads") {
307*da0073e9SAndroid Build Coastguard Worker i = parsePinnedNumRegisterThreads(config, i);
308*da0073e9SAndroid Build Coastguard Worker used_native_specific_option = true;
309*da0073e9SAndroid Build Coastguard Worker } else {
310*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
311*da0073e9SAndroid Build Coastguard Worker false, "Unrecognized CachingAllocator option: ", config_item_view);
312*da0073e9SAndroid Build Coastguard Worker }
313*da0073e9SAndroid Build Coastguard Worker
314*da0073e9SAndroid Build Coastguard Worker if (i + 1 < config.size()) {
315*da0073e9SAndroid Build Coastguard Worker consumeToken(config, ++i, ',');
316*da0073e9SAndroid Build Coastguard Worker }
317*da0073e9SAndroid Build Coastguard Worker }
318*da0073e9SAndroid Build Coastguard Worker
319*da0073e9SAndroid Build Coastguard Worker if (used_cudaMallocAsync && used_native_specific_option) {
320*da0073e9SAndroid Build Coastguard Worker TORCH_WARN(
321*da0073e9SAndroid Build Coastguard Worker "backend:cudaMallocAsync ignores max_split_size_mb,"
322*da0073e9SAndroid Build Coastguard Worker "roundup_power2_divisions, and garbage_collect_threshold.");
323*da0073e9SAndroid Build Coastguard Worker }
324*da0073e9SAndroid Build Coastguard Worker }
325*da0073e9SAndroid Build Coastguard Worker
parsePinnedUseCudaHostRegister(const std::vector<std::string> & config,size_t i)326*da0073e9SAndroid Build Coastguard Worker size_t CUDAAllocatorConfig::parsePinnedUseCudaHostRegister(
327*da0073e9SAndroid Build Coastguard Worker const std::vector<std::string>& config,
328*da0073e9SAndroid Build Coastguard Worker size_t i) {
329*da0073e9SAndroid Build Coastguard Worker consumeToken(config, ++i, ':');
330*da0073e9SAndroid Build Coastguard Worker if (++i < config.size()) {
331*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
332*da0073e9SAndroid Build Coastguard Worker (config[i] == "True" || config[i] == "False"),
333*da0073e9SAndroid Build Coastguard Worker "Expected a single True/False argument for pinned_use_cuda_host_register");
334*da0073e9SAndroid Build Coastguard Worker m_pinned_use_cuda_host_register = (config[i] == "True");
335*da0073e9SAndroid Build Coastguard Worker } else {
336*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
337*da0073e9SAndroid Build Coastguard Worker false, "Error, expecting pinned_use_cuda_host_register value", "");
338*da0073e9SAndroid Build Coastguard Worker }
339*da0073e9SAndroid Build Coastguard Worker return i;
340*da0073e9SAndroid Build Coastguard Worker }
341*da0073e9SAndroid Build Coastguard Worker
parsePinnedNumRegisterThreads(const std::vector<std::string> & config,size_t i)342*da0073e9SAndroid Build Coastguard Worker size_t CUDAAllocatorConfig::parsePinnedNumRegisterThreads(
343*da0073e9SAndroid Build Coastguard Worker const std::vector<std::string>& config,
344*da0073e9SAndroid Build Coastguard Worker size_t i) {
345*da0073e9SAndroid Build Coastguard Worker consumeToken(config, ++i, ':');
346*da0073e9SAndroid Build Coastguard Worker if (++i < config.size()) {
347*da0073e9SAndroid Build Coastguard Worker size_t val2 = stoi(config[i]);
348*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
349*da0073e9SAndroid Build Coastguard Worker llvm::isPowerOf2_64(val2),
350*da0073e9SAndroid Build Coastguard Worker "Number of register threads has to be power of 2 ",
351*da0073e9SAndroid Build Coastguard Worker "");
352*da0073e9SAndroid Build Coastguard Worker auto maxThreads = CUDAAllocatorConfig::pinned_max_register_threads();
353*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
354*da0073e9SAndroid Build Coastguard Worker val2 <= maxThreads,
355*da0073e9SAndroid Build Coastguard Worker "Number of register threads should be less than or equal to " +
356*da0073e9SAndroid Build Coastguard Worker std::to_string(maxThreads),
357*da0073e9SAndroid Build Coastguard Worker "");
358*da0073e9SAndroid Build Coastguard Worker m_pinned_num_register_threads = val2;
359*da0073e9SAndroid Build Coastguard Worker } else {
360*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
361*da0073e9SAndroid Build Coastguard Worker false, "Error, expecting pinned_num_register_threads value", "");
362*da0073e9SAndroid Build Coastguard Worker }
363*da0073e9SAndroid Build Coastguard Worker return i;
364*da0073e9SAndroid Build Coastguard Worker }
365*da0073e9SAndroid Build Coastguard Worker
366*da0073e9SAndroid Build Coastguard Worker // General caching allocator utilities
setAllocatorSettings(const std::string & env)367*da0073e9SAndroid Build Coastguard Worker void setAllocatorSettings(const std::string& env) {
368*da0073e9SAndroid Build Coastguard Worker CUDACachingAllocator::CUDAAllocatorConfig::instance().parseArgs(env.c_str());
369*da0073e9SAndroid Build Coastguard Worker }
370*da0073e9SAndroid Build Coastguard Worker
371*da0073e9SAndroid Build Coastguard Worker } // namespace c10::cuda::CUDACachingAllocator
372