xref: /aosp_15_r20/external/swiftshader/third_party/SPIRV-Tools/test/opt/dataflow.cpp (revision 03ce13f70fcc45d86ee91b7ee4cab1936a95046e)
1 // Copyright (c) 2021 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/dataflow.h"
16 
17 #include <map>
18 #include <set>
19 
20 #include "gtest/gtest.h"
21 #include "opt/function_utils.h"
22 #include "source/opt/build_module.h"
23 
24 namespace spvtools {
25 namespace opt {
26 namespace {
27 
28 using DataFlowTest = ::testing::Test;
29 
30 // Simple analyses for testing:
31 
32 // Stores the result IDs of visited instructions in visit order.
33 struct VisitOrder : public ForwardDataFlowAnalysis {
34   std::vector<uint32_t> visited_result_ids;
35 
VisitOrderspvtools::opt::__anon3f9560080111::VisitOrder36   VisitOrder(IRContext& context, LabelPosition label_position)
37       : ForwardDataFlowAnalysis(context, label_position) {}
38 
Visitspvtools::opt::__anon3f9560080111::VisitOrder39   VisitResult Visit(Instruction* inst) override {
40     if (inst->HasResultId()) {
41       visited_result_ids.push_back(inst->result_id());
42     }
43     return DataFlowAnalysis::VisitResult::kResultFixed;
44   }
45 };
46 
47 // For each block, stores the set of blocks it can be preceded by.
48 // For example, with the following CFG:
49 //    V-----------.
50 // -> 11 -> 12 -> 13 -> 15
51 //            \-> 14 ---^
52 //
53 // The answer is:
54 // 11: 11, 12, 13
55 // 12: 11, 12, 13
56 // 13: 11, 12, 13
57 // 14: 11, 12, 13
58 // 15: 11, 12, 13, 14
59 struct BackwardReachability : public ForwardDataFlowAnalysis {
60   std::map<uint32_t, std::set<uint32_t>> reachable_from;
61 
BackwardReachabilityspvtools::opt::__anon3f9560080111::BackwardReachability62   BackwardReachability(IRContext& context)
63       : ForwardDataFlowAnalysis(
64             context, ForwardDataFlowAnalysis::LabelPosition::kLabelsOnly) {}
65 
Visitspvtools::opt::__anon3f9560080111::BackwardReachability66   VisitResult Visit(Instruction* inst) override {
67     // Conditional branches can be enqueued from labels, so skip them.
68     if (inst->opcode() != spv::Op::OpLabel)
69       return DataFlowAnalysis::VisitResult::kResultFixed;
70     uint32_t id = inst->result_id();
71     VisitResult ret = DataFlowAnalysis::VisitResult::kResultFixed;
72     std::set<uint32_t>& precedents = reachable_from[id];
73     for (uint32_t pred : context().cfg()->preds(id)) {
74       bool pred_inserted = precedents.insert(pred).second;
75       if (pred_inserted) {
76         ret = DataFlowAnalysis::VisitResult::kResultChanged;
77       }
78       for (uint32_t block : reachable_from[pred]) {
79         bool inserted = precedents.insert(block).second;
80         if (inserted) {
81           ret = DataFlowAnalysis::VisitResult::kResultChanged;
82         }
83       }
84     }
85     return ret;
86   }
87 
InitializeWorklistspvtools::opt::__anon3f9560080111::BackwardReachability88   void InitializeWorklist(Function* function,
89                           bool is_first_iteration) override {
90     // Since successor function is exact, only need one pass.
91     if (is_first_iteration) {
92       ForwardDataFlowAnalysis::InitializeWorklist(function, true);
93     }
94   }
95 };
96 
TEST_F(DataFlowTest,ReversePostOrder)97 TEST_F(DataFlowTest, ReversePostOrder) {
98   // Note: labels and IDs are intentionally out of order.
99   //
100   // CFG: (order of branches is from bottom to top)
101   //          V-----------.
102   // -> 50 -> 40 -> 20 -> 60 -> 70
103   //            \-> 30 ---^
104 
105   // DFS tree with RPO numbering:
106   // -> 50[0] -> 40[1] -> 20[2]    60[4] -> 70[5]
107   //                  \-> 30[3] ---^
108 
109   const std::string text = R"(
110                OpCapability Shader
111           %1 = OpExtInstImport "GLSL.std.450"
112                OpMemoryModel Logical GLSL450
113                OpEntryPoint Fragment %2 "main"
114                OpExecutionMode %2 OriginUpperLeft
115                OpSource GLSL 430
116           %3 = OpTypeVoid
117           %4 = OpTypeFunction %3
118           %6 = OpTypeBool
119           %5 = OpConstantTrue %6
120           %2 = OpFunction %3 None %4
121          %50 = OpLabel
122          %51 = OpUndef %6
123          %52 = OpUndef %6
124                OpBranch %40
125          %70 = OpLabel
126          %69 = OpUndef %6
127                OpReturn
128          %60 = OpLabel
129          %61 = OpUndef %6
130                OpBranchConditional %5 %70 %40
131          %30 = OpLabel
132          %29 = OpUndef %6
133                OpBranch %60
134          %20 = OpLabel
135          %21 = OpUndef %6
136                OpBranch %60
137          %40 = OpLabel
138          %39 = OpUndef %6
139                OpBranchConditional %5 %30 %20
140                OpFunctionEnd
141   )";
142 
143   std::unique_ptr<IRContext> context =
144       BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text,
145                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
146   ASSERT_NE(context, nullptr);
147 
148   Function* function = spvtest::GetFunction(context->module(), 2);
149 
150   std::map<ForwardDataFlowAnalysis::LabelPosition, std::vector<uint32_t>>
151       expected_order;
152   expected_order[ForwardDataFlowAnalysis::LabelPosition::kLabelsOnly] = {
153       50, 40, 20, 30, 60, 70,
154   };
155   expected_order[ForwardDataFlowAnalysis::LabelPosition::kLabelsAtBeginning] = {
156       50, 51, 52, 40, 39, 20, 21, 30, 29, 60, 61, 70, 69,
157   };
158   expected_order[ForwardDataFlowAnalysis::LabelPosition::kLabelsAtEnd] = {
159       51, 52, 50, 39, 40, 21, 20, 29, 30, 61, 60, 69, 70,
160   };
161   expected_order[ForwardDataFlowAnalysis::LabelPosition::kNoLabels] = {
162       51, 52, 39, 21, 29, 61, 69,
163   };
164 
165   for (const auto& test_case : expected_order) {
166     VisitOrder analysis(*context, test_case.first);
167     analysis.Run(function);
168     EXPECT_EQ(test_case.second, analysis.visited_result_ids);
169   }
170 }
171 
TEST_F(DataFlowTest,BackwardReachability)172 TEST_F(DataFlowTest, BackwardReachability) {
173   // CFG:
174   //    V-----------.
175   // -> 11 -> 12 -> 13 -> 15
176   //            \-> 14 ---^
177 
178   const std::string text = R"(
179                OpCapability Shader
180           %1 = OpExtInstImport "GLSL.std.450"
181                OpMemoryModel Logical GLSL450
182                OpEntryPoint Fragment %2 "main"
183                OpExecutionMode %2 OriginUpperLeft
184                OpSource GLSL 430
185           %3 = OpTypeVoid
186           %4 = OpTypeFunction %3
187           %6 = OpTypeBool
188           %5 = OpConstantTrue %6
189           %2 = OpFunction %3 None %4
190          %11 = OpLabel
191                OpBranch %12
192          %12 = OpLabel
193                OpBranchConditional %5 %14 %13
194          %13 = OpLabel
195                OpBranchConditional %5 %15 %11
196          %14 = OpLabel
197                OpBranch %15
198          %15 = OpLabel
199                OpReturn
200                OpFunctionEnd
201   )";
202 
203   std::unique_ptr<IRContext> context =
204       BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text,
205                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
206   ASSERT_NE(context, nullptr);
207 
208   Function* function = spvtest::GetFunction(context->module(), 2);
209 
210   BackwardReachability analysis(*context);
211   analysis.Run(function);
212 
213   std::map<uint32_t, std::set<uint32_t>> expected_result;
214   expected_result[11] = {11, 12, 13};
215   expected_result[12] = {11, 12, 13};
216   expected_result[13] = {11, 12, 13};
217   expected_result[14] = {11, 12, 13};
218   expected_result[15] = {11, 12, 13, 14};
219   EXPECT_EQ(expected_result, analysis.reachable_from);
220 }
221 
222 }  // namespace
223 }  // namespace opt
224 }  // namespace spvtools
225