1*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h>
2*da0073e9SAndroid Build Coastguard Worker #include <c10/util/numa.h>
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Worker C10_DEFINE_bool(caffe2_cpu_numa_enabled, false, "Use NUMA whenever possible.");
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Worker #if defined(__linux__) && defined(C10_USE_NUMA) && !defined(C10_MOBILE)
7*da0073e9SAndroid Build Coastguard Worker #include <numa.h>
8*da0073e9SAndroid Build Coastguard Worker #include <numaif.h>
9*da0073e9SAndroid Build Coastguard Worker #include <unistd.h>
10*da0073e9SAndroid Build Coastguard Worker #define C10_ENABLE_NUMA
11*da0073e9SAndroid Build Coastguard Worker #endif
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker // This code used to have a lot of VLOGs. However, because allocation might be
14*da0073e9SAndroid Build Coastguard Worker // triggered during static initialization, it's unsafe to invoke VLOG here
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker namespace c10 {
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker #ifdef C10_ENABLE_NUMA
IsNUMAEnabled()19*da0073e9SAndroid Build Coastguard Worker bool IsNUMAEnabled() {
20*da0073e9SAndroid Build Coastguard Worker return FLAGS_caffe2_cpu_numa_enabled && numa_available() >= 0;
21*da0073e9SAndroid Build Coastguard Worker }
22*da0073e9SAndroid Build Coastguard Worker
NUMABind(int numa_node_id)23*da0073e9SAndroid Build Coastguard Worker void NUMABind(int numa_node_id) {
24*da0073e9SAndroid Build Coastguard Worker if (numa_node_id < 0) {
25*da0073e9SAndroid Build Coastguard Worker return;
26*da0073e9SAndroid Build Coastguard Worker }
27*da0073e9SAndroid Build Coastguard Worker if (!IsNUMAEnabled()) {
28*da0073e9SAndroid Build Coastguard Worker return;
29*da0073e9SAndroid Build Coastguard Worker }
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
32*da0073e9SAndroid Build Coastguard Worker numa_node_id <= numa_max_node(),
33*da0073e9SAndroid Build Coastguard Worker "NUMA node id ",
34*da0073e9SAndroid Build Coastguard Worker numa_node_id,
35*da0073e9SAndroid Build Coastguard Worker " is unavailable");
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker auto bm = numa_allocate_nodemask();
38*da0073e9SAndroid Build Coastguard Worker numa_bitmask_setbit(bm, numa_node_id);
39*da0073e9SAndroid Build Coastguard Worker numa_bind(bm);
40*da0073e9SAndroid Build Coastguard Worker numa_bitmask_free(bm);
41*da0073e9SAndroid Build Coastguard Worker }
42*da0073e9SAndroid Build Coastguard Worker
GetNUMANode(const void * ptr)43*da0073e9SAndroid Build Coastguard Worker int GetNUMANode(const void* ptr) {
44*da0073e9SAndroid Build Coastguard Worker if (!IsNUMAEnabled()) {
45*da0073e9SAndroid Build Coastguard Worker return -1;
46*da0073e9SAndroid Build Coastguard Worker }
47*da0073e9SAndroid Build Coastguard Worker AT_ASSERT(ptr);
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker int numa_node = -1;
50*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
51*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
52*da0073e9SAndroid Build Coastguard Worker get_mempolicy(
53*da0073e9SAndroid Build Coastguard Worker &numa_node,
54*da0073e9SAndroid Build Coastguard Worker nullptr,
55*da0073e9SAndroid Build Coastguard Worker 0,
56*da0073e9SAndroid Build Coastguard Worker const_cast<void*>(ptr),
57*da0073e9SAndroid Build Coastguard Worker MPOL_F_NODE | MPOL_F_ADDR) == 0,
58*da0073e9SAndroid Build Coastguard Worker "Unable to get memory policy, errno:",
59*da0073e9SAndroid Build Coastguard Worker errno);
60*da0073e9SAndroid Build Coastguard Worker return numa_node;
61*da0073e9SAndroid Build Coastguard Worker }
62*da0073e9SAndroid Build Coastguard Worker
GetNumNUMANodes()63*da0073e9SAndroid Build Coastguard Worker int GetNumNUMANodes() {
64*da0073e9SAndroid Build Coastguard Worker if (!IsNUMAEnabled()) {
65*da0073e9SAndroid Build Coastguard Worker return -1;
66*da0073e9SAndroid Build Coastguard Worker }
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker return numa_num_configured_nodes();
69*da0073e9SAndroid Build Coastguard Worker }
70*da0073e9SAndroid Build Coastguard Worker
NUMAMove(void * ptr,size_t size,int numa_node_id)71*da0073e9SAndroid Build Coastguard Worker void NUMAMove(void* ptr, size_t size, int numa_node_id) {
72*da0073e9SAndroid Build Coastguard Worker if (numa_node_id < 0) {
73*da0073e9SAndroid Build Coastguard Worker return;
74*da0073e9SAndroid Build Coastguard Worker }
75*da0073e9SAndroid Build Coastguard Worker if (!IsNUMAEnabled()) {
76*da0073e9SAndroid Build Coastguard Worker return;
77*da0073e9SAndroid Build Coastguard Worker }
78*da0073e9SAndroid Build Coastguard Worker AT_ASSERT(ptr);
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker uintptr_t page_start_ptr =
81*da0073e9SAndroid Build Coastguard Worker ((reinterpret_cast<uintptr_t>(ptr)) & ~(getpagesize() - 1));
82*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(*-conversions)
83*da0073e9SAndroid Build Coastguard Worker ptrdiff_t offset = reinterpret_cast<uintptr_t>(ptr) - page_start_ptr;
84*da0073e9SAndroid Build Coastguard Worker // Avoid extra dynamic allocation and NUMA api calls
85*da0073e9SAndroid Build Coastguard Worker AT_ASSERT(
86*da0073e9SAndroid Build Coastguard Worker numa_node_id >= 0 &&
87*da0073e9SAndroid Build Coastguard Worker static_cast<unsigned>(numa_node_id) < sizeof(unsigned long) * 8);
88*da0073e9SAndroid Build Coastguard Worker unsigned long mask = 1UL << numa_node_id;
89*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(performance-no-int-to-ptr)
90*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
91*da0073e9SAndroid Build Coastguard Worker mbind(
92*da0073e9SAndroid Build Coastguard Worker reinterpret_cast<void*>(page_start_ptr),
93*da0073e9SAndroid Build Coastguard Worker size + offset,
94*da0073e9SAndroid Build Coastguard Worker MPOL_BIND,
95*da0073e9SAndroid Build Coastguard Worker &mask,
96*da0073e9SAndroid Build Coastguard Worker sizeof(mask) * 8,
97*da0073e9SAndroid Build Coastguard Worker MPOL_MF_MOVE | MPOL_MF_STRICT) == 0,
98*da0073e9SAndroid Build Coastguard Worker "Could not move memory to a NUMA node");
99*da0073e9SAndroid Build Coastguard Worker }
100*da0073e9SAndroid Build Coastguard Worker
GetCurrentNUMANode()101*da0073e9SAndroid Build Coastguard Worker int GetCurrentNUMANode() {
102*da0073e9SAndroid Build Coastguard Worker if (!IsNUMAEnabled()) {
103*da0073e9SAndroid Build Coastguard Worker return -1;
104*da0073e9SAndroid Build Coastguard Worker }
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Worker auto n = numa_node_of_cpu(sched_getcpu());
107*da0073e9SAndroid Build Coastguard Worker return n;
108*da0073e9SAndroid Build Coastguard Worker }
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker #else // C10_ENABLE_NUMA
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker bool IsNUMAEnabled() {
113*da0073e9SAndroid Build Coastguard Worker return false;
114*da0073e9SAndroid Build Coastguard Worker }
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker void NUMABind(int numa_node_id) {}
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker int GetNUMANode(const void* ptr) {
119*da0073e9SAndroid Build Coastguard Worker return -1;
120*da0073e9SAndroid Build Coastguard Worker }
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker int GetNumNUMANodes() {
123*da0073e9SAndroid Build Coastguard Worker return -1;
124*da0073e9SAndroid Build Coastguard Worker }
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker void NUMAMove(void* ptr, size_t size, int numa_node_id) {}
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker int GetCurrentNUMANode() {
129*da0073e9SAndroid Build Coastguard Worker return -1;
130*da0073e9SAndroid Build Coastguard Worker }
131*da0073e9SAndroid Build Coastguard Worker
132*da0073e9SAndroid Build Coastguard Worker #endif // C10_NUMA_ENABLED
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker } // namespace c10
135