1 // Copyright (c) 2018 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "test/reduce/reduce_test_util.h"
16 
17 #include <iostream>
18 
19 #include "tools/io.h"
20 
21 namespace spvtools {
22 namespace reduce {
23 
24 const spvtools::MessageConsumer kConsoleMessageConsumer =
25     [](spv_message_level_t level, const char*, const spv_position_t& position,
__anon0c65bff50102(spv_message_level_t level, const char*, const spv_position_t& position, const char* message) 26        const char* message) -> void {
27   switch (level) {
28     case SPV_MSG_FATAL:
29     case SPV_MSG_INTERNAL_ERROR:
30     case SPV_MSG_ERROR:
31       std::cerr << "error: line " << position.index << ": " << message
32                 << std::endl;
33       break;
34     case SPV_MSG_WARNING:
35       std::cout << "warning: line " << position.index << ": " << message
36                 << std::endl;
37       break;
38     case SPV_MSG_INFO:
39       std::cout << "info: line " << position.index << ": " << message
40                 << std::endl;
41       break;
42     default:
43       break;
44   }
45 };
46 
CheckEqual(const spv_target_env env,const std::vector<uint32_t> & expected_binary,const std::vector<uint32_t> & actual_binary)47 void CheckEqual(const spv_target_env env,
48                 const std::vector<uint32_t>& expected_binary,
49                 const std::vector<uint32_t>& actual_binary) {
50   if (expected_binary != actual_binary) {
51     SpirvTools t(env);
52     std::string expected_disassembled;
53     std::string actual_disassembled;
54     ASSERT_TRUE(t.Disassemble(expected_binary, &expected_disassembled,
55                               kReduceDisassembleOption));
56     ASSERT_TRUE(t.Disassemble(actual_binary, &actual_disassembled,
57                               kReduceDisassembleOption));
58     ASSERT_EQ(expected_disassembled, actual_disassembled);
59   }
60 }
61 
CheckEqual(const spv_target_env env,const std::string & expected_text,const std::vector<uint32_t> & actual_binary)62 void CheckEqual(const spv_target_env env, const std::string& expected_text,
63                 const std::vector<uint32_t>& actual_binary) {
64   std::vector<uint32_t> expected_binary;
65   SpirvTools t(env);
66   ASSERT_TRUE(
67       t.Assemble(expected_text, &expected_binary, kReduceAssembleOption));
68   CheckEqual(env, expected_binary, actual_binary);
69 }
70 
CheckEqual(const spv_target_env env,const std::string & expected_text,const opt::IRContext * actual_ir)71 void CheckEqual(const spv_target_env env, const std::string& expected_text,
72                 const opt::IRContext* actual_ir) {
73   std::vector<uint32_t> actual_binary;
74   actual_ir->module()->ToBinary(&actual_binary, false);
75   CheckEqual(env, expected_text, actual_binary);
76 }
77 
CheckValid(spv_target_env env,const opt::IRContext * ir)78 void CheckValid(spv_target_env env, const opt::IRContext* ir) {
79   std::vector<uint32_t> binary;
80   ir->module()->ToBinary(&binary, false);
81   SpirvTools tools(env);
82   tools.SetMessageConsumer(kConsoleMessageConsumer);
83   ASSERT_TRUE(tools.Validate(binary));
84 }
85 
ToString(spv_target_env env,const opt::IRContext * ir)86 std::string ToString(spv_target_env env, const opt::IRContext* ir) {
87   std::vector<uint32_t> binary;
88   ir->module()->ToBinary(&binary, false);
89   SpirvTools t(env);
90   std::string result;
91   t.Disassemble(binary, &result, kReduceDisassembleOption);
92   return result;
93 }
94 
NopDiagnostic(spv_message_level_t,const char *,const spv_position_t &,const char *)95 void NopDiagnostic(spv_message_level_t /*level*/, const char* /*source*/,
96                    const spv_position_t& /*position*/,
97                    const char* /*message*/) {}
98 
CLIMessageConsumer(spv_message_level_t level,const char *,const spv_position_t & position,const char * message)99 void CLIMessageConsumer(spv_message_level_t level, const char*,
100                         const spv_position_t& position, const char* message) {
101   switch (level) {
102     case SPV_MSG_FATAL:
103     case SPV_MSG_INTERNAL_ERROR:
104     case SPV_MSG_ERROR:
105       std::cerr << "error: line " << position.index << ": " << message
106                 << std::endl;
107       break;
108     case SPV_MSG_WARNING:
109       std::cout << "warning: line " << position.index << ": " << message
110                 << std::endl;
111       break;
112     case SPV_MSG_INFO:
113       std::cout << "info: line " << position.index << ": " << message
114                 << std::endl;
115       break;
116     default:
117       break;
118   }
119 }
120 
DumpShader(opt::IRContext * context,const char * filename)121 void DumpShader(opt::IRContext* context, const char* filename) {
122   std::vector<uint32_t> binary;
123   context->module()->ToBinary(&binary, false);
124   DumpShader(binary, filename);
125 }
126 
DumpShader(const std::vector<uint32_t> & binary,const char * filename)127 void DumpShader(const std::vector<uint32_t>& binary, const char* filename) {
128   auto write_file_succeeded =
129       WriteFile(filename, "wb", &binary[0], binary.size());
130   if (!write_file_succeeded) {
131     std::cerr << "Failed to dump shader" << std::endl;
132   }
133 }
134 
135 }  // namespace reduce
136 }  // namespace spvtools
137