1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/workgroup_selection.h"
17
18 #include <math.h>
19
20 #include <cmath>
21 #include <set>
22 #include <vector>
23
24 #include "tensorflow/lite/delegates/gpu/common/util.h"
25
26 namespace tflite {
27 namespace gpu {
28
29 namespace {
30
31 template <typename T>
AddCornerCases(const T & grid,int max_work_group_total_size,const T & max_work_group_sizes,WorkGroupSizeAlignment x_alignment,WorkGroupSizeAlignment y_alignment,WorkGroupSizeAlignment z_alignment,std::vector<T> * work_groups)32 void AddCornerCases(const T& grid, int max_work_group_total_size,
33 const T& max_work_group_sizes,
34 WorkGroupSizeAlignment x_alignment,
35 WorkGroupSizeAlignment y_alignment,
36 WorkGroupSizeAlignment z_alignment,
37 std::vector<T>* work_groups) {
38 for (int x = 1; x <= 4; ++x) {
39 for (int y = 1; y <= 4; ++y) {
40 for (int z = 1; z <= 4; ++z) {
41 int wg_x = DivideRoundUp(grid.x, x);
42 int wg_y = DivideRoundUp(grid.y, y);
43 int wg_z = DivideRoundUp(grid.z, z);
44 if (wg_x > max_work_group_sizes.x || wg_y > max_work_group_sizes.y ||
45 wg_z > max_work_group_sizes.z ||
46 wg_x * wg_y * wg_z > max_work_group_total_size) {
47 continue;
48 }
49 if (x_alignment == WorkGroupSizeAlignment::PRECISE &&
50 grid.x % wg_x != 0) {
51 continue;
52 }
53 if (y_alignment == WorkGroupSizeAlignment::PRECISE &&
54 grid.y % wg_y != 0) {
55 continue;
56 }
57 if (z_alignment == WorkGroupSizeAlignment::PRECISE &&
58 grid.z % wg_z != 0) {
59 continue;
60 }
61 work_groups->push_back({wg_x, wg_y, wg_z});
62 }
63 }
64 }
65
66 // this will add at least {1, 1, 1} always.
67 for (int x = 1; x <= 4; ++x) {
68 for (int y = 1; y <= 4; ++y) {
69 for (int z = 1; z <= 4; ++z) {
70 if (x > max_work_group_sizes.x || y > max_work_group_sizes.y ||
71 z > max_work_group_sizes.z ||
72 x * y * z > max_work_group_total_size) {
73 continue;
74 }
75 if (x_alignment == WorkGroupSizeAlignment::PRECISE && grid.x % x != 0) {
76 continue;
77 }
78 if (y_alignment == WorkGroupSizeAlignment::PRECISE && grid.y % y != 0) {
79 continue;
80 }
81 if (z_alignment == WorkGroupSizeAlignment::PRECISE && grid.z % z != 0) {
82 continue;
83 }
84 work_groups->push_back({x, y, z});
85 }
86 }
87 }
88 }
89
GetDivisors(int number)90 std::vector<int> GetDivisors(int number) {
91 const int max_divisor = static_cast<int>(std::sqrt(number));
92 std::vector<int> divisors;
93 // we don't know the number of dividers, so it is just heuristic.
94 divisors.reserve(max_divisor / 3 + 1);
95 for (int i = 1; i <= max_divisor; ++i) {
96 const int d = number / i;
97 if (i * d == number) {
98 divisors.push_back(i);
99 if (d != i) {
100 divisors.push_back(d);
101 }
102 }
103 }
104 return divisors;
105 }
106
GetDivisorsForRange(int number,int range)107 std::vector<int> GetDivisorsForRange(int number, int range) {
108 const int last_number = number + range;
109 const int max_divisor = static_cast<int>(std::sqrt(last_number));
110 std::set<int> divisors;
111 for (int i = 1; i <= max_divisor; ++i) {
112 const int reminder = number % i;
113 // iterate through numbers that divisible by i in our range;
114 const int first_number = number + (i - reminder) % i;
115 if (first_number <= last_number) {
116 divisors.insert(i);
117 }
118 for (int j = first_number; j <= last_number; j += i) {
119 const int d = j / i;
120 if (d != i) {
121 divisors.insert(d);
122 }
123 }
124 }
125 return std::vector<int>(divisors.begin(), divisors.end());
126 }
127
128 } // namespace
129
GetPossibleSizes(int number,WorkGroupSizeAlignment z_alignment)130 std::vector<int> GetPossibleSizes(int number,
131 WorkGroupSizeAlignment z_alignment) {
132 if (z_alignment == WorkGroupSizeAlignment::PRECISE) {
133 // we will use for potential sizes, sizes that cover grid precisely
134 // work group size * k (k is integer) == grid_size
135 return GetDivisors(number);
136 } else {
137 // when we chose work group size we can use work group size that
138 // work group size * k (k is integer) != grid_size (slightly bigger)
139 // so in this heuristic we trying to find potential size, that satisfies
140 // to this : work group size * k (k is integer) <= grid_size + 5
141 // and this : work group size * k (k is integer) >= grid_size
142 return GetDivisorsForRange(number, 5);
143 }
144 }
145
146 template <typename T>
GenerateWorkGroupSizes(const T & grid,int min_work_group_total_size,int max_work_group_total_size,const T & max_work_group_sizes,WorkGroupSizeAlignment x_alignment,WorkGroupSizeAlignment y_alignment,WorkGroupSizeAlignment z_alignment)147 std::vector<T> GenerateWorkGroupSizes(
148 const T& grid, int min_work_group_total_size, int max_work_group_total_size,
149 const T& max_work_group_sizes, WorkGroupSizeAlignment x_alignment,
150 WorkGroupSizeAlignment y_alignment, WorkGroupSizeAlignment z_alignment) {
151 std::vector<T> work_groups;
152 work_groups.reserve(64);
153
154 std::vector<int> sizes_x = GetPossibleSizes(grid.x, x_alignment);
155 std::vector<int> sizes_y = GetPossibleSizes(grid.y, y_alignment);
156 std::vector<int> sizes_z = GetPossibleSizes(grid.z, z_alignment);
157
158 for (auto x : sizes_x) {
159 if (x > max_work_group_sizes.x) continue;
160 for (auto y : sizes_y) {
161 if (y > max_work_group_sizes.y) continue;
162 for (auto z : sizes_z) {
163 if (z > max_work_group_sizes.z) continue;
164 const int work_group_size = x * y * z;
165 if (work_group_size < min_work_group_total_size ||
166 work_group_size > max_work_group_total_size)
167 continue;
168 work_groups.push_back({x, y, z});
169 }
170 }
171 }
172
173 return work_groups;
174 }
175
176 // Specializations of GenerateWorkGroupSizes for int3 and uint3
177
178 template std::vector<int3> GenerateWorkGroupSizes(
179 const int3& grid, int min_work_group_total_size,
180 int max_work_group_total_size, const int3& max_work_group_sizes,
181 WorkGroupSizeAlignment x_alignment, WorkGroupSizeAlignment y_alignment,
182 WorkGroupSizeAlignment z_alignment);
183
184 template std::vector<uint3> GenerateWorkGroupSizes(
185 const uint3& grid, int min_work_group_total_size,
186 int max_work_group_total_size, const uint3& max_work_group_sizes,
187 WorkGroupSizeAlignment x_alignment, WorkGroupSizeAlignment y_alignment,
188 WorkGroupSizeAlignment z_alignment);
189
190 template <typename T>
GenerateWorkGroupSizesAlignedToGrid(const T & grid,const T & max_work_group_size,const int max_work_group_total_size,std::vector<T> * work_groups)191 void GenerateWorkGroupSizesAlignedToGrid(const T& grid,
192 const T& max_work_group_size,
193 const int max_work_group_total_size,
194 std::vector<T>* work_groups) {
195 auto alignment = WorkGroupSizeAlignment::PRECISE;
196 *work_groups = GenerateWorkGroupSizes<T>(
197 grid, /*min_work_group_total_size = */ 32, max_work_group_total_size,
198 max_work_group_size, alignment, alignment, alignment);
199 // If the grid parameter too small, method below cannot generate workgroups.
200 if (work_groups->empty()) {
201 AddCornerCases(grid, max_work_group_total_size, max_work_group_size,
202 alignment, alignment, alignment, work_groups);
203 }
204 }
205
206 // Specializations of GenerateWorkGroupSizesAlignedToGrid for int3 and uint3
207
208 template void GenerateWorkGroupSizesAlignedToGrid(
209 const int3& grid, const int3& max_work_group_size,
210 const int max_work_group_total_size, std::vector<int3>* work_groups);
211
212 template void GenerateWorkGroupSizesAlignedToGrid(
213 const uint3& grid, const uint3& max_work_group_size,
214 const int max_work_group_total_size, std::vector<uint3>* work_groups);
215
216 } // namespace gpu
217 } // namespace tflite
218