1 // Copyright (c) 2022 The Khronos Group Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // Tests ray tracing instructions from SPV_KHR_ray_tracing.
16 
17 #include <sstream>
18 #include <string>
19 
20 #include "gmock/gmock.h"
21 #include "test/val/val_fixtures.h"
22 
23 namespace spvtools {
24 namespace val {
25 namespace {
26 
27 using ::testing::HasSubstr;
28 using ::testing::Values;
29 
30 using ValidateRayTracing = spvtest::ValidateBase<bool>;
31 
TEST_F(ValidateRayTracing,IgnoreIntersectionSuccess)32 TEST_F(ValidateRayTracing, IgnoreIntersectionSuccess) {
33   const std::string body = R"(
34 OpCapability RayTracingKHR
35 OpExtension "SPV_KHR_ray_tracing"
36 OpMemoryModel Logical GLSL450
37 OpEntryPoint AnyHitKHR %main "main"
38 OpName %main "main"
39 %void = OpTypeVoid
40 %func = OpTypeFunction %void
41 %main = OpFunction %void None %func
42 %label = OpLabel
43 OpIgnoreIntersectionKHR
44 OpFunctionEnd
45 )";
46 
47   CompileSuccessfully(body.c_str());
48   EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
49 }
50 
TEST_F(ValidateRayTracing,IgnoreIntersectionExecutionModel)51 TEST_F(ValidateRayTracing, IgnoreIntersectionExecutionModel) {
52   const std::string body = R"(
53 OpCapability RayTracingKHR
54 OpExtension "SPV_KHR_ray_tracing"
55 OpMemoryModel Logical GLSL450
56 OpEntryPoint CallableKHR %main "main"
57 OpName %main "main"
58 %void = OpTypeVoid
59 %func = OpTypeFunction %void
60 %main = OpFunction %void None %func
61 %label = OpLabel
62 OpIgnoreIntersectionKHR
63 OpFunctionEnd
64 )";
65 
66   CompileSuccessfully(body.c_str());
67   EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
68   EXPECT_THAT(
69       getDiagnosticString(),
70       HasSubstr("OpIgnoreIntersectionKHR requires AnyHitKHR execution model"));
71 }
72 
TEST_F(ValidateRayTracing,TerminateRaySuccess)73 TEST_F(ValidateRayTracing, TerminateRaySuccess) {
74   const std::string body = R"(
75 OpCapability RayTracingKHR
76 OpExtension "SPV_KHR_ray_tracing"
77 OpMemoryModel Logical GLSL450
78 OpEntryPoint AnyHitKHR %main "main"
79 OpName %main "main"
80 %void = OpTypeVoid
81 %func = OpTypeFunction %void
82 %main = OpFunction %void None %func
83 %label = OpLabel
84 OpIgnoreIntersectionKHR
85 OpFunctionEnd
86 )";
87 
88   CompileSuccessfully(body.c_str());
89   EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
90 }
91 
TEST_F(ValidateRayTracing,TerminateRayExecutionModel)92 TEST_F(ValidateRayTracing, TerminateRayExecutionModel) {
93   const std::string body = R"(
94 OpCapability RayTracingKHR
95 OpExtension "SPV_KHR_ray_tracing"
96 OpMemoryModel Logical GLSL450
97 OpEntryPoint MissKHR %main "main"
98 OpName %main "main"
99 %void = OpTypeVoid
100 %func = OpTypeFunction %void
101 %main = OpFunction %void None %func
102 %label = OpLabel
103 OpTerminateRayKHR
104 OpFunctionEnd
105 )";
106 
107   CompileSuccessfully(body.c_str());
108   EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
109   EXPECT_THAT(
110       getDiagnosticString(),
111       HasSubstr("OpTerminateRayKHR requires AnyHitKHR execution model"));
112 }
113 
TEST_F(ValidateRayTracing,ReportIntersectionRaySuccess)114 TEST_F(ValidateRayTracing, ReportIntersectionRaySuccess) {
115   const std::string body = R"(
116 OpCapability RayTracingKHR
117 OpExtension "SPV_KHR_ray_tracing"
118 OpMemoryModel Logical GLSL450
119 OpEntryPoint IntersectionKHR %main "main"
120 OpName %main "main"
121 %void = OpTypeVoid
122 %func = OpTypeFunction %void
123 %float = OpTypeFloat 32
124 %float_1 = OpConstant %float 1
125 %uint = OpTypeInt 32 0
126 %uint_1 = OpConstant %uint 1
127 %bool = OpTypeBool
128 %main = OpFunction %void None %func
129 %label = OpLabel
130 %report = OpReportIntersectionKHR %bool %float_1 %uint_1
131 OpReturn
132 OpFunctionEnd
133 )";
134 
135   CompileSuccessfully(body.c_str());
136   EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
137 }
138 
TEST_F(ValidateRayTracing,ReportIntersectionExecutionModel)139 TEST_F(ValidateRayTracing, ReportIntersectionExecutionModel) {
140   const std::string body = R"(
141 OpCapability RayTracingKHR
142 OpExtension "SPV_KHR_ray_tracing"
143 OpMemoryModel Logical GLSL450
144 OpEntryPoint MissKHR %main "main"
145 OpName %main "main"
146 %void = OpTypeVoid
147 %func = OpTypeFunction %void
148 %float = OpTypeFloat 32
149 %float_1 = OpConstant %float 1
150 %uint = OpTypeInt 32 0
151 %uint_1 = OpConstant %uint 1
152 %bool = OpTypeBool
153 %main = OpFunction %void None %func
154 %label = OpLabel
155 %report = OpReportIntersectionKHR %bool %float_1 %uint_1
156 OpReturn
157 OpFunctionEnd
158 )";
159 
160   CompileSuccessfully(body.c_str());
161   EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
162   EXPECT_THAT(
163       getDiagnosticString(),
164       HasSubstr(
165           "OpReportIntersectionKHR requires IntersectionKHR execution model"));
166 }
167 
TEST_F(ValidateRayTracing,ReportIntersectionReturnType)168 TEST_F(ValidateRayTracing, ReportIntersectionReturnType) {
169   const std::string body = R"(
170 OpCapability RayTracingKHR
171 OpExtension "SPV_KHR_ray_tracing"
172 OpMemoryModel Logical GLSL450
173 OpEntryPoint IntersectionKHR %main "main"
174 OpName %main "main"
175 %void = OpTypeVoid
176 %func = OpTypeFunction %void
177 %float = OpTypeFloat 32
178 %float_1 = OpConstant %float 1
179 %uint = OpTypeInt 32 0
180 %uint_1 = OpConstant %uint 1
181 %main = OpFunction %void None %func
182 %label = OpLabel
183 %report = OpReportIntersectionKHR %uint %float_1 %uint_1
184 OpReturn
185 OpFunctionEnd
186 )";
187 
188   CompileSuccessfully(body.c_str());
189   EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
190   EXPECT_THAT(getDiagnosticString(),
191               HasSubstr("expected Result Type to be bool scalar type"));
192 }
193 
TEST_F(ValidateRayTracing,ReportIntersectionHit)194 TEST_F(ValidateRayTracing, ReportIntersectionHit) {
195   const std::string body = R"(
196 OpCapability RayTracingKHR
197 OpCapability Float64
198 OpExtension "SPV_KHR_ray_tracing"
199 OpMemoryModel Logical GLSL450
200 OpEntryPoint IntersectionKHR %main "main"
201 OpName %main "main"
202 %void = OpTypeVoid
203 %func = OpTypeFunction %void
204 %float64 = OpTypeFloat 64
205 %float64_1 = OpConstant %float64 1
206 %uint = OpTypeInt 32 0
207 %uint_1 = OpConstant %uint 1
208 %bool = OpTypeBool
209 %main = OpFunction %void None %func
210 %label = OpLabel
211 %report = OpReportIntersectionKHR %bool %float64_1 %uint_1
212 OpReturn
213 OpFunctionEnd
214 )";
215 
216   CompileSuccessfully(body.c_str());
217   EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
218   EXPECT_THAT(getDiagnosticString(),
219               HasSubstr("Hit must be a 32-bit int scalar"));
220 }
221 
TEST_F(ValidateRayTracing,ReportIntersectionHitKind)222 TEST_F(ValidateRayTracing, ReportIntersectionHitKind) {
223   const std::string body = R"(
224 OpCapability RayTracingKHR
225 OpExtension "SPV_KHR_ray_tracing"
226 OpMemoryModel Logical GLSL450
227 OpEntryPoint IntersectionKHR %main "main"
228 OpName %main "main"
229 %void = OpTypeVoid
230 %func = OpTypeFunction %void
231 %float = OpTypeFloat 32
232 %float_1 = OpConstant %float 1
233 %sint = OpTypeInt 32 1
234 %sint_1 = OpConstant %sint 1
235 %bool = OpTypeBool
236 %main = OpFunction %void None %func
237 %label = OpLabel
238 %report = OpReportIntersectionKHR %bool %float_1 %sint_1
239 OpReturn
240 OpFunctionEnd
241 )";
242 
243   CompileSuccessfully(body.c_str());
244   EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
245   EXPECT_THAT(getDiagnosticString(),
246               HasSubstr("Hit Kind must be a 32-bit unsigned int scalar"));
247 }
248 
TEST_F(ValidateRayTracing,ExecuteCallableSuccess)249 TEST_F(ValidateRayTracing, ExecuteCallableSuccess) {
250   const std::string body = R"(
251 OpCapability RayTracingKHR
252 OpExtension "SPV_KHR_ray_tracing"
253 OpMemoryModel Logical GLSL450
254 OpEntryPoint CallableKHR %main "main"
255 OpName %main "main"
256 %void = OpTypeVoid
257 %func = OpTypeFunction %void
258 %int = OpTypeInt 32 1
259 %uint = OpTypeInt 32 0
260 %uint_0 = OpConstant %uint 0
261 %data_ptr = OpTypePointer CallableDataKHR %int
262 %data = OpVariable %data_ptr CallableDataKHR
263 %inData_ptr = OpTypePointer IncomingCallableDataKHR %int
264 %inData = OpVariable %inData_ptr IncomingCallableDataKHR
265 %main = OpFunction %void None %func
266 %label = OpLabel
267 OpExecuteCallableKHR %uint_0 %data
268 OpExecuteCallableKHR %uint_0 %inData
269 OpReturn
270 OpFunctionEnd
271 )";
272 
273   CompileSuccessfully(body.c_str());
274   EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
275 }
276 
TEST_F(ValidateRayTracing,ExecuteCallableExecutionModel)277 TEST_F(ValidateRayTracing, ExecuteCallableExecutionModel) {
278   const std::string body = R"(
279 OpCapability RayTracingKHR
280 OpExtension "SPV_KHR_ray_tracing"
281 OpMemoryModel Logical GLSL450
282 OpEntryPoint AnyHitKHR %main "main"
283 OpName %main "main"
284 %void = OpTypeVoid
285 %func = OpTypeFunction %void
286 %int = OpTypeInt 32 1
287 %uint = OpTypeInt 32 0
288 %uint_0 = OpConstant %uint 0
289 %data_ptr = OpTypePointer CallableDataKHR %int
290 %data = OpVariable %data_ptr CallableDataKHR
291 %inData_ptr = OpTypePointer IncomingCallableDataKHR %int
292 %inData = OpVariable %inData_ptr IncomingCallableDataKHR
293 %main = OpFunction %void None %func
294 %label = OpLabel
295 OpExecuteCallableKHR %uint_0 %data
296 OpExecuteCallableKHR %uint_0 %inData
297 OpReturn
298 OpFunctionEnd
299 )";
300 
301   CompileSuccessfully(body.c_str());
302   EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
303   EXPECT_THAT(
304       getDiagnosticString(),
305       HasSubstr("OpExecuteCallableKHR requires RayGenerationKHR, "
306                 "ClosestHitKHR, MissKHR and CallableKHR execution models"));
307 }
308 
TEST_F(ValidateRayTracing,ExecuteCallableStorageClass)309 TEST_F(ValidateRayTracing, ExecuteCallableStorageClass) {
310   const std::string body = R"(
311 OpCapability RayTracingKHR
312 OpExtension "SPV_KHR_ray_tracing"
313 OpMemoryModel Logical GLSL450
314 OpEntryPoint RayGenerationKHR %main "main"
315 OpName %main "main"
316 %void = OpTypeVoid
317 %func = OpTypeFunction %void
318 %int = OpTypeInt 32 1
319 %uint = OpTypeInt 32 0
320 %uint_0 = OpConstant %uint 0
321 %data_ptr = OpTypePointer RayPayloadKHR %int
322 %data = OpVariable %data_ptr RayPayloadKHR
323 %main = OpFunction %void None %func
324 %label = OpLabel
325 OpExecuteCallableKHR %uint_0 %data
326 OpReturn
327 OpFunctionEnd
328 )";
329 
330   CompileSuccessfully(body.c_str());
331   EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
332   EXPECT_THAT(getDiagnosticString(),
333               HasSubstr("Callable Data must have storage class CallableDataKHR "
334                         "or IncomingCallableDataKHR"));
335 }
336 
TEST_F(ValidateRayTracing,ExecuteCallableSbtIndex)337 TEST_F(ValidateRayTracing, ExecuteCallableSbtIndex) {
338   const std::string body = R"(
339 OpCapability RayTracingKHR
340 OpExtension "SPV_KHR_ray_tracing"
341 OpMemoryModel Logical GLSL450
342 OpEntryPoint CallableKHR %main "main"
343 OpName %main "main"
344 %void = OpTypeVoid
345 %func = OpTypeFunction %void
346 %int = OpTypeInt 32 1
347 %uint = OpTypeInt 32 0
348 %uint_0 = OpConstant %uint 0
349 %int_1 = OpConstant %int 1
350 %data_ptr = OpTypePointer CallableDataKHR %int
351 %data = OpVariable %data_ptr CallableDataKHR
352 %main = OpFunction %void None %func
353 %label = OpLabel
354 OpExecuteCallableKHR %int_1 %data
355 OpReturn
356 OpFunctionEnd
357 )";
358 
359   CompileSuccessfully(body.c_str());
360   EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
361   EXPECT_THAT(getDiagnosticString(),
362               HasSubstr("SBT Index must be a 32-bit unsigned int scalar"));
363 }
364 
GenerateRayTraceCode(const std::string & body,const std::string execution_model="RayGenerationKHR")365 std::string GenerateRayTraceCode(
366     const std::string& body,
367     const std::string execution_model = "RayGenerationKHR") {
368   std::ostringstream ss;
369   ss << R"(
370 OpCapability RayTracingKHR
371 OpCapability Float64
372 OpExtension "SPV_KHR_ray_tracing"
373 OpMemoryModel Logical GLSL450
374 OpEntryPoint )"
375      << execution_model << R"( %main "main"
376 OpDecorate %top_level_as DescriptorSet 0
377 OpDecorate %top_level_as Binding 0
378 %void = OpTypeVoid
379 %func = OpTypeFunction %void
380 %type_as = OpTypeAccelerationStructureKHR
381 %as_uc_ptr = OpTypePointer UniformConstant %type_as
382 %top_level_as = OpVariable %as_uc_ptr UniformConstant
383 %uint = OpTypeInt 32 0
384 %uint_1 = OpConstant %uint 1
385 %float = OpTypeFloat 32
386 %float64 = OpTypeFloat 64
387 %f32vec3 = OpTypeVector %float 3
388 %f32vec4 = OpTypeVector %float 4
389 %float_0 = OpConstant %float 0
390 %float64_0 = OpConstant %float64 0
391 %v3composite = OpConstantComposite %f32vec3 %float_0 %float_0 %float_0
392 %v4composite = OpConstantComposite %f32vec4 %float_0 %float_0 %float_0 %float_0
393 %int = OpTypeInt 32 1
394 %int_1 = OpConstant %int 1
395 %payload_ptr = OpTypePointer RayPayloadKHR %int
396 %payload = OpVariable %payload_ptr RayPayloadKHR
397 %callable_ptr = OpTypePointer CallableDataKHR %int
398 %callable = OpVariable %callable_ptr CallableDataKHR
399 %ptr_uint = OpTypePointer Private %uint
400 %var_uint = OpVariable %ptr_uint Private
401 %ptr_float = OpTypePointer Private %float
402 %var_float = OpVariable %ptr_float Private
403 %ptr_f32vec3 = OpTypePointer Private %f32vec3
404 %var_f32vec3 = OpVariable %ptr_f32vec3 Private
405 %main = OpFunction %void None %func
406 %label = OpLabel
407 )";
408 
409   ss << body;
410 
411   ss << R"(
412 OpReturn
413 OpFunctionEnd)";
414   return ss.str();
415 }
416 
TEST_F(ValidateRayTracing,TraceRaySuccess)417 TEST_F(ValidateRayTracing, TraceRaySuccess) {
418   const std::string body = R"(
419 %as = OpLoad %type_as %top_level_as
420 OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload
421 
422 %_uint = OpLoad %uint %var_uint
423 %_float = OpLoad %float %var_float
424 %_f32vec3 = OpLoad %f32vec3 %var_f32vec3
425 OpTraceRayKHR %as %_uint %_uint %_uint %_uint %_uint %_f32vec3 %_float %_f32vec3 %_float %payload
426 )";
427 
428   CompileSuccessfully(GenerateRayTraceCode(body).c_str());
429   EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
430 }
431 
TEST_F(ValidateRayTracing,TraceRayExecutionModel)432 TEST_F(ValidateRayTracing, TraceRayExecutionModel) {
433   const std::string body = R"(
434 %as = OpLoad %type_as %top_level_as
435 OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload
436 )";
437 
438   CompileSuccessfully(GenerateRayTraceCode(body, "CallableKHR").c_str());
439   EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
440   EXPECT_THAT(getDiagnosticString(),
441               HasSubstr("OpTraceRayKHR requires RayGenerationKHR, "
442                         "ClosestHitKHR and MissKHR execution models"));
443 }
444 
TEST_F(ValidateRayTracing,TraceRayAccelerationStructure)445 TEST_F(ValidateRayTracing, TraceRayAccelerationStructure) {
446   const std::string body = R"(
447 %_uint = OpLoad %uint %var_uint
448 OpTraceRayKHR %_uint %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload
449 )";
450 
451   CompileSuccessfully(GenerateRayTraceCode(body).c_str());
452   EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
453   EXPECT_THAT(getDiagnosticString(),
454               HasSubstr("Expected Acceleration Structure to be of type "
455                         "OpTypeAccelerationStructureKHR"));
456 }
457 
TEST_F(ValidateRayTracing,TraceRayRayFlags)458 TEST_F(ValidateRayTracing, TraceRayRayFlags) {
459   const std::string body = R"(
460 %as = OpLoad %type_as %top_level_as
461 OpTraceRayKHR %as %float_0 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload
462 )";
463 
464   CompileSuccessfully(GenerateRayTraceCode(body).c_str());
465   EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
466   EXPECT_THAT(getDiagnosticString(),
467               HasSubstr("Ray Flags must be a 32-bit int scalar"));
468 }
469 
TEST_F(ValidateRayTracing,TraceRayCullMask)470 TEST_F(ValidateRayTracing, TraceRayCullMask) {
471   const std::string body = R"(
472 %as = OpLoad %type_as %top_level_as
473 OpTraceRayKHR %as %uint_1 %float_0 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload
474 )";
475 
476   CompileSuccessfully(GenerateRayTraceCode(body).c_str());
477   EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
478   EXPECT_THAT(getDiagnosticString(),
479               HasSubstr("Cull Mask must be a 32-bit int scalar"));
480 }
481 
TEST_F(ValidateRayTracing,TraceRaySbtOffest)482 TEST_F(ValidateRayTracing, TraceRaySbtOffest) {
483   const std::string body = R"(
484 %as = OpLoad %type_as %top_level_as
485 OpTraceRayKHR %as %uint_1 %uint_1 %float_0 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload
486 )";
487 
488   CompileSuccessfully(GenerateRayTraceCode(body).c_str());
489   EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
490   EXPECT_THAT(getDiagnosticString(),
491               HasSubstr("SBT Offset must be a 32-bit int scalar"));
492 }
493 
TEST_F(ValidateRayTracing,TraceRaySbtStride)494 TEST_F(ValidateRayTracing, TraceRaySbtStride) {
495   const std::string body = R"(
496 %as = OpLoad %type_as %top_level_as
497 OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %float_0 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload
498 )";
499 
500   CompileSuccessfully(GenerateRayTraceCode(body).c_str());
501   EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
502   EXPECT_THAT(getDiagnosticString(),
503               HasSubstr("SBT Stride must be a 32-bit int scalar"));
504 }
505 
TEST_F(ValidateRayTracing,TraceRayMissIndex)506 TEST_F(ValidateRayTracing, TraceRayMissIndex) {
507   const std::string body = R"(
508 %as = OpLoad %type_as %top_level_as
509 OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %float_0 %v3composite %float_0 %v3composite %float_0 %payload
510 )";
511 
512   CompileSuccessfully(GenerateRayTraceCode(body).c_str());
513   EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
514   EXPECT_THAT(getDiagnosticString(),
515               HasSubstr("Miss Index must be a 32-bit int scalar"));
516 }
517 
TEST_F(ValidateRayTracing,TraceRayRayOrigin)518 TEST_F(ValidateRayTracing, TraceRayRayOrigin) {
519   const std::string body = R"(
520 %as = OpLoad %type_as %top_level_as
521 OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %float_0 %float_0 %v3composite %float_0 %payload
522 )";
523 
524   CompileSuccessfully(GenerateRayTraceCode(body).c_str());
525   EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
526   EXPECT_THAT(
527       getDiagnosticString(),
528       HasSubstr("Ray Origin must be a 32-bit float 3-component vector"));
529 }
530 
TEST_F(ValidateRayTracing,TraceRayRayTMin)531 TEST_F(ValidateRayTracing, TraceRayRayTMin) {
532   const std::string body = R"(
533 %as = OpLoad %type_as %top_level_as
534 OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %uint_1 %v3composite %float_0 %payload
535 )";
536 
537   CompileSuccessfully(GenerateRayTraceCode(body).c_str());
538   EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
539   EXPECT_THAT(getDiagnosticString(),
540               HasSubstr("Ray TMin must be a 32-bit float scalar"));
541 }
542 
TEST_F(ValidateRayTracing,TraceRayRayDirection)543 TEST_F(ValidateRayTracing, TraceRayRayDirection) {
544   const std::string body = R"(
545 %as = OpLoad %type_as %top_level_as
546 OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v4composite %float_0 %payload
547 )";
548 
549   CompileSuccessfully(GenerateRayTraceCode(body).c_str());
550   EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
551   EXPECT_THAT(
552       getDiagnosticString(),
553       HasSubstr("Ray Direction must be a 32-bit float 3-component vector"));
554 }
555 
TEST_F(ValidateRayTracing,TraceRayRayTMax)556 TEST_F(ValidateRayTracing, TraceRayRayTMax) {
557   const std::string body = R"(
558 %as = OpLoad %type_as %top_level_as
559 OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float64_0 %payload
560 )";
561 
562   CompileSuccessfully(GenerateRayTraceCode(body).c_str());
563   EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
564   EXPECT_THAT(getDiagnosticString(),
565               HasSubstr("Ray TMax must be a 32-bit float scalar"));
566 }
567 
TEST_F(ValidateRayTracing,TraceRayPayload)568 TEST_F(ValidateRayTracing, TraceRayPayload) {
569   const std::string body = R"(
570 %as = OpLoad %type_as %top_level_as
571 OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %callable
572 )";
573 
574   CompileSuccessfully(GenerateRayTraceCode(body).c_str());
575   EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
576   EXPECT_THAT(getDiagnosticString(),
577               HasSubstr("Payload must have storage class RayPayloadKHR or "
578                         "IncomingRayPayloadKHR"));
579 }
580 
TEST_F(ValidateRayTracing,InterfaceIncomingRayPayload)581 TEST_F(ValidateRayTracing, InterfaceIncomingRayPayload) {
582   const std::string body = R"(
583 OpCapability RayTracingKHR
584 OpExtension "SPV_KHR_ray_tracing"
585 OpMemoryModel Logical GLSL450
586 OpEntryPoint CallableKHR %main "main" %inData1 %inData2
587 OpName %main "main"
588 %void = OpTypeVoid
589 %func = OpTypeFunction %void
590 %int = OpTypeInt 32 1
591 %inData_ptr = OpTypePointer IncomingRayPayloadKHR %int
592 %inData1 = OpVariable %inData_ptr IncomingRayPayloadKHR
593 %inData2 = OpVariable %inData_ptr IncomingRayPayloadKHR
594 %main = OpFunction %void None %func
595 %label = OpLabel
596 OpReturn
597 OpFunctionEnd
598 )";
599 
600   CompileSuccessfully(body.c_str(), SPV_ENV_VULKAN_1_2);
601   EXPECT_EQ(SPV_ERROR_INVALID_DATA,
602             ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_2));
603   EXPECT_THAT(getDiagnosticString(),
604               AnyVUID("VUID-StandaloneSpirv-IncomingRayPayloadKHR-04700"));
605   EXPECT_THAT(
606       getDiagnosticString(),
607       HasSubstr("Entry-point has more than one variable with the "
608                 "IncomingRayPayloadKHR storage class in the interface"));
609 }
610 
TEST_F(ValidateRayTracing,InterfaceHitAttribute)611 TEST_F(ValidateRayTracing, InterfaceHitAttribute) {
612   const std::string body = R"(
613 OpCapability RayTracingKHR
614 OpExtension "SPV_KHR_ray_tracing"
615 OpMemoryModel Logical GLSL450
616 OpEntryPoint CallableKHR %main "main" %inData1 %inData2
617 OpName %main "main"
618 %void = OpTypeVoid
619 %func = OpTypeFunction %void
620 %int = OpTypeInt 32 1
621 %inData_ptr = OpTypePointer HitAttributeKHR %int
622 %inData1 = OpVariable %inData_ptr HitAttributeKHR
623 %inData2 = OpVariable %inData_ptr HitAttributeKHR
624 %main = OpFunction %void None %func
625 %label = OpLabel
626 OpReturn
627 OpFunctionEnd
628 )";
629 
630   CompileSuccessfully(body.c_str(), SPV_ENV_VULKAN_1_2);
631   EXPECT_EQ(SPV_ERROR_INVALID_DATA,
632             ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_2));
633   EXPECT_THAT(getDiagnosticString(),
634               AnyVUID("VUID-StandaloneSpirv-HitAttributeKHR-04702"));
635   EXPECT_THAT(getDiagnosticString(),
636               HasSubstr("Entry-point has more than one variable with the "
637                         "HitAttributeKHR storage class in the interface"));
638 }
639 
TEST_F(ValidateRayTracing,InterfaceIncomingCallableData)640 TEST_F(ValidateRayTracing, InterfaceIncomingCallableData) {
641   const std::string body = R"(
642 OpCapability RayTracingKHR
643 OpExtension "SPV_KHR_ray_tracing"
644 OpMemoryModel Logical GLSL450
645 OpEntryPoint CallableKHR %main "main" %inData1 %inData2
646 OpName %main "main"
647 %void = OpTypeVoid
648 %func = OpTypeFunction %void
649 %int = OpTypeInt 32 1
650 %inData_ptr = OpTypePointer IncomingCallableDataKHR %int
651 %inData1 = OpVariable %inData_ptr IncomingCallableDataKHR
652 %inData2 = OpVariable %inData_ptr IncomingCallableDataKHR
653 %main = OpFunction %void None %func
654 %label = OpLabel
655 OpReturn
656 OpFunctionEnd
657 )";
658 
659   CompileSuccessfully(body.c_str(), SPV_ENV_VULKAN_1_2);
660   EXPECT_EQ(SPV_ERROR_INVALID_DATA,
661             ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_2));
662   EXPECT_THAT(getDiagnosticString(),
663               AnyVUID("VUID-StandaloneSpirv-IncomingCallableDataKHR-04706"));
664   EXPECT_THAT(
665       getDiagnosticString(),
666       HasSubstr("Entry-point has more than one variable with the "
667                 "IncomingCallableDataKHR storage class in the interface"));
668 }
669 
670 }  // namespace
671 }  // namespace val
672 }  // namespace spvtools
673