1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2022 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker
6*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h> // For xnn_caches_t, xnn_operator_t.
7*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/allocator.h> // For XNN_ALLOCATION_ALIGNMENT.
8*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/cache.h> // For xnn_caches.
9*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/operator.h> // For xnn_operator definition.
10*4bdc9457SAndroid Build Coastguard Worker
xnn_get_pointer_to_write_weights(xnn_operator_t op,size_t aligned_weights_size,int padding_byte)11*4bdc9457SAndroid Build Coastguard Worker void* xnn_get_pointer_to_write_weights(
12*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t op,
13*4bdc9457SAndroid Build Coastguard Worker size_t aligned_weights_size,
14*4bdc9457SAndroid Build Coastguard Worker int padding_byte)
15*4bdc9457SAndroid Build Coastguard Worker {
16*4bdc9457SAndroid Build Coastguard Worker assert(aligned_weights_size % XNN_ALLOCATION_ALIGNMENT == 0);
17*4bdc9457SAndroid Build Coastguard Worker void* weights_ptr = NULL;
18*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache(op)) {
19*4bdc9457SAndroid Build Coastguard Worker weights_ptr = xnn_reserve_space_in_weights_cache(op->weights_cache, aligned_weights_size);
20*4bdc9457SAndroid Build Coastguard Worker if (weights_ptr == NULL) {
21*4bdc9457SAndroid Build Coastguard Worker return NULL;
22*4bdc9457SAndroid Build Coastguard Worker }
23*4bdc9457SAndroid Build Coastguard Worker } else {
24*4bdc9457SAndroid Build Coastguard Worker op->packed_weights.pointer = xnn_allocate_simd_memory(aligned_weights_size);
25*4bdc9457SAndroid Build Coastguard Worker if (op->packed_weights.pointer == NULL) {
26*4bdc9457SAndroid Build Coastguard Worker return NULL;
27*4bdc9457SAndroid Build Coastguard Worker }
28*4bdc9457SAndroid Build Coastguard Worker weights_ptr = op->packed_weights.pointer;
29*4bdc9457SAndroid Build Coastguard Worker }
30*4bdc9457SAndroid Build Coastguard Worker memset(weights_ptr, padding_byte, aligned_weights_size);
31*4bdc9457SAndroid Build Coastguard Worker return weights_ptr;
32*4bdc9457SAndroid Build Coastguard Worker }
33*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_convolution_output_dimension(size_t padded_input_dimension,size_t kernel_dimension,size_t dilation_dimension,size_t subsampling_dimension)34*4bdc9457SAndroid Build Coastguard Worker size_t xnn_compute_convolution_output_dimension(
35*4bdc9457SAndroid Build Coastguard Worker size_t padded_input_dimension,
36*4bdc9457SAndroid Build Coastguard Worker size_t kernel_dimension,
37*4bdc9457SAndroid Build Coastguard Worker size_t dilation_dimension,
38*4bdc9457SAndroid Build Coastguard Worker size_t subsampling_dimension)
39*4bdc9457SAndroid Build Coastguard Worker {
40*4bdc9457SAndroid Build Coastguard Worker const size_t effective_kernel_dimension = (kernel_dimension - 1) * dilation_dimension + 1;
41*4bdc9457SAndroid Build Coastguard Worker return doz(padded_input_dimension, effective_kernel_dimension) / subsampling_dimension + 1;
42*4bdc9457SAndroid Build Coastguard Worker }
43*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_deconvolution_output_dimension(size_t input_dimension,size_t output_padding_dimension,size_t adjustment_dimension,size_t kernel_dimension,size_t dilation_dimension,size_t stride_dimension)44*4bdc9457SAndroid Build Coastguard Worker size_t xnn_compute_deconvolution_output_dimension(
45*4bdc9457SAndroid Build Coastguard Worker size_t input_dimension,
46*4bdc9457SAndroid Build Coastguard Worker size_t output_padding_dimension,
47*4bdc9457SAndroid Build Coastguard Worker size_t adjustment_dimension,
48*4bdc9457SAndroid Build Coastguard Worker size_t kernel_dimension,
49*4bdc9457SAndroid Build Coastguard Worker size_t dilation_dimension,
50*4bdc9457SAndroid Build Coastguard Worker size_t stride_dimension)
51*4bdc9457SAndroid Build Coastguard Worker {
52*4bdc9457SAndroid Build Coastguard Worker const size_t effective_kernel_dimension = (kernel_dimension - 1) * dilation_dimension + 1;
53*4bdc9457SAndroid Build Coastguard Worker return doz(
54*4bdc9457SAndroid Build Coastguard Worker stride_dimension * (input_dimension - 1) + adjustment_dimension + effective_kernel_dimension,
55*4bdc9457SAndroid Build Coastguard Worker output_padding_dimension);
56*4bdc9457SAndroid Build Coastguard Worker }
57*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_unpooling_output_dimension(size_t input_dimension,size_t input_padding_dimension,size_t kernel_dimension)58*4bdc9457SAndroid Build Coastguard Worker size_t xnn_compute_unpooling_output_dimension(
59*4bdc9457SAndroid Build Coastguard Worker size_t input_dimension,
60*4bdc9457SAndroid Build Coastguard Worker size_t input_padding_dimension,
61*4bdc9457SAndroid Build Coastguard Worker size_t kernel_dimension)
62*4bdc9457SAndroid Build Coastguard Worker {
63*4bdc9457SAndroid Build Coastguard Worker return xnn_compute_deconvolution_output_dimension(
64*4bdc9457SAndroid Build Coastguard Worker input_dimension, input_padding_dimension, /*adjustment_dimension=*/0,
65*4bdc9457SAndroid Build Coastguard Worker kernel_dimension, /*dilation_dimension=*/1, /*stride_dimension=*/kernel_dimension);
66*4bdc9457SAndroid Build Coastguard Worker }
67*4bdc9457SAndroid Build Coastguard Worker
68*4bdc9457SAndroid Build Coastguard Worker // Calculate how much work a microkernel does.
69*4bdc9457SAndroid Build Coastguard Worker // A MxN microkernel does M+N (scalar) loads and M*N (scalar) FMAs.
70*4bdc9457SAndroid Build Coastguard Worker // So, given batch_size, the microkernel does:
71*4bdc9457SAndroid Build Coastguard Worker // divide_round_up(batch_size, mr) * (mr + nr) loads, and
72*4bdc9457SAndroid Build Coastguard Worker // divide_round_up(batch_size, mr) * (mr * nr) FMAs.
73*4bdc9457SAndroid Build Coastguard Worker // The total cost is then a linear combination of these 2 operations. From experimental data, use a multiplier of 3 for
74*4bdc9457SAndroid Build Coastguard Worker // loads, to prefer higher tile sizes which have better computation intensity.
calculate_microkernel_cost(size_t batch_size,uint32_t mr,uint32_t nr)75*4bdc9457SAndroid Build Coastguard Worker static size_t calculate_microkernel_cost(size_t batch_size, uint32_t mr, uint32_t nr)
76*4bdc9457SAndroid Build Coastguard Worker {
77*4bdc9457SAndroid Build Coastguard Worker return divide_round_up(batch_size, mr) * (3 * (mr + nr) + mr * nr);
78*4bdc9457SAndroid Build Coastguard Worker }
79*4bdc9457SAndroid Build Coastguard Worker
xnn_get_heuristic_mr_gemm(size_t batch_size,uint32_t max_mr,uint32_t nr,struct xnn_hmp_gemm_ukernel * gemm_cases)80*4bdc9457SAndroid Build Coastguard Worker uint32_t xnn_get_heuristic_mr_gemm(
81*4bdc9457SAndroid Build Coastguard Worker size_t batch_size, uint32_t max_mr, uint32_t nr, struct xnn_hmp_gemm_ukernel *gemm_cases)
82*4bdc9457SAndroid Build Coastguard Worker {
83*4bdc9457SAndroid Build Coastguard Worker assert(gemm_cases[max_mr-1].function[XNN_UARCH_DEFAULT] != NULL);
84*4bdc9457SAndroid Build Coastguard Worker if (batch_size <= max_mr && gemm_cases[batch_size-1].function[XNN_UARCH_DEFAULT] != NULL) {
85*4bdc9457SAndroid Build Coastguard Worker // We have a microkernel with MR that is the exact match with batch_size.
86*4bdc9457SAndroid Build Coastguard Worker return batch_size;
87*4bdc9457SAndroid Build Coastguard Worker }
88*4bdc9457SAndroid Build Coastguard Worker
89*4bdc9457SAndroid Build Coastguard Worker // Try to find the best fitting mr.
90*4bdc9457SAndroid Build Coastguard Worker // - use a cost heuristic to calculate how much work is done by the microkernel (see calculate_microkernel_cost)
91*4bdc9457SAndroid Build Coastguard Worker // - smaller cost is better
92*4bdc9457SAndroid Build Coastguard Worker uint32_t best_mr = max_mr;
93*4bdc9457SAndroid Build Coastguard Worker size_t best_cost = SIZE_MAX;
94*4bdc9457SAndroid Build Coastguard Worker for (uint32_t mr = 1; mr <= max_mr; mr++) {
95*4bdc9457SAndroid Build Coastguard Worker if (gemm_cases[mr-1].function[XNN_UARCH_DEFAULT] == NULL) {
96*4bdc9457SAndroid Build Coastguard Worker continue;
97*4bdc9457SAndroid Build Coastguard Worker }
98*4bdc9457SAndroid Build Coastguard Worker const size_t current_cost = calculate_microkernel_cost(batch_size, mr, nr);
99*4bdc9457SAndroid Build Coastguard Worker if (current_cost <= best_cost) {
100*4bdc9457SAndroid Build Coastguard Worker best_mr = mr;
101*4bdc9457SAndroid Build Coastguard Worker best_cost = current_cost;
102*4bdc9457SAndroid Build Coastguard Worker }
103*4bdc9457SAndroid Build Coastguard Worker }
104*4bdc9457SAndroid Build Coastguard Worker assert(gemm_cases[best_mr-1].function[XNN_UARCH_DEFAULT] != NULL);
105*4bdc9457SAndroid Build Coastguard Worker return best_mr;
106*4bdc9457SAndroid Build Coastguard Worker }
107*4bdc9457SAndroid Build Coastguard Worker
xnn_get_heuristic_mr_igemm(size_t batch_size,uint32_t max_mr,uint32_t nr,struct xnn_hmp_igemm_ukernel * igemm_cases)108*4bdc9457SAndroid Build Coastguard Worker uint32_t xnn_get_heuristic_mr_igemm(
109*4bdc9457SAndroid Build Coastguard Worker size_t batch_size, uint32_t max_mr, uint32_t nr, struct xnn_hmp_igemm_ukernel *igemm_cases)
110*4bdc9457SAndroid Build Coastguard Worker {
111*4bdc9457SAndroid Build Coastguard Worker assert(igemm_cases[max_mr-1].function[XNN_UARCH_DEFAULT] != NULL);
112*4bdc9457SAndroid Build Coastguard Worker if (batch_size <= max_mr && igemm_cases[batch_size-1].function[XNN_UARCH_DEFAULT] != NULL) {
113*4bdc9457SAndroid Build Coastguard Worker // We have a microkernel with MR that is the exact match with batch_size.
114*4bdc9457SAndroid Build Coastguard Worker return batch_size;
115*4bdc9457SAndroid Build Coastguard Worker }
116*4bdc9457SAndroid Build Coastguard Worker
117*4bdc9457SAndroid Build Coastguard Worker // Try to find the best fitting mr.
118*4bdc9457SAndroid Build Coastguard Worker // - use a cost heuristic to calculate how much work is done by the microkernel (see calculate_microkernel_cost)
119*4bdc9457SAndroid Build Coastguard Worker // - smaller cost is better
120*4bdc9457SAndroid Build Coastguard Worker uint32_t best_mr = max_mr;
121*4bdc9457SAndroid Build Coastguard Worker size_t best_cost = SIZE_MAX;
122*4bdc9457SAndroid Build Coastguard Worker for (uint32_t mr = 1; mr <= max_mr; mr++) {
123*4bdc9457SAndroid Build Coastguard Worker if (igemm_cases[mr-1].function[XNN_UARCH_DEFAULT] == NULL) {
124*4bdc9457SAndroid Build Coastguard Worker continue;
125*4bdc9457SAndroid Build Coastguard Worker }
126*4bdc9457SAndroid Build Coastguard Worker const size_t current_cost = calculate_microkernel_cost(batch_size, mr, nr);
127*4bdc9457SAndroid Build Coastguard Worker if (current_cost <= best_cost) {
128*4bdc9457SAndroid Build Coastguard Worker best_mr = mr;
129*4bdc9457SAndroid Build Coastguard Worker best_cost = current_cost;
130*4bdc9457SAndroid Build Coastguard Worker }
131*4bdc9457SAndroid Build Coastguard Worker }
132*4bdc9457SAndroid Build Coastguard Worker assert(igemm_cases[best_mr-1].function[XNN_UARCH_DEFAULT] != NULL);
133*4bdc9457SAndroid Build Coastguard Worker return best_mr;
134*4bdc9457SAndroid Build Coastguard Worker }
135