1*da0073e9SAndroid Build Coastguard Worker /* Standard C headers */
2*da0073e9SAndroid Build Coastguard Worker #include <stdint.h>
3*da0073e9SAndroid Build Coastguard Worker #include <stdbool.h>
4*da0073e9SAndroid Build Coastguard Worker #include <stdlib.h>
5*da0073e9SAndroid Build Coastguard Worker #include <string.h>
6*da0073e9SAndroid Build Coastguard Worker #include <assert.h>
7*da0073e9SAndroid Build Coastguard Worker #include <limits>
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker #ifdef _MSC_VER
10*da0073e9SAndroid Build Coastguard Worker #include <cstdio>
11*da0073e9SAndroid Build Coastguard Worker #undef min
12*da0073e9SAndroid Build Coastguard Worker #else
13*da0073e9SAndroid Build Coastguard Worker /* POSIX headers */
14*da0073e9SAndroid Build Coastguard Worker #include <unistd.h>
15*da0073e9SAndroid Build Coastguard Worker #endif
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker /* Library header */
18*da0073e9SAndroid Build Coastguard Worker #include "caffe2/utils/fixed_divisor.h"
19*da0073e9SAndroid Build Coastguard Worker #include "caffe2/utils/threadpool/pthreadpool.h"
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Logging.h>
22*da0073e9SAndroid Build Coastguard Worker
divide_round_up(size_t dividend,size_t divisor)23*da0073e9SAndroid Build Coastguard Worker static inline size_t divide_round_up(size_t dividend, size_t divisor) {
24*da0073e9SAndroid Build Coastguard Worker if (dividend % divisor == 0) {
25*da0073e9SAndroid Build Coastguard Worker return dividend / divisor;
26*da0073e9SAndroid Build Coastguard Worker } else {
27*da0073e9SAndroid Build Coastguard Worker return dividend / divisor + 1;
28*da0073e9SAndroid Build Coastguard Worker }
29*da0073e9SAndroid Build Coastguard Worker }
30*da0073e9SAndroid Build Coastguard Worker
min(size_t a,size_t b)31*da0073e9SAndroid Build Coastguard Worker static inline size_t min(size_t a, size_t b) {
32*da0073e9SAndroid Build Coastguard Worker return a < b ? a : b;
33*da0073e9SAndroid Build Coastguard Worker }
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker struct compute_1d_tiled_context {
36*da0073e9SAndroid Build Coastguard Worker legacy_pthreadpool_function_1d_tiled_t function;
37*da0073e9SAndroid Build Coastguard Worker void* argument;
38*da0073e9SAndroid Build Coastguard Worker size_t range;
39*da0073e9SAndroid Build Coastguard Worker size_t tile;
40*da0073e9SAndroid Build Coastguard Worker };
41*da0073e9SAndroid Build Coastguard Worker
compute_1d_tiled(void * context_,size_t linear_index)42*da0073e9SAndroid Build Coastguard Worker static void compute_1d_tiled(void* context_, size_t linear_index) {
43*da0073e9SAndroid Build Coastguard Worker const struct compute_1d_tiled_context* context = (compute_1d_tiled_context*) context_;
44*da0073e9SAndroid Build Coastguard Worker const size_t tile_index = linear_index;
45*da0073e9SAndroid Build Coastguard Worker const size_t index = tile_index * context->tile;
46*da0073e9SAndroid Build Coastguard Worker const size_t tile = min(context->tile, context->range - index);
47*da0073e9SAndroid Build Coastguard Worker context->function(context->argument, index, tile);
48*da0073e9SAndroid Build Coastguard Worker }
49*da0073e9SAndroid Build Coastguard Worker
legacy_pthreadpool_compute_1d_tiled(legacy_pthreadpool_t threadpool,legacy_pthreadpool_function_1d_tiled_t function,void * argument,size_t range,size_t tile)50*da0073e9SAndroid Build Coastguard Worker void legacy_pthreadpool_compute_1d_tiled(
51*da0073e9SAndroid Build Coastguard Worker legacy_pthreadpool_t threadpool,
52*da0073e9SAndroid Build Coastguard Worker legacy_pthreadpool_function_1d_tiled_t function,
53*da0073e9SAndroid Build Coastguard Worker void* argument,
54*da0073e9SAndroid Build Coastguard Worker size_t range,
55*da0073e9SAndroid Build Coastguard Worker size_t tile)
56*da0073e9SAndroid Build Coastguard Worker {
57*da0073e9SAndroid Build Coastguard Worker if (threadpool == nullptr) {
58*da0073e9SAndroid Build Coastguard Worker /* No thread pool provided: execute function sequentially on the calling thread */
59*da0073e9SAndroid Build Coastguard Worker for (size_t i = 0; i < range; i += tile) {
60*da0073e9SAndroid Build Coastguard Worker function(argument, i, min(range - i, tile));
61*da0073e9SAndroid Build Coastguard Worker }
62*da0073e9SAndroid Build Coastguard Worker } else {
63*da0073e9SAndroid Build Coastguard Worker /* Execute in parallel on the thread pool using linearized index */
64*da0073e9SAndroid Build Coastguard Worker const size_t tile_range = divide_round_up(range, tile);
65*da0073e9SAndroid Build Coastguard Worker struct compute_1d_tiled_context context = {/*.function = */ function,
66*da0073e9SAndroid Build Coastguard Worker /*.argument = */ argument,
67*da0073e9SAndroid Build Coastguard Worker /*.range = */ range,
68*da0073e9SAndroid Build Coastguard Worker /*.tile = */ tile};
69*da0073e9SAndroid Build Coastguard Worker legacy_pthreadpool_compute_1d(threadpool, (legacy_pthreadpool_function_1d_t) compute_1d_tiled, &context, tile_range);
70*da0073e9SAndroid Build Coastguard Worker }
71*da0073e9SAndroid Build Coastguard Worker }
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker struct compute_2d_context {
74*da0073e9SAndroid Build Coastguard Worker legacy_pthreadpool_function_2d_t function;
75*da0073e9SAndroid Build Coastguard Worker void* argument;
76*da0073e9SAndroid Build Coastguard Worker caffe2::FixedDivisor<int32_t> range_j;
77*da0073e9SAndroid Build Coastguard Worker };
78*da0073e9SAndroid Build Coastguard Worker
compute_2d(void * context_,size_t linear_index)79*da0073e9SAndroid Build Coastguard Worker static void compute_2d(void* context_, size_t linear_index) {
80*da0073e9SAndroid Build Coastguard Worker TORCH_DCHECK_LE(linear_index, std::numeric_limits<int32_t>::max());
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker const struct compute_2d_context* context = static_cast<compute_2d_context*>(context_);
83*da0073e9SAndroid Build Coastguard Worker int32_t q;
84*da0073e9SAndroid Build Coastguard Worker int32_t r;
85*da0073e9SAndroid Build Coastguard Worker context->range_j.DivMod(static_cast<int32_t>(linear_index), &q, &r);
86*da0073e9SAndroid Build Coastguard Worker context->function(context->argument, q, r);
87*da0073e9SAndroid Build Coastguard Worker }
88*da0073e9SAndroid Build Coastguard Worker
legacy_pthreadpool_compute_2d(legacy_pthreadpool_t threadpool,legacy_pthreadpool_function_2d_t function,void * argument,size_t range_i,size_t range_j)89*da0073e9SAndroid Build Coastguard Worker void legacy_pthreadpool_compute_2d(
90*da0073e9SAndroid Build Coastguard Worker legacy_pthreadpool_t threadpool,
91*da0073e9SAndroid Build Coastguard Worker legacy_pthreadpool_function_2d_t function,
92*da0073e9SAndroid Build Coastguard Worker void* argument,
93*da0073e9SAndroid Build Coastguard Worker size_t range_i,
94*da0073e9SAndroid Build Coastguard Worker size_t range_j)
95*da0073e9SAndroid Build Coastguard Worker {
96*da0073e9SAndroid Build Coastguard Worker if (threadpool == nullptr) {
97*da0073e9SAndroid Build Coastguard Worker /* No thread pool provided: execute function sequentially on the calling thread */
98*da0073e9SAndroid Build Coastguard Worker for (size_t i = 0; i < range_i; i++) {
99*da0073e9SAndroid Build Coastguard Worker for (size_t j = 0; j < range_j; j++) {
100*da0073e9SAndroid Build Coastguard Worker function(argument, i, j);
101*da0073e9SAndroid Build Coastguard Worker }
102*da0073e9SAndroid Build Coastguard Worker }
103*da0073e9SAndroid Build Coastguard Worker } else {
104*da0073e9SAndroid Build Coastguard Worker TORCH_DCHECK_LE(range_i * range_j, (size_t)std::numeric_limits<int32_t>::max());
105*da0073e9SAndroid Build Coastguard Worker /* Execute in parallel on the thread pool using linearized index */
106*da0073e9SAndroid Build Coastguard Worker struct compute_2d_context context = {
107*da0073e9SAndroid Build Coastguard Worker /*.function = */ function,
108*da0073e9SAndroid Build Coastguard Worker /*.argument = */ argument,
109*da0073e9SAndroid Build Coastguard Worker /*.range_j = */ caffe2::FixedDivisor<int32_t>(range_j)};
110*da0073e9SAndroid Build Coastguard Worker legacy_pthreadpool_compute_1d(threadpool, (legacy_pthreadpool_function_1d_t) compute_2d, &context, range_i * range_j);
111*da0073e9SAndroid Build Coastguard Worker }
112*da0073e9SAndroid Build Coastguard Worker }
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker struct compute_2d_tiled_context {
115*da0073e9SAndroid Build Coastguard Worker legacy_pthreadpool_function_2d_tiled_t function;
116*da0073e9SAndroid Build Coastguard Worker void* argument;
117*da0073e9SAndroid Build Coastguard Worker caffe2::FixedDivisor<int32_t> tile_range_j;
118*da0073e9SAndroid Build Coastguard Worker size_t range_i;
119*da0073e9SAndroid Build Coastguard Worker size_t range_j;
120*da0073e9SAndroid Build Coastguard Worker size_t tile_i;
121*da0073e9SAndroid Build Coastguard Worker size_t tile_j;
122*da0073e9SAndroid Build Coastguard Worker };
123*da0073e9SAndroid Build Coastguard Worker
compute_2d_tiled(void * context_,size_t linear_index)124*da0073e9SAndroid Build Coastguard Worker static void compute_2d_tiled(void* context_, size_t linear_index) {
125*da0073e9SAndroid Build Coastguard Worker int32_t q;
126*da0073e9SAndroid Build Coastguard Worker int32_t r;
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker const struct compute_2d_tiled_context* context = static_cast<compute_2d_tiled_context*>(context_);
129*da0073e9SAndroid Build Coastguard Worker context->tile_range_j.DivMod(linear_index, &q, &r);
130*da0073e9SAndroid Build Coastguard Worker const size_t max_tile_i = context->tile_i;
131*da0073e9SAndroid Build Coastguard Worker const size_t max_tile_j = context->tile_j;
132*da0073e9SAndroid Build Coastguard Worker const size_t index_i = q * max_tile_i;
133*da0073e9SAndroid Build Coastguard Worker const size_t index_j = r * max_tile_j;
134*da0073e9SAndroid Build Coastguard Worker const size_t tile_i = min(max_tile_i, context->range_i - index_i);
135*da0073e9SAndroid Build Coastguard Worker const size_t tile_j = min(max_tile_j, context->range_j - index_j);
136*da0073e9SAndroid Build Coastguard Worker context->function(context->argument, index_i, index_j, tile_i, tile_j);
137*da0073e9SAndroid Build Coastguard Worker }
138*da0073e9SAndroid Build Coastguard Worker
legacy_pthreadpool_compute_2d_tiled(legacy_pthreadpool_t threadpool,legacy_pthreadpool_function_2d_tiled_t function,void * argument,size_t range_i,size_t range_j,size_t tile_i,size_t tile_j)139*da0073e9SAndroid Build Coastguard Worker void legacy_pthreadpool_compute_2d_tiled(
140*da0073e9SAndroid Build Coastguard Worker legacy_pthreadpool_t threadpool,
141*da0073e9SAndroid Build Coastguard Worker legacy_pthreadpool_function_2d_tiled_t function,
142*da0073e9SAndroid Build Coastguard Worker void* argument,
143*da0073e9SAndroid Build Coastguard Worker size_t range_i,
144*da0073e9SAndroid Build Coastguard Worker size_t range_j,
145*da0073e9SAndroid Build Coastguard Worker size_t tile_i,
146*da0073e9SAndroid Build Coastguard Worker size_t tile_j)
147*da0073e9SAndroid Build Coastguard Worker {
148*da0073e9SAndroid Build Coastguard Worker if (threadpool == nullptr) {
149*da0073e9SAndroid Build Coastguard Worker /* No thread pool provided: execute function sequentially on the calling thread */
150*da0073e9SAndroid Build Coastguard Worker for (size_t i = 0; i < range_i; i += tile_i) {
151*da0073e9SAndroid Build Coastguard Worker for (size_t j = 0; j < range_j; j += tile_j) {
152*da0073e9SAndroid Build Coastguard Worker function(argument, i, j, min(range_i - i, tile_i), min(range_j - j, tile_j));
153*da0073e9SAndroid Build Coastguard Worker }
154*da0073e9SAndroid Build Coastguard Worker }
155*da0073e9SAndroid Build Coastguard Worker } else {
156*da0073e9SAndroid Build Coastguard Worker /* Execute in parallel on the thread pool using linearized index */
157*da0073e9SAndroid Build Coastguard Worker const size_t tile_range_i = divide_round_up(range_i, tile_i);
158*da0073e9SAndroid Build Coastguard Worker const size_t tile_range_j = divide_round_up(range_j, tile_j);
159*da0073e9SAndroid Build Coastguard Worker TORCH_DCHECK_LE(
160*da0073e9SAndroid Build Coastguard Worker tile_range_i * tile_range_j,
161*da0073e9SAndroid Build Coastguard Worker (size_t)std::numeric_limits<int32_t>::max());
162*da0073e9SAndroid Build Coastguard Worker struct compute_2d_tiled_context context = {
163*da0073e9SAndroid Build Coastguard Worker /*.function = */ function,
164*da0073e9SAndroid Build Coastguard Worker /*.argument = */ argument,
165*da0073e9SAndroid Build Coastguard Worker /*.tile_range_j = */ caffe2::FixedDivisor<int32_t>(tile_range_j),
166*da0073e9SAndroid Build Coastguard Worker /*.range_i = */ range_i,
167*da0073e9SAndroid Build Coastguard Worker /*.range_j = */ range_j,
168*da0073e9SAndroid Build Coastguard Worker /*.tile_i = */ tile_i,
169*da0073e9SAndroid Build Coastguard Worker /*.tile_j = */ tile_j};
170*da0073e9SAndroid Build Coastguard Worker legacy_pthreadpool_compute_1d(threadpool, (legacy_pthreadpool_function_1d_t) compute_2d_tiled, &context, tile_range_i * tile_range_j);
171*da0073e9SAndroid Build Coastguard Worker }
172*da0073e9SAndroid Build Coastguard Worker }
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Worker struct compute_3d_tiled_context {
175*da0073e9SAndroid Build Coastguard Worker legacy_pthreadpool_function_3d_tiled_t function;
176*da0073e9SAndroid Build Coastguard Worker void* argument;
177*da0073e9SAndroid Build Coastguard Worker caffe2::FixedDivisor<int32_t> tile_range_j;
178*da0073e9SAndroid Build Coastguard Worker caffe2::FixedDivisor<int32_t> tile_range_k;
179*da0073e9SAndroid Build Coastguard Worker size_t range_i;
180*da0073e9SAndroid Build Coastguard Worker size_t range_j;
181*da0073e9SAndroid Build Coastguard Worker size_t range_k;
182*da0073e9SAndroid Build Coastguard Worker size_t tile_i;
183*da0073e9SAndroid Build Coastguard Worker size_t tile_j;
184*da0073e9SAndroid Build Coastguard Worker size_t tile_k;
185*da0073e9SAndroid Build Coastguard Worker };
186*da0073e9SAndroid Build Coastguard Worker
compute_3d_tiled(void * context_,size_t linear_index)187*da0073e9SAndroid Build Coastguard Worker static void compute_3d_tiled(
188*da0073e9SAndroid Build Coastguard Worker void* context_,
189*da0073e9SAndroid Build Coastguard Worker size_t linear_index) {
190*da0073e9SAndroid Build Coastguard Worker int32_t tile_index_ij, tile_index_k;
191*da0073e9SAndroid Build Coastguard Worker const struct compute_3d_tiled_context* context = static_cast<compute_3d_tiled_context*>(context_);
192*da0073e9SAndroid Build Coastguard Worker context->tile_range_k.DivMod(
193*da0073e9SAndroid Build Coastguard Worker static_cast<int32_t>(linear_index), &tile_index_ij, &tile_index_k);
194*da0073e9SAndroid Build Coastguard Worker int32_t tile_index_i, tile_index_j;
195*da0073e9SAndroid Build Coastguard Worker context->tile_range_j.DivMod(tile_index_ij, &tile_index_i, &tile_index_j);
196*da0073e9SAndroid Build Coastguard Worker const size_t max_tile_i = context->tile_i;
197*da0073e9SAndroid Build Coastguard Worker const size_t max_tile_j = context->tile_j;
198*da0073e9SAndroid Build Coastguard Worker const size_t max_tile_k = context->tile_k;
199*da0073e9SAndroid Build Coastguard Worker const size_t index_i = static_cast<uint32_t>(tile_index_i) * max_tile_i;
200*da0073e9SAndroid Build Coastguard Worker const size_t index_j = static_cast<uint32_t>(tile_index_j) * max_tile_j;
201*da0073e9SAndroid Build Coastguard Worker const size_t index_k = static_cast<uint32_t>(tile_index_k) * max_tile_k;
202*da0073e9SAndroid Build Coastguard Worker const size_t tile_i = min(max_tile_i, context->range_i - index_i);
203*da0073e9SAndroid Build Coastguard Worker const size_t tile_j = min(max_tile_j, context->range_j - index_j);
204*da0073e9SAndroid Build Coastguard Worker const size_t tile_k = min(max_tile_k, context->range_k - index_k);
205*da0073e9SAndroid Build Coastguard Worker context->function(
206*da0073e9SAndroid Build Coastguard Worker context->argument, index_i, index_j, index_k, tile_i, tile_j, tile_k);
207*da0073e9SAndroid Build Coastguard Worker }
208*da0073e9SAndroid Build Coastguard Worker
legacy_pthreadpool_compute_3d_tiled(legacy_pthreadpool_t threadpool,legacy_pthreadpool_function_3d_tiled_t function,void * argument,size_t range_i,size_t range_j,size_t range_k,size_t tile_i,size_t tile_j,size_t tile_k)209*da0073e9SAndroid Build Coastguard Worker void legacy_pthreadpool_compute_3d_tiled(
210*da0073e9SAndroid Build Coastguard Worker legacy_pthreadpool_t threadpool,
211*da0073e9SAndroid Build Coastguard Worker legacy_pthreadpool_function_3d_tiled_t function,
212*da0073e9SAndroid Build Coastguard Worker void* argument,
213*da0073e9SAndroid Build Coastguard Worker size_t range_i,
214*da0073e9SAndroid Build Coastguard Worker size_t range_j,
215*da0073e9SAndroid Build Coastguard Worker size_t range_k,
216*da0073e9SAndroid Build Coastguard Worker size_t tile_i,
217*da0073e9SAndroid Build Coastguard Worker size_t tile_j,
218*da0073e9SAndroid Build Coastguard Worker size_t tile_k) {
219*da0073e9SAndroid Build Coastguard Worker if (threadpool == nullptr) {
220*da0073e9SAndroid Build Coastguard Worker /* No thread pool provided: execute function sequentially on the calling
221*da0073e9SAndroid Build Coastguard Worker * thread */
222*da0073e9SAndroid Build Coastguard Worker for (size_t i = 0; i < range_i; i += tile_i) {
223*da0073e9SAndroid Build Coastguard Worker for (size_t j = 0; j < range_j; j += tile_j) {
224*da0073e9SAndroid Build Coastguard Worker for (size_t k = 0; k < range_k; k += tile_k) {
225*da0073e9SAndroid Build Coastguard Worker function(
226*da0073e9SAndroid Build Coastguard Worker argument,
227*da0073e9SAndroid Build Coastguard Worker i,
228*da0073e9SAndroid Build Coastguard Worker j,
229*da0073e9SAndroid Build Coastguard Worker k,
230*da0073e9SAndroid Build Coastguard Worker min(range_i - i, tile_i),
231*da0073e9SAndroid Build Coastguard Worker min(range_j - j, tile_j),
232*da0073e9SAndroid Build Coastguard Worker min(range_k - k, tile_k));
233*da0073e9SAndroid Build Coastguard Worker }
234*da0073e9SAndroid Build Coastguard Worker }
235*da0073e9SAndroid Build Coastguard Worker }
236*da0073e9SAndroid Build Coastguard Worker } else {
237*da0073e9SAndroid Build Coastguard Worker /* Execute in parallel on the thread pool using linearized index */
238*da0073e9SAndroid Build Coastguard Worker const size_t tile_range_i = divide_round_up(range_i, tile_i);
239*da0073e9SAndroid Build Coastguard Worker const size_t tile_range_j = divide_round_up(range_j, tile_j);
240*da0073e9SAndroid Build Coastguard Worker const size_t tile_range_k = divide_round_up(range_k, tile_k);
241*da0073e9SAndroid Build Coastguard Worker TORCH_DCHECK_LE(
242*da0073e9SAndroid Build Coastguard Worker tile_range_i * tile_range_j * tile_range_k,
243*da0073e9SAndroid Build Coastguard Worker (size_t)std::numeric_limits<int>::max());
244*da0073e9SAndroid Build Coastguard Worker struct compute_3d_tiled_context context = {
245*da0073e9SAndroid Build Coastguard Worker /*.function = */ function,
246*da0073e9SAndroid Build Coastguard Worker /*.argument = */ argument,
247*da0073e9SAndroid Build Coastguard Worker /*.tile_range_j = */ caffe2::FixedDivisor<int>(tile_range_j),
248*da0073e9SAndroid Build Coastguard Worker /*.tile_range_k = */ caffe2::FixedDivisor<int>(tile_range_k),
249*da0073e9SAndroid Build Coastguard Worker /*.range_i = */ range_i,
250*da0073e9SAndroid Build Coastguard Worker /*.range_j = */ range_j,
251*da0073e9SAndroid Build Coastguard Worker /*.range_k = */ range_k,
252*da0073e9SAndroid Build Coastguard Worker /*.tile_i = */ tile_i,
253*da0073e9SAndroid Build Coastguard Worker /*.tile_j = */ tile_j,
254*da0073e9SAndroid Build Coastguard Worker /*.tile_k = */ tile_k};
255*da0073e9SAndroid Build Coastguard Worker legacy_pthreadpool_compute_1d(
256*da0073e9SAndroid Build Coastguard Worker threadpool,
257*da0073e9SAndroid Build Coastguard Worker (legacy_pthreadpool_function_1d_t)compute_3d_tiled,
258*da0073e9SAndroid Build Coastguard Worker &context,
259*da0073e9SAndroid Build Coastguard Worker tile_range_i * tile_range_j * tile_range_k);
260*da0073e9SAndroid Build Coastguard Worker }
261*da0073e9SAndroid Build Coastguard Worker }
262*da0073e9SAndroid Build Coastguard Worker
263*da0073e9SAndroid Build Coastguard Worker struct compute_4d_tiled_context {
264*da0073e9SAndroid Build Coastguard Worker legacy_pthreadpool_function_4d_tiled_t function;
265*da0073e9SAndroid Build Coastguard Worker void* argument;
266*da0073e9SAndroid Build Coastguard Worker caffe2::FixedDivisor<int32_t> tile_range_kl;
267*da0073e9SAndroid Build Coastguard Worker caffe2::FixedDivisor<int32_t> tile_range_j;
268*da0073e9SAndroid Build Coastguard Worker caffe2::FixedDivisor<int32_t> tile_range_l;
269*da0073e9SAndroid Build Coastguard Worker size_t range_i;
270*da0073e9SAndroid Build Coastguard Worker size_t range_j;
271*da0073e9SAndroid Build Coastguard Worker size_t range_k;
272*da0073e9SAndroid Build Coastguard Worker size_t range_l;
273*da0073e9SAndroid Build Coastguard Worker size_t tile_i;
274*da0073e9SAndroid Build Coastguard Worker size_t tile_j;
275*da0073e9SAndroid Build Coastguard Worker size_t tile_k;
276*da0073e9SAndroid Build Coastguard Worker size_t tile_l;
277*da0073e9SAndroid Build Coastguard Worker };
278*da0073e9SAndroid Build Coastguard Worker
compute_4d_tiled(void * context_,size_t linear_index)279*da0073e9SAndroid Build Coastguard Worker static void compute_4d_tiled(
280*da0073e9SAndroid Build Coastguard Worker void* context_,
281*da0073e9SAndroid Build Coastguard Worker size_t linear_index) {
282*da0073e9SAndroid Build Coastguard Worker int32_t tile_index_ij, tile_index_kl;
283*da0073e9SAndroid Build Coastguard Worker const struct compute_4d_tiled_context* context = static_cast<compute_4d_tiled_context*>(context_);
284*da0073e9SAndroid Build Coastguard Worker context->tile_range_kl.DivMod(
285*da0073e9SAndroid Build Coastguard Worker static_cast<int32_t>(linear_index), &tile_index_ij, &tile_index_kl);
286*da0073e9SAndroid Build Coastguard Worker int32_t tile_index_i, tile_index_j;
287*da0073e9SAndroid Build Coastguard Worker context->tile_range_j.DivMod(tile_index_ij, &tile_index_i, &tile_index_j);
288*da0073e9SAndroid Build Coastguard Worker int32_t tile_index_k, tile_index_l;
289*da0073e9SAndroid Build Coastguard Worker context->tile_range_l.DivMod(tile_index_kl, &tile_index_k, &tile_index_l);
290*da0073e9SAndroid Build Coastguard Worker const size_t max_tile_i = context->tile_i;
291*da0073e9SAndroid Build Coastguard Worker const size_t max_tile_j = context->tile_j;
292*da0073e9SAndroid Build Coastguard Worker const size_t max_tile_k = context->tile_k;
293*da0073e9SAndroid Build Coastguard Worker const size_t max_tile_l = context->tile_l;
294*da0073e9SAndroid Build Coastguard Worker const size_t index_i = static_cast<uint32_t>(tile_index_i) * max_tile_i;
295*da0073e9SAndroid Build Coastguard Worker const size_t index_j = static_cast<uint32_t>(tile_index_j) * max_tile_j;
296*da0073e9SAndroid Build Coastguard Worker const size_t index_k = static_cast<uint32_t>(tile_index_k) * max_tile_k;
297*da0073e9SAndroid Build Coastguard Worker const size_t index_l = static_cast<uint32_t>(tile_index_l) * max_tile_l;
298*da0073e9SAndroid Build Coastguard Worker const size_t tile_i = min(max_tile_i, context->range_i - index_i);
299*da0073e9SAndroid Build Coastguard Worker const size_t tile_j = min(max_tile_j, context->range_j - index_j);
300*da0073e9SAndroid Build Coastguard Worker const size_t tile_k = min(max_tile_k, context->range_k - index_k);
301*da0073e9SAndroid Build Coastguard Worker const size_t tile_l = min(max_tile_l, context->range_l - index_l);
302*da0073e9SAndroid Build Coastguard Worker context->function(
303*da0073e9SAndroid Build Coastguard Worker context->argument,
304*da0073e9SAndroid Build Coastguard Worker index_i,
305*da0073e9SAndroid Build Coastguard Worker index_j,
306*da0073e9SAndroid Build Coastguard Worker index_k,
307*da0073e9SAndroid Build Coastguard Worker index_l,
308*da0073e9SAndroid Build Coastguard Worker tile_i,
309*da0073e9SAndroid Build Coastguard Worker tile_j,
310*da0073e9SAndroid Build Coastguard Worker tile_k,
311*da0073e9SAndroid Build Coastguard Worker tile_l);
312*da0073e9SAndroid Build Coastguard Worker }
313*da0073e9SAndroid Build Coastguard Worker
legacy_pthreadpool_compute_4d_tiled(legacy_pthreadpool_t threadpool,legacy_pthreadpool_function_4d_tiled_t function,void * argument,size_t range_i,size_t range_j,size_t range_k,size_t range_l,size_t tile_i,size_t tile_j,size_t tile_k,size_t tile_l)314*da0073e9SAndroid Build Coastguard Worker void legacy_pthreadpool_compute_4d_tiled(
315*da0073e9SAndroid Build Coastguard Worker legacy_pthreadpool_t threadpool,
316*da0073e9SAndroid Build Coastguard Worker legacy_pthreadpool_function_4d_tiled_t function,
317*da0073e9SAndroid Build Coastguard Worker void* argument,
318*da0073e9SAndroid Build Coastguard Worker size_t range_i,
319*da0073e9SAndroid Build Coastguard Worker size_t range_j,
320*da0073e9SAndroid Build Coastguard Worker size_t range_k,
321*da0073e9SAndroid Build Coastguard Worker size_t range_l,
322*da0073e9SAndroid Build Coastguard Worker size_t tile_i,
323*da0073e9SAndroid Build Coastguard Worker size_t tile_j,
324*da0073e9SAndroid Build Coastguard Worker size_t tile_k,
325*da0073e9SAndroid Build Coastguard Worker size_t tile_l) {
326*da0073e9SAndroid Build Coastguard Worker if (threadpool == nullptr) {
327*da0073e9SAndroid Build Coastguard Worker /* No thread pool provided: execute function sequentially on the calling
328*da0073e9SAndroid Build Coastguard Worker * thread */
329*da0073e9SAndroid Build Coastguard Worker for (size_t i = 0; i < range_i; i += tile_i) {
330*da0073e9SAndroid Build Coastguard Worker for (size_t j = 0; j < range_j; j += tile_j) {
331*da0073e9SAndroid Build Coastguard Worker for (size_t k = 0; k < range_k; k += tile_k) {
332*da0073e9SAndroid Build Coastguard Worker for (size_t l = 0; l < range_l; l += tile_l) {
333*da0073e9SAndroid Build Coastguard Worker function(
334*da0073e9SAndroid Build Coastguard Worker argument,
335*da0073e9SAndroid Build Coastguard Worker i,
336*da0073e9SAndroid Build Coastguard Worker j,
337*da0073e9SAndroid Build Coastguard Worker k,
338*da0073e9SAndroid Build Coastguard Worker l,
339*da0073e9SAndroid Build Coastguard Worker min(range_i - i, tile_i),
340*da0073e9SAndroid Build Coastguard Worker min(range_j - j, tile_j),
341*da0073e9SAndroid Build Coastguard Worker min(range_k - k, tile_k),
342*da0073e9SAndroid Build Coastguard Worker min(range_l - l, tile_l));
343*da0073e9SAndroid Build Coastguard Worker }
344*da0073e9SAndroid Build Coastguard Worker }
345*da0073e9SAndroid Build Coastguard Worker }
346*da0073e9SAndroid Build Coastguard Worker }
347*da0073e9SAndroid Build Coastguard Worker } else {
348*da0073e9SAndroid Build Coastguard Worker /* Execute in parallel on the thread pool using linearized index */
349*da0073e9SAndroid Build Coastguard Worker const size_t tile_range_i = divide_round_up(range_i, tile_i);
350*da0073e9SAndroid Build Coastguard Worker const size_t tile_range_j = divide_round_up(range_j, tile_j);
351*da0073e9SAndroid Build Coastguard Worker const size_t tile_range_k = divide_round_up(range_k, tile_k);
352*da0073e9SAndroid Build Coastguard Worker const size_t tile_range_l = divide_round_up(range_l, tile_l);
353*da0073e9SAndroid Build Coastguard Worker TORCH_DCHECK_LE(
354*da0073e9SAndroid Build Coastguard Worker tile_range_i * tile_range_j * tile_range_k * tile_range_l,
355*da0073e9SAndroid Build Coastguard Worker (size_t)std::numeric_limits<int>::max());
356*da0073e9SAndroid Build Coastguard Worker struct compute_4d_tiled_context context = {
357*da0073e9SAndroid Build Coastguard Worker /*.function = */ function,
358*da0073e9SAndroid Build Coastguard Worker /*.argument = */ argument,
359*da0073e9SAndroid Build Coastguard Worker /*.tile_range_kl = */
360*da0073e9SAndroid Build Coastguard Worker caffe2::FixedDivisor<int>(tile_range_k * tile_range_l),
361*da0073e9SAndroid Build Coastguard Worker /*.tile_range_j = */ caffe2::FixedDivisor<int>(tile_range_j),
362*da0073e9SAndroid Build Coastguard Worker /*.tile_range_l = */ caffe2::FixedDivisor<int>(tile_range_l),
363*da0073e9SAndroid Build Coastguard Worker /*.range_i = */ range_i,
364*da0073e9SAndroid Build Coastguard Worker /*.range_j = */ range_j,
365*da0073e9SAndroid Build Coastguard Worker /*.range_k = */ range_k,
366*da0073e9SAndroid Build Coastguard Worker /*.range_l = */ range_l,
367*da0073e9SAndroid Build Coastguard Worker /*.tile_i = */ tile_i,
368*da0073e9SAndroid Build Coastguard Worker /*.tile_j = */ tile_j,
369*da0073e9SAndroid Build Coastguard Worker /*.tile_k = */ tile_k,
370*da0073e9SAndroid Build Coastguard Worker /*.tile_l = */ tile_l};
371*da0073e9SAndroid Build Coastguard Worker legacy_pthreadpool_compute_1d(
372*da0073e9SAndroid Build Coastguard Worker threadpool,
373*da0073e9SAndroid Build Coastguard Worker (legacy_pthreadpool_function_1d_t)compute_4d_tiled,
374*da0073e9SAndroid Build Coastguard Worker &context,
375*da0073e9SAndroid Build Coastguard Worker tile_range_i * tile_range_j * tile_range_k * tile_range_l);
376*da0073e9SAndroid Build Coastguard Worker }
377*da0073e9SAndroid Build Coastguard Worker }
378