1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <memory>
16 #include <vector>
17 
18 #include "gmock/gmock.h"
19 #include "source/opt/loop_descriptor.h"
20 #include "source/opt/pass.h"
21 #include "test/opt/assembly_builder.h"
22 #include "test/opt/function_utils.h"
23 #include "test/opt/pass_fixture.h"
24 #include "test/opt/pass_utils.h"
25 
26 namespace spvtools {
27 namespace opt {
28 namespace {
29 
30 using ::testing::UnorderedElementsAre;
31 using PassClassTest = PassTest<::testing::Test>;
32 
33 /*
34 Generated from the following GLSL
35 #version 330 core
36 layout(location = 0) out vec4 c;
37 void main() {
38   int i = 0;
39   for(; i < 10; ++i) {
40   }
41 }
42 */
TEST_F(PassClassTest,BasicVisitFromEntryPoint)43 TEST_F(PassClassTest, BasicVisitFromEntryPoint) {
44   const std::string text = R"(
45                 OpCapability Shader
46           %1 = OpExtInstImport "GLSL.std.450"
47                OpMemoryModel Logical GLSL450
48                OpEntryPoint Fragment %2 "main" %3
49                OpExecutionMode %2 OriginUpperLeft
50                OpSource GLSL 330
51                OpName %2 "main"
52                OpName %5 "i"
53                OpName %3 "c"
54                OpDecorate %3 Location 0
55           %6 = OpTypeVoid
56           %7 = OpTypeFunction %6
57           %8 = OpTypeInt 32 1
58           %9 = OpTypePointer Function %8
59          %10 = OpConstant %8 0
60          %11 = OpConstant %8 10
61          %12 = OpTypeBool
62          %13 = OpConstant %8 1
63          %14 = OpTypeFloat 32
64          %15 = OpTypeVector %14 4
65          %16 = OpTypePointer Output %15
66           %3 = OpVariable %16 Output
67           %2 = OpFunction %6 None %7
68          %17 = OpLabel
69           %5 = OpVariable %9 Function
70                OpStore %5 %10
71                OpBranch %18
72          %18 = OpLabel
73                OpLoopMerge %19 %20 None
74                OpBranch %21
75          %21 = OpLabel
76          %22 = OpLoad %8 %5
77          %23 = OpSLessThan %12 %22 %11
78                OpBranchConditional %23 %24 %19
79          %24 = OpLabel
80                OpBranch %20
81          %20 = OpLabel
82          %25 = OpLoad %8 %5
83          %26 = OpIAdd %8 %25 %13
84                OpStore %5 %26
85                OpBranch %18
86          %19 = OpLabel
87                OpReturn
88                OpFunctionEnd
89   )";
90   // clang-format on
91   std::unique_ptr<IRContext> context =
92       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
93                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
94   Module* module = context->module();
95   EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
96                              << text << std::endl;
97   const Function* f = spvtest::GetFunction(module, 2);
98   LoopDescriptor& ld = *context->GetLoopDescriptor(f);
99 
100   EXPECT_EQ(ld.NumLoops(), 1u);
101 
102   Loop& loop = ld.GetLoopByIndex(0);
103   EXPECT_EQ(loop.GetHeaderBlock(), spvtest::GetBasicBlock(f, 18));
104   EXPECT_EQ(loop.GetLatchBlock(), spvtest::GetBasicBlock(f, 20));
105   EXPECT_EQ(loop.GetMergeBlock(), spvtest::GetBasicBlock(f, 19));
106 
107   EXPECT_FALSE(loop.HasNestedLoops());
108   EXPECT_FALSE(loop.IsNested());
109   EXPECT_EQ(loop.GetDepth(), 1u);
110 }
111 
112 /*
113 Generated from the following GLSL:
114 #version 330 core
115 layout(location = 0) out vec4 c;
116 void main() {
117   for(int i = 0; i < 10; ++i) {}
118   for(int i = 0; i < 10; ++i) {}
119 }
120 
121 But it was "hacked" to make the first loop merge block the second loop header.
122 */
TEST_F(PassClassTest,LoopWithNoPreHeader)123 TEST_F(PassClassTest, LoopWithNoPreHeader) {
124   const std::string text = R"(
125                OpCapability Shader
126           %1 = OpExtInstImport "GLSL.std.450"
127                OpMemoryModel Logical GLSL450
128                OpEntryPoint Fragment %2 "main" %3
129                OpExecutionMode %2 OriginUpperLeft
130                OpSource GLSL 330
131                OpName %2 "main"
132                OpName %4 "i"
133                OpName %5 "i"
134                OpName %3 "c"
135                OpDecorate %3 Location 0
136           %6 = OpTypeVoid
137           %7 = OpTypeFunction %6
138           %8 = OpTypeInt 32 1
139           %9 = OpTypePointer Function %8
140          %10 = OpConstant %8 0
141          %11 = OpConstant %8 10
142          %12 = OpTypeBool
143          %13 = OpConstant %8 1
144          %14 = OpTypeFloat 32
145          %15 = OpTypeVector %14 4
146          %16 = OpTypePointer Output %15
147           %3 = OpVariable %16 Output
148           %2 = OpFunction %6 None %7
149          %17 = OpLabel
150           %4 = OpVariable %9 Function
151           %5 = OpVariable %9 Function
152                OpStore %4 %10
153                OpStore %5 %10
154                OpBranch %18
155          %18 = OpLabel
156                OpLoopMerge %27 %20 None
157                OpBranch %21
158          %21 = OpLabel
159          %22 = OpLoad %8 %4
160          %23 = OpSLessThan %12 %22 %11
161                OpBranchConditional %23 %24 %27
162          %24 = OpLabel
163                OpBranch %20
164          %20 = OpLabel
165          %25 = OpLoad %8 %4
166          %26 = OpIAdd %8 %25 %13
167                OpStore %4 %26
168                OpBranch %18
169          %27 = OpLabel
170                OpLoopMerge %28 %29 None
171                OpBranch %30
172          %30 = OpLabel
173          %31 = OpLoad %8 %5
174          %32 = OpSLessThan %12 %31 %11
175                OpBranchConditional %32 %33 %28
176          %33 = OpLabel
177                OpBranch %29
178          %29 = OpLabel
179          %34 = OpLoad %8 %5
180          %35 = OpIAdd %8 %34 %13
181                OpStore %5 %35
182                OpBranch %27
183          %28 = OpLabel
184                OpReturn
185                OpFunctionEnd
186   )";
187   // clang-format on
188   std::unique_ptr<IRContext> context =
189       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
190                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
191   Module* module = context->module();
192   EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
193                              << text << std::endl;
194   const Function* f = spvtest::GetFunction(module, 2);
195   LoopDescriptor& ld = *context->GetLoopDescriptor(f);
196 
197   EXPECT_EQ(ld.NumLoops(), 2u);
198 
199   Loop* loop = ld[27];
200   EXPECT_EQ(loop->GetPreHeaderBlock(), nullptr);
201   EXPECT_NE(loop->GetOrCreatePreHeaderBlock(), nullptr);
202 }
203 
204 /*
205 Generated from the following GLSL + --eliminate-local-multi-store
206 
207 #version 330 core
208 in vec4 c;
209 void main() {
210   int i = 0;
211   bool cond = c[0] == 0;
212   for (; i < 10; i++) {
213     if (cond) {
214       return;
215     }
216     else {
217       return;
218     }
219   }
220   bool cond2 = i == 9;
221 }
222 */
TEST_F(PassClassTest,NoLoop)223 TEST_F(PassClassTest, NoLoop) {
224   const std::string text = R"(; SPIR-V
225 ; Version: 1.0
226 ; Generator: Khronos Glslang Reference Front End; 3
227 ; Bound: 47
228 ; Schema: 0
229                OpCapability Shader
230           %1 = OpExtInstImport "GLSL.std.450"
231                OpMemoryModel Logical GLSL450
232                OpEntryPoint Fragment %4 "main" %16
233                OpExecutionMode %4 OriginUpperLeft
234                OpSource GLSL 330
235                OpName %4 "main"
236                OpName %16 "c"
237                OpDecorate %16 Location 0
238           %2 = OpTypeVoid
239           %3 = OpTypeFunction %2
240           %6 = OpTypeInt 32 1
241           %7 = OpTypePointer Function %6
242           %9 = OpConstant %6 0
243          %10 = OpTypeBool
244          %11 = OpTypePointer Function %10
245          %13 = OpTypeFloat 32
246          %14 = OpTypeVector %13 4
247          %15 = OpTypePointer Input %14
248          %16 = OpVariable %15 Input
249          %17 = OpTypeInt 32 0
250          %18 = OpConstant %17 0
251          %19 = OpTypePointer Input %13
252          %22 = OpConstant %13 0
253          %30 = OpConstant %6 10
254          %39 = OpConstant %6 1
255          %46 = OpUndef %6
256           %4 = OpFunction %2 None %3
257           %5 = OpLabel
258          %20 = OpAccessChain %19 %16 %18
259          %21 = OpLoad %13 %20
260          %23 = OpFOrdEqual %10 %21 %22
261                OpBranch %24
262          %24 = OpLabel
263          %45 = OpPhi %6 %9 %5 %40 %27
264                OpLoopMerge %26 %27 None
265                OpBranch %28
266          %28 = OpLabel
267          %31 = OpSLessThan %10 %45 %30
268                OpBranchConditional %31 %25 %26
269          %25 = OpLabel
270                OpSelectionMerge %34 None
271                OpBranchConditional %23 %33 %36
272          %33 = OpLabel
273                OpReturn
274          %36 = OpLabel
275                OpReturn
276          %34 = OpLabel
277                OpBranch %27
278          %27 = OpLabel
279          %40 = OpIAdd %6 %46 %39
280                OpBranch %24
281          %26 = OpLabel
282                OpReturn
283                OpFunctionEnd
284   )";
285 
286   std::unique_ptr<IRContext> context =
287       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
288                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
289   Module* module = context->module();
290   EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
291                              << text << std::endl;
292   const Function* f = spvtest::GetFunction(module, 4);
293   LoopDescriptor ld{context.get(), f};
294 
295   EXPECT_EQ(ld.NumLoops(), 0u);
296 }
297 
298 /*
299 Generated from following GLSL with latch block artificially inserted to be
300 separate from continue.
301 #version 430
302 void main(void) {
303     float x[10];
304     for (int i = 0; i < 10; ++i) {
305       x[i] = i;
306     }
307 }
308 */
TEST_F(PassClassTest,LoopLatchNotContinue)309 TEST_F(PassClassTest, LoopLatchNotContinue) {
310   const std::string text = R"(OpCapability Shader
311           %1 = OpExtInstImport "GLSL.std.450"
312                OpMemoryModel Logical GLSL450
313                OpEntryPoint Fragment %2 "main"
314                OpExecutionMode %2 OriginUpperLeft
315                OpSource GLSL 430
316                OpName %2 "main"
317                OpName %3 "i"
318                OpName %4 "x"
319           %5 = OpTypeVoid
320           %6 = OpTypeFunction %5
321           %7 = OpTypeInt 32 1
322           %8 = OpTypePointer Function %7
323           %9 = OpConstant %7 0
324          %10 = OpConstant %7 10
325          %11 = OpTypeBool
326          %12 = OpTypeFloat 32
327          %13 = OpTypeInt 32 0
328          %14 = OpConstant %13 10
329          %15 = OpTypeArray %12 %14
330          %16 = OpTypePointer Function %15
331          %17 = OpTypePointer Function %12
332          %18 = OpConstant %7 1
333           %2 = OpFunction %5 None %6
334          %19 = OpLabel
335           %3 = OpVariable %8 Function
336           %4 = OpVariable %16 Function
337                OpStore %3 %9
338                OpBranch %20
339          %20 = OpLabel
340          %21 = OpPhi %7 %9 %19 %22 %30
341                OpLoopMerge %24 %23 None
342                OpBranch %25
343          %25 = OpLabel
344          %26 = OpSLessThan %11 %21 %10
345                OpBranchConditional %26 %27 %24
346          %27 = OpLabel
347          %28 = OpConvertSToF %12 %21
348          %29 = OpAccessChain %17 %4 %21
349                OpStore %29 %28
350                OpBranch %23
351          %23 = OpLabel
352          %22 = OpIAdd %7 %21 %18
353                OpStore %3 %22
354                OpBranch %30
355          %30 = OpLabel
356                OpBranch %20
357          %24 = OpLabel
358                OpReturn
359                OpFunctionEnd
360   )";
361 
362   std::unique_ptr<IRContext> context =
363       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
364                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
365   Module* module = context->module();
366   EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
367                              << text << std::endl;
368   const Function* f = spvtest::GetFunction(module, 2);
369   LoopDescriptor ld{context.get(), f};
370 
371   EXPECT_EQ(ld.NumLoops(), 1u);
372 
373   Loop& loop = ld.GetLoopByIndex(0u);
374 
375   EXPECT_NE(loop.GetLatchBlock(), loop.GetContinueBlock());
376 
377   EXPECT_EQ(loop.GetContinueBlock()->id(), 23u);
378   EXPECT_EQ(loop.GetLatchBlock()->id(), 30u);
379 }
380 
TEST_F(PassClassTest,UnreachableMerge)381 TEST_F(PassClassTest, UnreachableMerge) {
382   const std::string text = R"(
383                OpCapability Shader
384                OpMemoryModel Logical GLSL450
385                OpEntryPoint Fragment %1 "main"
386                OpExecutionMode %1 OriginUpperLeft
387        %void = OpTypeVoid
388           %3 = OpTypeFunction %void
389           %1 = OpFunction %void None %3
390           %4 = OpLabel
391                OpBranch %5
392           %5 = OpLabel
393                OpLoopMerge %6 %7 None
394                OpBranch %8
395           %8 = OpLabel
396                OpBranch %9
397           %9 = OpLabel
398                OpBranch %7
399           %7 = OpLabel
400                OpBranch %5
401           %6 = OpLabel
402                OpUnreachable
403                OpFunctionEnd
404 )";
405 
406   std::unique_ptr<IRContext> context =
407       BuildModule(SPV_ENV_UNIVERSAL_1_3, nullptr, text,
408                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
409   Module* module = context->module();
410   EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
411                              << text << std::endl;
412   const Function* f = spvtest::GetFunction(module, 1);
413   LoopDescriptor ld{context.get(), f};
414 
415   EXPECT_EQ(ld.NumLoops(), 1u);
416 }
417 
418 }  // namespace
419 }  // namespace opt
420 }  // namespace spvtools
421