1 // Copyright (c) 2021 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/dataflow.h"
16
17 #include <map>
18 #include <set>
19
20 #include "gtest/gtest.h"
21 #include "opt/function_utils.h"
22 #include "source/opt/build_module.h"
23
24 namespace spvtools {
25 namespace opt {
26 namespace {
27
28 using DataFlowTest = ::testing::Test;
29
30 // Simple analyses for testing:
31
32 // Stores the result IDs of visited instructions in visit order.
33 struct VisitOrder : public ForwardDataFlowAnalysis {
34 std::vector<uint32_t> visited_result_ids;
35
VisitOrderspvtools::opt::__anon3f9560080111::VisitOrder36 VisitOrder(IRContext& context, LabelPosition label_position)
37 : ForwardDataFlowAnalysis(context, label_position) {}
38
Visitspvtools::opt::__anon3f9560080111::VisitOrder39 VisitResult Visit(Instruction* inst) override {
40 if (inst->HasResultId()) {
41 visited_result_ids.push_back(inst->result_id());
42 }
43 return DataFlowAnalysis::VisitResult::kResultFixed;
44 }
45 };
46
47 // For each block, stores the set of blocks it can be preceded by.
48 // For example, with the following CFG:
49 // V-----------.
50 // -> 11 -> 12 -> 13 -> 15
51 // \-> 14 ---^
52 //
53 // The answer is:
54 // 11: 11, 12, 13
55 // 12: 11, 12, 13
56 // 13: 11, 12, 13
57 // 14: 11, 12, 13
58 // 15: 11, 12, 13, 14
59 struct BackwardReachability : public ForwardDataFlowAnalysis {
60 std::map<uint32_t, std::set<uint32_t>> reachable_from;
61
BackwardReachabilityspvtools::opt::__anon3f9560080111::BackwardReachability62 BackwardReachability(IRContext& context)
63 : ForwardDataFlowAnalysis(
64 context, ForwardDataFlowAnalysis::LabelPosition::kLabelsOnly) {}
65
Visitspvtools::opt::__anon3f9560080111::BackwardReachability66 VisitResult Visit(Instruction* inst) override {
67 // Conditional branches can be enqueued from labels, so skip them.
68 if (inst->opcode() != spv::Op::OpLabel)
69 return DataFlowAnalysis::VisitResult::kResultFixed;
70 uint32_t id = inst->result_id();
71 VisitResult ret = DataFlowAnalysis::VisitResult::kResultFixed;
72 std::set<uint32_t>& precedents = reachable_from[id];
73 for (uint32_t pred : context().cfg()->preds(id)) {
74 bool pred_inserted = precedents.insert(pred).second;
75 if (pred_inserted) {
76 ret = DataFlowAnalysis::VisitResult::kResultChanged;
77 }
78 for (uint32_t block : reachable_from[pred]) {
79 bool inserted = precedents.insert(block).second;
80 if (inserted) {
81 ret = DataFlowAnalysis::VisitResult::kResultChanged;
82 }
83 }
84 }
85 return ret;
86 }
87
InitializeWorklistspvtools::opt::__anon3f9560080111::BackwardReachability88 void InitializeWorklist(Function* function,
89 bool is_first_iteration) override {
90 // Since successor function is exact, only need one pass.
91 if (is_first_iteration) {
92 ForwardDataFlowAnalysis::InitializeWorklist(function, true);
93 }
94 }
95 };
96
TEST_F(DataFlowTest,ReversePostOrder)97 TEST_F(DataFlowTest, ReversePostOrder) {
98 // Note: labels and IDs are intentionally out of order.
99 //
100 // CFG: (order of branches is from bottom to top)
101 // V-----------.
102 // -> 50 -> 40 -> 20 -> 60 -> 70
103 // \-> 30 ---^
104
105 // DFS tree with RPO numbering:
106 // -> 50[0] -> 40[1] -> 20[2] 60[4] -> 70[5]
107 // \-> 30[3] ---^
108
109 const std::string text = R"(
110 OpCapability Shader
111 %1 = OpExtInstImport "GLSL.std.450"
112 OpMemoryModel Logical GLSL450
113 OpEntryPoint Fragment %2 "main"
114 OpExecutionMode %2 OriginUpperLeft
115 OpSource GLSL 430
116 %3 = OpTypeVoid
117 %4 = OpTypeFunction %3
118 %6 = OpTypeBool
119 %5 = OpConstantTrue %6
120 %2 = OpFunction %3 None %4
121 %50 = OpLabel
122 %51 = OpUndef %6
123 %52 = OpUndef %6
124 OpBranch %40
125 %70 = OpLabel
126 %69 = OpUndef %6
127 OpReturn
128 %60 = OpLabel
129 %61 = OpUndef %6
130 OpBranchConditional %5 %70 %40
131 %30 = OpLabel
132 %29 = OpUndef %6
133 OpBranch %60
134 %20 = OpLabel
135 %21 = OpUndef %6
136 OpBranch %60
137 %40 = OpLabel
138 %39 = OpUndef %6
139 OpBranchConditional %5 %30 %20
140 OpFunctionEnd
141 )";
142
143 std::unique_ptr<IRContext> context =
144 BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text,
145 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
146 ASSERT_NE(context, nullptr);
147
148 Function* function = spvtest::GetFunction(context->module(), 2);
149
150 std::map<ForwardDataFlowAnalysis::LabelPosition, std::vector<uint32_t>>
151 expected_order;
152 expected_order[ForwardDataFlowAnalysis::LabelPosition::kLabelsOnly] = {
153 50, 40, 20, 30, 60, 70,
154 };
155 expected_order[ForwardDataFlowAnalysis::LabelPosition::kLabelsAtBeginning] = {
156 50, 51, 52, 40, 39, 20, 21, 30, 29, 60, 61, 70, 69,
157 };
158 expected_order[ForwardDataFlowAnalysis::LabelPosition::kLabelsAtEnd] = {
159 51, 52, 50, 39, 40, 21, 20, 29, 30, 61, 60, 69, 70,
160 };
161 expected_order[ForwardDataFlowAnalysis::LabelPosition::kNoLabels] = {
162 51, 52, 39, 21, 29, 61, 69,
163 };
164
165 for (const auto& test_case : expected_order) {
166 VisitOrder analysis(*context, test_case.first);
167 analysis.Run(function);
168 EXPECT_EQ(test_case.second, analysis.visited_result_ids);
169 }
170 }
171
TEST_F(DataFlowTest,BackwardReachability)172 TEST_F(DataFlowTest, BackwardReachability) {
173 // CFG:
174 // V-----------.
175 // -> 11 -> 12 -> 13 -> 15
176 // \-> 14 ---^
177
178 const std::string text = R"(
179 OpCapability Shader
180 %1 = OpExtInstImport "GLSL.std.450"
181 OpMemoryModel Logical GLSL450
182 OpEntryPoint Fragment %2 "main"
183 OpExecutionMode %2 OriginUpperLeft
184 OpSource GLSL 430
185 %3 = OpTypeVoid
186 %4 = OpTypeFunction %3
187 %6 = OpTypeBool
188 %5 = OpConstantTrue %6
189 %2 = OpFunction %3 None %4
190 %11 = OpLabel
191 OpBranch %12
192 %12 = OpLabel
193 OpBranchConditional %5 %14 %13
194 %13 = OpLabel
195 OpBranchConditional %5 %15 %11
196 %14 = OpLabel
197 OpBranch %15
198 %15 = OpLabel
199 OpReturn
200 OpFunctionEnd
201 )";
202
203 std::unique_ptr<IRContext> context =
204 BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text,
205 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
206 ASSERT_NE(context, nullptr);
207
208 Function* function = spvtest::GetFunction(context->module(), 2);
209
210 BackwardReachability analysis(*context);
211 analysis.Run(function);
212
213 std::map<uint32_t, std::set<uint32_t>> expected_result;
214 expected_result[11] = {11, 12, 13};
215 expected_result[12] = {11, 12, 13};
216 expected_result[13] = {11, 12, 13};
217 expected_result[14] = {11, 12, 13};
218 expected_result[15] = {11, 12, 13, 14};
219 EXPECT_EQ(expected_result, analysis.reachable_from);
220 }
221
222 } // namespace
223 } // namespace opt
224 } // namespace spvtools
225