xref: /aosp_15_r20/external/pytorch/c10/util/numa.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h>
2*da0073e9SAndroid Build Coastguard Worker #include <c10/util/numa.h>
3*da0073e9SAndroid Build Coastguard Worker 
4*da0073e9SAndroid Build Coastguard Worker C10_DEFINE_bool(caffe2_cpu_numa_enabled, false, "Use NUMA whenever possible.");
5*da0073e9SAndroid Build Coastguard Worker 
6*da0073e9SAndroid Build Coastguard Worker #if defined(__linux__) && defined(C10_USE_NUMA) && !defined(C10_MOBILE)
7*da0073e9SAndroid Build Coastguard Worker #include <numa.h>
8*da0073e9SAndroid Build Coastguard Worker #include <numaif.h>
9*da0073e9SAndroid Build Coastguard Worker #include <unistd.h>
10*da0073e9SAndroid Build Coastguard Worker #define C10_ENABLE_NUMA
11*da0073e9SAndroid Build Coastguard Worker #endif
12*da0073e9SAndroid Build Coastguard Worker 
13*da0073e9SAndroid Build Coastguard Worker // This code used to have a lot of VLOGs. However, because allocation might be
14*da0073e9SAndroid Build Coastguard Worker // triggered during static initialization, it's unsafe to invoke VLOG here
15*da0073e9SAndroid Build Coastguard Worker 
16*da0073e9SAndroid Build Coastguard Worker namespace c10 {
17*da0073e9SAndroid Build Coastguard Worker 
18*da0073e9SAndroid Build Coastguard Worker #ifdef C10_ENABLE_NUMA
IsNUMAEnabled()19*da0073e9SAndroid Build Coastguard Worker bool IsNUMAEnabled() {
20*da0073e9SAndroid Build Coastguard Worker   return FLAGS_caffe2_cpu_numa_enabled && numa_available() >= 0;
21*da0073e9SAndroid Build Coastguard Worker }
22*da0073e9SAndroid Build Coastguard Worker 
NUMABind(int numa_node_id)23*da0073e9SAndroid Build Coastguard Worker void NUMABind(int numa_node_id) {
24*da0073e9SAndroid Build Coastguard Worker   if (numa_node_id < 0) {
25*da0073e9SAndroid Build Coastguard Worker     return;
26*da0073e9SAndroid Build Coastguard Worker   }
27*da0073e9SAndroid Build Coastguard Worker   if (!IsNUMAEnabled()) {
28*da0073e9SAndroid Build Coastguard Worker     return;
29*da0073e9SAndroid Build Coastguard Worker   }
30*da0073e9SAndroid Build Coastguard Worker 
31*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
32*da0073e9SAndroid Build Coastguard Worker       numa_node_id <= numa_max_node(),
33*da0073e9SAndroid Build Coastguard Worker       "NUMA node id ",
34*da0073e9SAndroid Build Coastguard Worker       numa_node_id,
35*da0073e9SAndroid Build Coastguard Worker       " is unavailable");
36*da0073e9SAndroid Build Coastguard Worker 
37*da0073e9SAndroid Build Coastguard Worker   auto bm = numa_allocate_nodemask();
38*da0073e9SAndroid Build Coastguard Worker   numa_bitmask_setbit(bm, numa_node_id);
39*da0073e9SAndroid Build Coastguard Worker   numa_bind(bm);
40*da0073e9SAndroid Build Coastguard Worker   numa_bitmask_free(bm);
41*da0073e9SAndroid Build Coastguard Worker }
42*da0073e9SAndroid Build Coastguard Worker 
GetNUMANode(const void * ptr)43*da0073e9SAndroid Build Coastguard Worker int GetNUMANode(const void* ptr) {
44*da0073e9SAndroid Build Coastguard Worker   if (!IsNUMAEnabled()) {
45*da0073e9SAndroid Build Coastguard Worker     return -1;
46*da0073e9SAndroid Build Coastguard Worker   }
47*da0073e9SAndroid Build Coastguard Worker   AT_ASSERT(ptr);
48*da0073e9SAndroid Build Coastguard Worker 
49*da0073e9SAndroid Build Coastguard Worker   int numa_node = -1;
50*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
51*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
52*da0073e9SAndroid Build Coastguard Worker       get_mempolicy(
53*da0073e9SAndroid Build Coastguard Worker           &numa_node,
54*da0073e9SAndroid Build Coastguard Worker           nullptr,
55*da0073e9SAndroid Build Coastguard Worker           0,
56*da0073e9SAndroid Build Coastguard Worker           const_cast<void*>(ptr),
57*da0073e9SAndroid Build Coastguard Worker           MPOL_F_NODE | MPOL_F_ADDR) == 0,
58*da0073e9SAndroid Build Coastguard Worker       "Unable to get memory policy, errno:",
59*da0073e9SAndroid Build Coastguard Worker       errno);
60*da0073e9SAndroid Build Coastguard Worker   return numa_node;
61*da0073e9SAndroid Build Coastguard Worker }
62*da0073e9SAndroid Build Coastguard Worker 
GetNumNUMANodes()63*da0073e9SAndroid Build Coastguard Worker int GetNumNUMANodes() {
64*da0073e9SAndroid Build Coastguard Worker   if (!IsNUMAEnabled()) {
65*da0073e9SAndroid Build Coastguard Worker     return -1;
66*da0073e9SAndroid Build Coastguard Worker   }
67*da0073e9SAndroid Build Coastguard Worker 
68*da0073e9SAndroid Build Coastguard Worker   return numa_num_configured_nodes();
69*da0073e9SAndroid Build Coastguard Worker }
70*da0073e9SAndroid Build Coastguard Worker 
NUMAMove(void * ptr,size_t size,int numa_node_id)71*da0073e9SAndroid Build Coastguard Worker void NUMAMove(void* ptr, size_t size, int numa_node_id) {
72*da0073e9SAndroid Build Coastguard Worker   if (numa_node_id < 0) {
73*da0073e9SAndroid Build Coastguard Worker     return;
74*da0073e9SAndroid Build Coastguard Worker   }
75*da0073e9SAndroid Build Coastguard Worker   if (!IsNUMAEnabled()) {
76*da0073e9SAndroid Build Coastguard Worker     return;
77*da0073e9SAndroid Build Coastguard Worker   }
78*da0073e9SAndroid Build Coastguard Worker   AT_ASSERT(ptr);
79*da0073e9SAndroid Build Coastguard Worker 
80*da0073e9SAndroid Build Coastguard Worker   uintptr_t page_start_ptr =
81*da0073e9SAndroid Build Coastguard Worker       ((reinterpret_cast<uintptr_t>(ptr)) & ~(getpagesize() - 1));
82*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(*-conversions)
83*da0073e9SAndroid Build Coastguard Worker   ptrdiff_t offset = reinterpret_cast<uintptr_t>(ptr) - page_start_ptr;
84*da0073e9SAndroid Build Coastguard Worker   // Avoid extra dynamic allocation and NUMA api calls
85*da0073e9SAndroid Build Coastguard Worker   AT_ASSERT(
86*da0073e9SAndroid Build Coastguard Worker       numa_node_id >= 0 &&
87*da0073e9SAndroid Build Coastguard Worker       static_cast<unsigned>(numa_node_id) < sizeof(unsigned long) * 8);
88*da0073e9SAndroid Build Coastguard Worker   unsigned long mask = 1UL << numa_node_id;
89*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(performance-no-int-to-ptr)
90*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
91*da0073e9SAndroid Build Coastguard Worker       mbind(
92*da0073e9SAndroid Build Coastguard Worker           reinterpret_cast<void*>(page_start_ptr),
93*da0073e9SAndroid Build Coastguard Worker           size + offset,
94*da0073e9SAndroid Build Coastguard Worker           MPOL_BIND,
95*da0073e9SAndroid Build Coastguard Worker           &mask,
96*da0073e9SAndroid Build Coastguard Worker           sizeof(mask) * 8,
97*da0073e9SAndroid Build Coastguard Worker           MPOL_MF_MOVE | MPOL_MF_STRICT) == 0,
98*da0073e9SAndroid Build Coastguard Worker       "Could not move memory to a NUMA node");
99*da0073e9SAndroid Build Coastguard Worker }
100*da0073e9SAndroid Build Coastguard Worker 
GetCurrentNUMANode()101*da0073e9SAndroid Build Coastguard Worker int GetCurrentNUMANode() {
102*da0073e9SAndroid Build Coastguard Worker   if (!IsNUMAEnabled()) {
103*da0073e9SAndroid Build Coastguard Worker     return -1;
104*da0073e9SAndroid Build Coastguard Worker   }
105*da0073e9SAndroid Build Coastguard Worker 
106*da0073e9SAndroid Build Coastguard Worker   auto n = numa_node_of_cpu(sched_getcpu());
107*da0073e9SAndroid Build Coastguard Worker   return n;
108*da0073e9SAndroid Build Coastguard Worker }
109*da0073e9SAndroid Build Coastguard Worker 
110*da0073e9SAndroid Build Coastguard Worker #else // C10_ENABLE_NUMA
111*da0073e9SAndroid Build Coastguard Worker 
112*da0073e9SAndroid Build Coastguard Worker bool IsNUMAEnabled() {
113*da0073e9SAndroid Build Coastguard Worker   return false;
114*da0073e9SAndroid Build Coastguard Worker }
115*da0073e9SAndroid Build Coastguard Worker 
116*da0073e9SAndroid Build Coastguard Worker void NUMABind(int numa_node_id) {}
117*da0073e9SAndroid Build Coastguard Worker 
118*da0073e9SAndroid Build Coastguard Worker int GetNUMANode(const void* ptr) {
119*da0073e9SAndroid Build Coastguard Worker   return -1;
120*da0073e9SAndroid Build Coastguard Worker }
121*da0073e9SAndroid Build Coastguard Worker 
122*da0073e9SAndroid Build Coastguard Worker int GetNumNUMANodes() {
123*da0073e9SAndroid Build Coastguard Worker   return -1;
124*da0073e9SAndroid Build Coastguard Worker }
125*da0073e9SAndroid Build Coastguard Worker 
126*da0073e9SAndroid Build Coastguard Worker void NUMAMove(void* ptr, size_t size, int numa_node_id) {}
127*da0073e9SAndroid Build Coastguard Worker 
128*da0073e9SAndroid Build Coastguard Worker int GetCurrentNUMANode() {
129*da0073e9SAndroid Build Coastguard Worker   return -1;
130*da0073e9SAndroid Build Coastguard Worker }
131*da0073e9SAndroid Build Coastguard Worker 
132*da0073e9SAndroid Build Coastguard Worker #endif // C10_NUMA_ENABLED
133*da0073e9SAndroid Build Coastguard Worker 
134*da0073e9SAndroid Build Coastguard Worker } // namespace c10
135