1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <string>
16 #include <unordered_set>
17 #include <vector>
18 
19 #include "gmock/gmock.h"
20 #include "test/opt/pass_fixture.h"
21 #include "test/opt/pass_utils.h"
22 
23 namespace spvtools {
24 namespace opt {
25 namespace {
26 
27 using ::testing::HasSubstr;
28 using ::testing::MatchesRegex;
29 using StrengthReductionBasicTest = PassTest<::testing::Test>;
30 
31 // Test to make sure we replace 5*8.
TEST_F(StrengthReductionBasicTest,BasicReplaceMulBy8)32 TEST_F(StrengthReductionBasicTest, BasicReplaceMulBy8) {
33   const std::vector<const char*> text = {
34       // clang-format off
35                "OpCapability Shader",
36           "%1 = OpExtInstImport \"GLSL.std.450\"",
37                "OpMemoryModel Logical GLSL450",
38                "OpEntryPoint Vertex %main \"main\"",
39                "OpName %main \"main\"",
40        "%void = OpTypeVoid",
41           "%4 = OpTypeFunction %void",
42        "%uint = OpTypeInt 32 0",
43      "%uint_5 = OpConstant %uint 5",
44      "%uint_8 = OpConstant %uint 8",
45        "%main = OpFunction %void None %4",
46           "%8 = OpLabel",
47           "%9 = OpIMul %uint %uint_5 %uint_8",
48                "OpReturn",
49                "OpFunctionEnd"
50       // clang-format on
51   };
52 
53   auto result = SinglePassRunAndDisassemble<StrengthReductionPass>(
54       JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false);
55 
56   EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result));
57   const std::string& output = std::get<0>(result);
58   EXPECT_THAT(output, Not(HasSubstr("OpIMul")));
59   EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %uint %uint_5 %uint_3"));
60 }
61 
62 // TODO(dneto): Add Effcee as required dependency, and make this unconditional.
63 // Test to make sure we replace 16*5
64 // Also demonstrate use of Effcee matching.
TEST_F(StrengthReductionBasicTest,BasicReplaceMulBy16)65 TEST_F(StrengthReductionBasicTest, BasicReplaceMulBy16) {
66   const std::string text = R"(
67                OpCapability Shader
68           %1 = OpExtInstImport "GLSL.std.450"
69                OpMemoryModel Logical GLSL450
70                OpEntryPoint Vertex %main "main"
71                OpName %main "main"
72        %void = OpTypeVoid
73           %4 = OpTypeFunction %void
74 ; We know disassembly will produce %uint here, but
75 ;  CHECK: %uint = OpTypeInt 32 0
76 ;  CHECK-DAG: [[five:%[a-zA-Z_\d]+]] = OpConstant %uint 5
77 
78 ; We have RE2 regular expressions, so \w matches [_a-zA-Z0-9].
79 ; This shows the preferred pattern for matching SPIR-V identifiers.
80 ; (We could have cheated in this case since we know the disassembler will
81 ; generate the 'nice' name of "%uint_4".
82 ;  CHECK-DAG: [[four:%\w+]] = OpConstant %uint 4
83        %uint = OpTypeInt 32 0
84      %uint_5 = OpConstant %uint 5
85     %uint_16 = OpConstant %uint 16
86        %main = OpFunction %void None %4
87 ; CHECK: OpLabel
88           %8 = OpLabel
89 ; CHECK-NEXT: OpShiftLeftLogical %uint [[five]] [[four]]
90 ; The multiplication disappears.
91 ; CHECK-NOT: OpIMul
92           %9 = OpIMul %uint %uint_16 %uint_5
93                OpReturn
94 ; CHECK: OpFunctionEnd
95                OpFunctionEnd)";
96 
97   SinglePassRunAndMatch<StrengthReductionPass>(text, false);
98 }
99 
100 // Test to make sure we replace a multiple of 32 and 4.
TEST_F(StrengthReductionBasicTest,BasicTwoPowersOf2)101 TEST_F(StrengthReductionBasicTest, BasicTwoPowersOf2) {
102   // In this case, we have two powers of 2.  Need to make sure we replace only
103   // one of them for the bit shift.
104   // clang-format off
105   const std::string text = R"(
106           OpCapability Shader
107      %1 = OpExtInstImport "GLSL.std.450"
108           OpMemoryModel Logical GLSL450
109           OpEntryPoint Vertex %main "main"
110           OpName %main "main"
111   %void = OpTypeVoid
112      %4 = OpTypeFunction %void
113    %int = OpTypeInt 32 1
114 %int_32 = OpConstant %int 32
115  %int_4 = OpConstant %int 4
116   %main = OpFunction %void None %4
117      %8 = OpLabel
118      %9 = OpIMul %int %int_32 %int_4
119           OpReturn
120           OpFunctionEnd
121 )";
122   // clang-format on
123   auto result = SinglePassRunAndDisassemble<StrengthReductionPass>(
124       text, /* skip_nop = */ true, /* do_validation = */ false);
125 
126   EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result));
127   const std::string& output = std::get<0>(result);
128   EXPECT_THAT(output, Not(HasSubstr("OpIMul")));
129   EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %int %int_4 %uint_5"));
130 }
131 
132 // Test to make sure we don't replace 0*5.
TEST_F(StrengthReductionBasicTest,BasicDontReplace0)133 TEST_F(StrengthReductionBasicTest, BasicDontReplace0) {
134   const std::vector<const char*> text = {
135       // clang-format off
136                "OpCapability Shader",
137           "%1 = OpExtInstImport \"GLSL.std.450\"",
138                "OpMemoryModel Logical GLSL450",
139                "OpEntryPoint Vertex %main \"main\"",
140                "OpName %main \"main\"",
141        "%void = OpTypeVoid",
142           "%4 = OpTypeFunction %void",
143         "%int = OpTypeInt 32 1",
144       "%int_0 = OpConstant %int 0",
145       "%int_5 = OpConstant %int 5",
146        "%main = OpFunction %void None %4",
147           "%8 = OpLabel",
148           "%9 = OpIMul %int %int_0 %int_5",
149                "OpReturn",
150                "OpFunctionEnd"
151       // clang-format on
152   };
153 
154   auto result = SinglePassRunAndDisassemble<StrengthReductionPass>(
155       JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false);
156 
157   EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result));
158 }
159 
160 // Test to make sure we do not replace a multiple of 5 and 7.
TEST_F(StrengthReductionBasicTest,BasicNoChange)161 TEST_F(StrengthReductionBasicTest, BasicNoChange) {
162   const std::vector<const char*> text = {
163       // clang-format off
164              "OpCapability Shader",
165         "%1 = OpExtInstImport \"GLSL.std.450\"",
166              "OpMemoryModel Logical GLSL450",
167              "OpEntryPoint Vertex %2 \"main\"",
168              "OpName %2 \"main\"",
169         "%3 = OpTypeVoid",
170         "%4 = OpTypeFunction %3",
171         "%5 = OpTypeInt 32 1",
172         "%6 = OpTypeInt 32 0",
173         "%7 = OpConstant %5 5",
174         "%8 = OpConstant %5 7",
175         "%2 = OpFunction %3 None %4",
176         "%9 = OpLabel",
177         "%10 = OpIMul %5 %7 %8",
178              "OpReturn",
179              "OpFunctionEnd",
180       // clang-format on
181   };
182 
183   auto result = SinglePassRunAndDisassemble<StrengthReductionPass>(
184       JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false);
185 
186   EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result));
187 }
188 
189 // Test to make sure constants and types are reused and not duplicated.
TEST_F(StrengthReductionBasicTest,NoDuplicateConstantsAndTypes)190 TEST_F(StrengthReductionBasicTest, NoDuplicateConstantsAndTypes) {
191   const std::vector<const char*> text = {
192       // clang-format off
193                "OpCapability Shader",
194           "%1 = OpExtInstImport \"GLSL.std.450\"",
195                "OpMemoryModel Logical GLSL450",
196                "OpEntryPoint Vertex %main \"main\"",
197                "OpName %main \"main\"",
198        "%void = OpTypeVoid",
199           "%4 = OpTypeFunction %void",
200        "%uint = OpTypeInt 32 0",
201      "%uint_8 = OpConstant %uint 8",
202      "%uint_3 = OpConstant %uint 3",
203        "%main = OpFunction %void None %4",
204           "%8 = OpLabel",
205           "%9 = OpIMul %uint %uint_8 %uint_3",
206                "OpReturn",
207                "OpFunctionEnd",
208       // clang-format on
209   };
210 
211   auto result = SinglePassRunAndDisassemble<StrengthReductionPass>(
212       JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false);
213 
214   EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result));
215   const std::string& output = std::get<0>(result);
216   EXPECT_THAT(output,
217               Not(MatchesRegex(".*OpConstant %uint 3.*OpConstant %uint 3.*")));
218   EXPECT_THAT(output, Not(MatchesRegex(".*OpTypeInt 32 0.*OpTypeInt 32 0.*")));
219 }
220 
221 // Test to make sure we generate the constants only once
TEST_F(StrengthReductionBasicTest,BasicCreateOneConst)222 TEST_F(StrengthReductionBasicTest, BasicCreateOneConst) {
223   const std::vector<const char*> text = {
224       // clang-format off
225                "OpCapability Shader",
226           "%1 = OpExtInstImport \"GLSL.std.450\"",
227                "OpMemoryModel Logical GLSL450",
228                "OpEntryPoint Vertex %main \"main\"",
229                "OpName %main \"main\"",
230        "%void = OpTypeVoid",
231           "%4 = OpTypeFunction %void",
232        "%uint = OpTypeInt 32 0",
233      "%uint_5 = OpConstant %uint 5",
234      "%uint_9 = OpConstant %uint 9",
235    "%uint_128 = OpConstant %uint 128",
236        "%main = OpFunction %void None %4",
237           "%8 = OpLabel",
238           "%9 = OpIMul %uint %uint_5 %uint_128",
239          "%10 = OpIMul %uint %uint_9 %uint_128",
240                "OpReturn",
241                "OpFunctionEnd"
242       // clang-format on
243   };
244 
245   auto result = SinglePassRunAndDisassemble<StrengthReductionPass>(
246       JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false);
247 
248   EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result));
249   const std::string& output = std::get<0>(result);
250   EXPECT_THAT(output, Not(HasSubstr("OpIMul")));
251   EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %uint %uint_5 %uint_7"));
252   EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %uint %uint_9 %uint_7"));
253 }
254 
255 // Test to make sure we generate the instructions in the correct position and
256 // that the uses get replaced as well.  Here we check that the use in the return
257 // is replaced, we also check that we can replace two OpIMuls when one feeds the
258 // other.
TEST_F(StrengthReductionBasicTest,BasicCheckPositionAndReplacement)259 TEST_F(StrengthReductionBasicTest, BasicCheckPositionAndReplacement) {
260   // This is just the preamble to set up the test.
261   const std::vector<const char*> common_text = {
262       // clang-format off
263                "OpCapability Shader",
264           "%1 = OpExtInstImport \"GLSL.std.450\"",
265                "OpMemoryModel Logical GLSL450",
266                "OpEntryPoint Fragment %main \"main\" %gl_FragColor",
267                "OpExecutionMode %main OriginUpperLeft",
268                "OpName %main \"main\"",
269                "OpName %foo_i1_ \"foo(i1;\"",
270                "OpName %n \"n\"",
271                "OpName %gl_FragColor \"gl_FragColor\"",
272                "OpName %param \"param\"",
273                "OpDecorate %gl_FragColor Location 0",
274        "%void = OpTypeVoid",
275           "%3 = OpTypeFunction %void",
276         "%int = OpTypeInt 32 1",
277 "%_ptr_Function_int = OpTypePointer Function %int",
278           "%8 = OpTypeFunction %int %_ptr_Function_int",
279     "%int_256 = OpConstant %int 256",
280       "%int_2 = OpConstant %int 2",
281       "%float = OpTypeFloat 32",
282     "%v4float = OpTypeVector %float 4",
283 "%_ptr_Output_v4float = OpTypePointer Output %v4float",
284 "%gl_FragColor = OpVariable %_ptr_Output_v4float Output",
285     "%float_1 = OpConstant %float 1",
286      "%int_10 = OpConstant %int 10",
287   "%float_0_375 = OpConstant %float 0.375",
288   "%float_0_75 = OpConstant %float 0.75",
289        "%uint = OpTypeInt 32 0",
290      "%uint_8 = OpConstant %uint 8",
291      "%uint_1 = OpConstant %uint 1",
292        "%main = OpFunction %void None %3",
293           "%5 = OpLabel",
294       "%param = OpVariable %_ptr_Function_int Function",
295                "OpStore %param %int_10",
296          "%26 = OpFunctionCall %int %foo_i1_ %param",
297          "%27 = OpConvertSToF %float %26",
298          "%28 = OpFDiv %float %float_1 %27",
299          "%31 = OpCompositeConstruct %v4float %28 %float_0_375 %float_0_75 %float_1",
300                "OpStore %gl_FragColor %31",
301                "OpReturn",
302                "OpFunctionEnd"
303       // clang-format on
304   };
305 
306   // This is the real test.  The two OpIMul should be replaced.  The expected
307   // output is in |foo_after|.
308   const std::vector<const char*> foo_before = {
309       // clang-format off
310     "%foo_i1_ = OpFunction %int None %8",
311           "%n = OpFunctionParameter %_ptr_Function_int",
312          "%11 = OpLabel",
313          "%12 = OpLoad %int %n",
314          "%14 = OpIMul %int %12 %int_256",
315          "%16 = OpIMul %int %14 %int_2",
316                "OpReturnValue %16",
317                "OpFunctionEnd",
318 
319       // clang-format on
320   };
321 
322   const std::vector<const char*> foo_after = {
323       // clang-format off
324     "%foo_i1_ = OpFunction %int None %8",
325           "%n = OpFunctionParameter %_ptr_Function_int",
326          "%11 = OpLabel",
327          "%12 = OpLoad %int %n",
328          "%33 = OpShiftLeftLogical %int %12 %uint_8",
329          "%34 = OpShiftLeftLogical %int %33 %uint_1",
330                "OpReturnValue %34",
331                "OpFunctionEnd",
332       // clang-format on
333   };
334 
335   SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
336   SinglePassRunAndCheck<StrengthReductionPass>(
337       JoinAllInsts(Concat(common_text, foo_before)),
338       JoinAllInsts(Concat(common_text, foo_after)),
339       /* skip_nop = */ true, /* do_validate = */ true);
340 }
341 
342 // Test that, when the result of an OpIMul instruction has more than 1 use, and
343 // the instruction is replaced, all of the uses of the results are replace with
344 // the new result.
TEST_F(StrengthReductionBasicTest,BasicTestMultipleReplacements)345 TEST_F(StrengthReductionBasicTest, BasicTestMultipleReplacements) {
346   // This is just the preamble to set up the test.
347   const std::vector<const char*> common_text = {
348       // clang-format off
349                "OpCapability Shader",
350           "%1 = OpExtInstImport \"GLSL.std.450\"",
351                "OpMemoryModel Logical GLSL450",
352                "OpEntryPoint Fragment %main \"main\" %gl_FragColor",
353                "OpExecutionMode %main OriginUpperLeft",
354                "OpName %main \"main\"",
355                "OpName %foo_i1_ \"foo(i1;\"",
356                "OpName %n \"n\"",
357                "OpName %gl_FragColor \"gl_FragColor\"",
358                "OpName %param \"param\"",
359                "OpDecorate %gl_FragColor Location 0",
360        "%void = OpTypeVoid",
361           "%3 = OpTypeFunction %void",
362         "%int = OpTypeInt 32 1",
363 "%_ptr_Function_int = OpTypePointer Function %int",
364           "%8 = OpTypeFunction %int %_ptr_Function_int",
365     "%int_256 = OpConstant %int 256",
366       "%int_2 = OpConstant %int 2",
367       "%float = OpTypeFloat 32",
368     "%v4float = OpTypeVector %float 4",
369 "%_ptr_Output_v4float = OpTypePointer Output %v4float",
370 "%gl_FragColor = OpVariable %_ptr_Output_v4float Output",
371     "%float_1 = OpConstant %float 1",
372      "%int_10 = OpConstant %int 10",
373   "%float_0_375 = OpConstant %float 0.375",
374   "%float_0_75 = OpConstant %float 0.75",
375        "%uint = OpTypeInt 32 0",
376      "%uint_8 = OpConstant %uint 8",
377      "%uint_1 = OpConstant %uint 1",
378        "%main = OpFunction %void None %3",
379           "%5 = OpLabel",
380       "%param = OpVariable %_ptr_Function_int Function",
381                "OpStore %param %int_10",
382          "%26 = OpFunctionCall %int %foo_i1_ %param",
383          "%27 = OpConvertSToF %float %26",
384          "%28 = OpFDiv %float %float_1 %27",
385          "%31 = OpCompositeConstruct %v4float %28 %float_0_375 %float_0_75 %float_1",
386                "OpStore %gl_FragColor %31",
387                "OpReturn",
388                "OpFunctionEnd"
389       // clang-format on
390   };
391 
392   // This is the real test.  The two OpIMul instructions should be replaced.  In
393   // particular, we want to be sure that both uses of %16 are changed to use the
394   // new result.
395   const std::vector<const char*> foo_before = {
396       // clang-format off
397     "%foo_i1_ = OpFunction %int None %8",
398           "%n = OpFunctionParameter %_ptr_Function_int",
399          "%11 = OpLabel",
400          "%12 = OpLoad %int %n",
401          "%14 = OpIMul %int %12 %int_256",
402          "%16 = OpIMul %int %14 %int_2",
403          "%17 = OpIAdd %int %14 %16",
404                "OpReturnValue %17",
405                "OpFunctionEnd",
406 
407       // clang-format on
408   };
409 
410   const std::vector<const char*> foo_after = {
411       // clang-format off
412     "%foo_i1_ = OpFunction %int None %8",
413           "%n = OpFunctionParameter %_ptr_Function_int",
414          "%11 = OpLabel",
415          "%12 = OpLoad %int %n",
416          "%34 = OpShiftLeftLogical %int %12 %uint_8",
417          "%35 = OpShiftLeftLogical %int %34 %uint_1",
418          "%17 = OpIAdd %int %34 %35",
419                "OpReturnValue %17",
420                "OpFunctionEnd",
421       // clang-format on
422   };
423 
424   SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
425   SinglePassRunAndCheck<StrengthReductionPass>(
426       JoinAllInsts(Concat(common_text, foo_before)),
427       JoinAllInsts(Concat(common_text, foo_after)),
428       /* skip_nop = */ true, /* do_validate = */ true);
429 }
430 
431 }  // namespace
432 }  // namespace opt
433 }  // namespace spvtools
434