xref: /aosp_15_r20/external/deqp/external/vulkancts/framework/vulkan/vkRayTracingUtil.cpp (revision 35238bce31c2a825756842865a792f8cf7f89930)
1 /*-------------------------------------------------------------------------
2  * Vulkan CTS Framework
3  * --------------------
4  *
5  * Copyright (c) 2020 The Khronos Group Inc.
6  *
7  * Licensed under the Apache License, Version 2.0 (the "License");
8  * you may not use this file except in compliance with the License.
9  * You may obtain a copy of the License at
10  *
11  *      http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  *
19  *//*!
20  * \file
21  * \brief Utilities for creating commonly used Vulkan objects
22  *//*--------------------------------------------------------------------*/
23 
24 #include "vkRayTracingUtil.hpp"
25 
26 #include "vkRefUtil.hpp"
27 #include "vkQueryUtil.hpp"
28 #include "vkObjUtil.hpp"
29 #include "vkBarrierUtil.hpp"
30 #include "vkCmdUtil.hpp"
31 
32 #include "deStringUtil.hpp"
33 #include "deSTLUtil.hpp"
34 
35 #include <vector>
36 #include <string>
37 #include <thread>
38 #include <limits>
39 #include <type_traits>
40 #include <map>
41 
42 #include "SPIRV/spirv.hpp"
43 
44 namespace vk
45 {
46 
47 #ifndef CTS_USES_VULKANSC
48 
49 static const uint32_t WATCHDOG_INTERVAL = 16384; // Touch watchDog every N iterations.
50 
51 struct DeferredThreadParams
52 {
53     const DeviceInterface &vk;
54     VkDevice device;
55     VkDeferredOperationKHR deferredOperation;
56     VkResult result;
57 };
58 
getFormatSimpleName(vk::VkFormat format)59 std::string getFormatSimpleName(vk::VkFormat format)
60 {
61     constexpr size_t kPrefixLen = 10; // strlen("VK_FORMAT_")
62     return de::toLower(de::toString(format).substr(kPrefixLen));
63 }
64 
pointInTriangle2D(const tcu::Vec3 & p,const tcu::Vec3 & p0,const tcu::Vec3 & p1,const tcu::Vec3 & p2)65 bool pointInTriangle2D(const tcu::Vec3 &p, const tcu::Vec3 &p0, const tcu::Vec3 &p1, const tcu::Vec3 &p2)
66 {
67     float s = p0.y() * p2.x() - p0.x() * p2.y() + (p2.y() - p0.y()) * p.x() + (p0.x() - p2.x()) * p.y();
68     float t = p0.x() * p1.y() - p0.y() * p1.x() + (p0.y() - p1.y()) * p.x() + (p1.x() - p0.x()) * p.y();
69 
70     if ((s < 0) != (t < 0))
71         return false;
72 
73     float a = -p1.y() * p2.x() + p0.y() * (p2.x() - p1.x()) + p0.x() * (p1.y() - p2.y()) + p1.x() * p2.y();
74 
75     return a < 0 ? (s <= 0 && s + t >= a) : (s >= 0 && s + t <= a);
76 }
77 
78 // Returns true if VK_FORMAT_FEATURE_ACCELERATION_STRUCTURE_VERTEX_BUFFER_BIT_KHR needs to be supported for the given format.
isMandatoryAccelerationStructureVertexBufferFormat(vk::VkFormat format)79 static bool isMandatoryAccelerationStructureVertexBufferFormat(vk::VkFormat format)
80 {
81     bool mandatory = false;
82 
83     switch (format)
84     {
85     case VK_FORMAT_R32G32_SFLOAT:
86     case VK_FORMAT_R32G32B32_SFLOAT:
87     case VK_FORMAT_R16G16_SFLOAT:
88     case VK_FORMAT_R16G16B16A16_SFLOAT:
89     case VK_FORMAT_R16G16_SNORM:
90     case VK_FORMAT_R16G16B16A16_SNORM:
91         mandatory = true;
92         break;
93     default:
94         break;
95     }
96 
97     return mandatory;
98 }
99 
checkAccelerationStructureVertexBufferFormat(const vk::InstanceInterface & vki,vk::VkPhysicalDevice physicalDevice,vk::VkFormat format)100 void checkAccelerationStructureVertexBufferFormat(const vk::InstanceInterface &vki, vk::VkPhysicalDevice physicalDevice,
101                                                   vk::VkFormat format)
102 {
103     const vk::VkFormatProperties formatProperties = getPhysicalDeviceFormatProperties(vki, physicalDevice, format);
104 
105     if ((formatProperties.bufferFeatures & vk::VK_FORMAT_FEATURE_ACCELERATION_STRUCTURE_VERTEX_BUFFER_BIT_KHR) == 0u)
106     {
107         const std::string errorMsg = "Format not supported for acceleration structure vertex buffers";
108         if (isMandatoryAccelerationStructureVertexBufferFormat(format))
109             TCU_FAIL(errorMsg);
110         TCU_THROW(NotSupportedError, errorMsg);
111     }
112 }
113 
getCommonRayGenerationShader(void)114 std::string getCommonRayGenerationShader(void)
115 {
116     return "#version 460 core\n"
117            "#extension GL_EXT_ray_tracing : require\n"
118            "layout(location = 0) rayPayloadEXT vec3 hitValue;\n"
119            "layout(set = 0, binding = 1) uniform accelerationStructureEXT topLevelAS;\n"
120            "\n"
121            "void main()\n"
122            "{\n"
123            "  uint  rayFlags = 0;\n"
124            "  uint  cullMask = 0xFF;\n"
125            "  float tmin     = 0.0;\n"
126            "  float tmax     = 9.0;\n"
127            "  vec3  origin   = vec3((float(gl_LaunchIDEXT.x) + 0.5f) / float(gl_LaunchSizeEXT.x), "
128            "(float(gl_LaunchIDEXT.y) + 0.5f) / float(gl_LaunchSizeEXT.y), 0.0);\n"
129            "  vec3  direct   = vec3(0.0, 0.0, -1.0);\n"
130            "  traceRayEXT(topLevelAS, rayFlags, cullMask, 0, 0, 0, origin, tmin, direct, tmax, 0);\n"
131            "}\n";
132 }
133 
RaytracedGeometryBase(VkGeometryTypeKHR geometryType,VkFormat vertexFormat,VkIndexType indexType)134 RaytracedGeometryBase::RaytracedGeometryBase(VkGeometryTypeKHR geometryType, VkFormat vertexFormat,
135                                              VkIndexType indexType)
136     : m_geometryType(geometryType)
137     , m_vertexFormat(vertexFormat)
138     , m_indexType(indexType)
139     , m_geometryFlags((VkGeometryFlagsKHR)0u)
140     , m_hasOpacityMicromap(false)
141 {
142     if (m_geometryType == VK_GEOMETRY_TYPE_AABBS_KHR)
143         DE_ASSERT(m_vertexFormat == VK_FORMAT_R32G32B32_SFLOAT);
144 }
145 
~RaytracedGeometryBase()146 RaytracedGeometryBase::~RaytracedGeometryBase()
147 {
148 }
149 
150 struct GeometryBuilderParams
151 {
152     VkGeometryTypeKHR geometryType;
153     bool usePadding;
154 };
155 
156 template <typename V, typename I>
buildRaytracedGeometry(const GeometryBuilderParams & params)157 RaytracedGeometryBase *buildRaytracedGeometry(const GeometryBuilderParams &params)
158 {
159     return new RaytracedGeometry<V, I>(params.geometryType, (params.usePadding ? 1u : 0u));
160 }
161 
makeRaytracedGeometry(VkGeometryTypeKHR geometryType,VkFormat vertexFormat,VkIndexType indexType,bool padVertices)162 de::SharedPtr<RaytracedGeometryBase> makeRaytracedGeometry(VkGeometryTypeKHR geometryType, VkFormat vertexFormat,
163                                                            VkIndexType indexType, bool padVertices)
164 {
165     const GeometryBuilderParams builderParams{geometryType, padVertices};
166 
167     switch (vertexFormat)
168     {
169     case VK_FORMAT_R32G32_SFLOAT:
170         switch (indexType)
171         {
172         case VK_INDEX_TYPE_UINT16:
173             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::Vec2, uint16_t>(builderParams));
174         case VK_INDEX_TYPE_UINT32:
175             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::Vec2, uint32_t>(builderParams));
176         case VK_INDEX_TYPE_NONE_KHR:
177             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::Vec2, EmptyIndex>(builderParams));
178         default:
179             TCU_THROW(InternalError, "Wrong index type");
180         }
181     case VK_FORMAT_R32G32B32_SFLOAT:
182         switch (indexType)
183         {
184         case VK_INDEX_TYPE_UINT16:
185             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::Vec3, uint16_t>(builderParams));
186         case VK_INDEX_TYPE_UINT32:
187             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::Vec3, uint32_t>(builderParams));
188         case VK_INDEX_TYPE_NONE_KHR:
189             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::Vec3, EmptyIndex>(builderParams));
190         default:
191             TCU_THROW(InternalError, "Wrong index type");
192         }
193     case VK_FORMAT_R32G32B32A32_SFLOAT:
194         switch (indexType)
195         {
196         case VK_INDEX_TYPE_UINT16:
197             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::Vec4, uint16_t>(builderParams));
198         case VK_INDEX_TYPE_UINT32:
199             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::Vec4, uint32_t>(builderParams));
200         case VK_INDEX_TYPE_NONE_KHR:
201             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::Vec4, EmptyIndex>(builderParams));
202         default:
203             TCU_THROW(InternalError, "Wrong index type");
204         }
205     case VK_FORMAT_R16G16_SFLOAT:
206         switch (indexType)
207         {
208         case VK_INDEX_TYPE_UINT16:
209             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec2_16, uint16_t>(builderParams));
210         case VK_INDEX_TYPE_UINT32:
211             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec2_16, uint32_t>(builderParams));
212         case VK_INDEX_TYPE_NONE_KHR:
213             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec2_16, EmptyIndex>(builderParams));
214         default:
215             TCU_THROW(InternalError, "Wrong index type");
216         }
217     case VK_FORMAT_R16G16B16_SFLOAT:
218         switch (indexType)
219         {
220         case VK_INDEX_TYPE_UINT16:
221             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec3_16, uint16_t>(builderParams));
222         case VK_INDEX_TYPE_UINT32:
223             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec3_16, uint32_t>(builderParams));
224         case VK_INDEX_TYPE_NONE_KHR:
225             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec3_16, EmptyIndex>(builderParams));
226         default:
227             TCU_THROW(InternalError, "Wrong index type");
228         }
229     case VK_FORMAT_R16G16B16A16_SFLOAT:
230         switch (indexType)
231         {
232         case VK_INDEX_TYPE_UINT16:
233             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec4_16, uint16_t>(builderParams));
234         case VK_INDEX_TYPE_UINT32:
235             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec4_16, uint32_t>(builderParams));
236         case VK_INDEX_TYPE_NONE_KHR:
237             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec4_16, EmptyIndex>(builderParams));
238         default:
239             TCU_THROW(InternalError, "Wrong index type");
240         }
241     case VK_FORMAT_R16G16_SNORM:
242         switch (indexType)
243         {
244         case VK_INDEX_TYPE_UINT16:
245             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec2_16SNorm, uint16_t>(builderParams));
246         case VK_INDEX_TYPE_UINT32:
247             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec2_16SNorm, uint32_t>(builderParams));
248         case VK_INDEX_TYPE_NONE_KHR:
249             return de::SharedPtr<RaytracedGeometryBase>(
250                 buildRaytracedGeometry<Vec2_16SNorm, EmptyIndex>(builderParams));
251         default:
252             TCU_THROW(InternalError, "Wrong index type");
253         }
254     case VK_FORMAT_R16G16B16_SNORM:
255         switch (indexType)
256         {
257         case VK_INDEX_TYPE_UINT16:
258             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec3_16SNorm, uint16_t>(builderParams));
259         case VK_INDEX_TYPE_UINT32:
260             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec3_16SNorm, uint32_t>(builderParams));
261         case VK_INDEX_TYPE_NONE_KHR:
262             return de::SharedPtr<RaytracedGeometryBase>(
263                 buildRaytracedGeometry<Vec3_16SNorm, EmptyIndex>(builderParams));
264         default:
265             TCU_THROW(InternalError, "Wrong index type");
266         }
267     case VK_FORMAT_R16G16B16A16_SNORM:
268         switch (indexType)
269         {
270         case VK_INDEX_TYPE_UINT16:
271             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec4_16SNorm, uint16_t>(builderParams));
272         case VK_INDEX_TYPE_UINT32:
273             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec4_16SNorm, uint32_t>(builderParams));
274         case VK_INDEX_TYPE_NONE_KHR:
275             return de::SharedPtr<RaytracedGeometryBase>(
276                 buildRaytracedGeometry<Vec4_16SNorm, EmptyIndex>(builderParams));
277         default:
278             TCU_THROW(InternalError, "Wrong index type");
279         }
280     case VK_FORMAT_R64G64_SFLOAT:
281         switch (indexType)
282         {
283         case VK_INDEX_TYPE_UINT16:
284             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::DVec2, uint16_t>(builderParams));
285         case VK_INDEX_TYPE_UINT32:
286             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::DVec2, uint32_t>(builderParams));
287         case VK_INDEX_TYPE_NONE_KHR:
288             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::DVec2, EmptyIndex>(builderParams));
289         default:
290             TCU_THROW(InternalError, "Wrong index type");
291         }
292     case VK_FORMAT_R64G64B64_SFLOAT:
293         switch (indexType)
294         {
295         case VK_INDEX_TYPE_UINT16:
296             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::DVec3, uint16_t>(builderParams));
297         case VK_INDEX_TYPE_UINT32:
298             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::DVec3, uint32_t>(builderParams));
299         case VK_INDEX_TYPE_NONE_KHR:
300             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::DVec3, EmptyIndex>(builderParams));
301         default:
302             TCU_THROW(InternalError, "Wrong index type");
303         }
304     case VK_FORMAT_R64G64B64A64_SFLOAT:
305         switch (indexType)
306         {
307         case VK_INDEX_TYPE_UINT16:
308             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::DVec4, uint16_t>(builderParams));
309         case VK_INDEX_TYPE_UINT32:
310             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::DVec4, uint32_t>(builderParams));
311         case VK_INDEX_TYPE_NONE_KHR:
312             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<tcu::DVec4, EmptyIndex>(builderParams));
313         default:
314             TCU_THROW(InternalError, "Wrong index type");
315         }
316     case VK_FORMAT_R8G8_SNORM:
317         switch (indexType)
318         {
319         case VK_INDEX_TYPE_UINT16:
320             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec2_8SNorm, uint16_t>(builderParams));
321         case VK_INDEX_TYPE_UINT32:
322             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec2_8SNorm, uint32_t>(builderParams));
323         case VK_INDEX_TYPE_NONE_KHR:
324             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec2_8SNorm, EmptyIndex>(builderParams));
325         default:
326             TCU_THROW(InternalError, "Wrong index type");
327         }
328     case VK_FORMAT_R8G8B8_SNORM:
329         switch (indexType)
330         {
331         case VK_INDEX_TYPE_UINT16:
332             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec3_8SNorm, uint16_t>(builderParams));
333         case VK_INDEX_TYPE_UINT32:
334             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec3_8SNorm, uint32_t>(builderParams));
335         case VK_INDEX_TYPE_NONE_KHR:
336             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec3_8SNorm, EmptyIndex>(builderParams));
337         default:
338             TCU_THROW(InternalError, "Wrong index type");
339         }
340     case VK_FORMAT_R8G8B8A8_SNORM:
341         switch (indexType)
342         {
343         case VK_INDEX_TYPE_UINT16:
344             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec4_8SNorm, uint16_t>(builderParams));
345         case VK_INDEX_TYPE_UINT32:
346             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec4_8SNorm, uint32_t>(builderParams));
347         case VK_INDEX_TYPE_NONE_KHR:
348             return de::SharedPtr<RaytracedGeometryBase>(buildRaytracedGeometry<Vec4_8SNorm, EmptyIndex>(builderParams));
349         default:
350             TCU_THROW(InternalError, "Wrong index type");
351         }
352     default:
353         TCU_THROW(InternalError, "Wrong vertex format");
354     }
355 }
356 
getBufferDeviceAddress(const DeviceInterface & vk,const VkDevice device,const VkBuffer buffer,VkDeviceSize offset)357 VkDeviceAddress getBufferDeviceAddress(const DeviceInterface &vk, const VkDevice device, const VkBuffer buffer,
358                                        VkDeviceSize offset)
359 {
360 
361     if (buffer == DE_NULL)
362         return 0;
363 
364     VkBufferDeviceAddressInfo deviceAddressInfo{
365         VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO, // VkStructureType    sType
366         DE_NULL,                                      // const void*        pNext
367         buffer                                        // VkBuffer           buffer;
368     };
369     return vk.getBufferDeviceAddress(device, &deviceAddressInfo) + offset;
370 }
371 
makeQueryPool(const DeviceInterface & vk,const VkDevice device,const VkQueryType queryType,uint32_t queryCount)372 static inline Move<VkQueryPool> makeQueryPool(const DeviceInterface &vk, const VkDevice device,
373                                               const VkQueryType queryType, uint32_t queryCount)
374 {
375     const VkQueryPoolCreateInfo queryPoolCreateInfo = {
376         VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO, // sType
377         DE_NULL,                                  // pNext
378         (VkQueryPoolCreateFlags)0,                // flags
379         queryType,                                // queryType
380         queryCount,                               // queryCount
381         0u,                                       // pipelineStatistics
382     };
383     return createQueryPool(vk, device, &queryPoolCreateInfo);
384 }
385 
makeVkAccelerationStructureGeometryDataKHR(const VkAccelerationStructureGeometryTrianglesDataKHR & triangles)386 static inline VkAccelerationStructureGeometryDataKHR makeVkAccelerationStructureGeometryDataKHR(
387     const VkAccelerationStructureGeometryTrianglesDataKHR &triangles)
388 {
389     VkAccelerationStructureGeometryDataKHR result;
390 
391     deMemset(&result, 0, sizeof(result));
392 
393     result.triangles = triangles;
394 
395     return result;
396 }
397 
makeVkAccelerationStructureGeometryDataKHR(const VkAccelerationStructureGeometryAabbsDataKHR & aabbs)398 static inline VkAccelerationStructureGeometryDataKHR makeVkAccelerationStructureGeometryDataKHR(
399     const VkAccelerationStructureGeometryAabbsDataKHR &aabbs)
400 {
401     VkAccelerationStructureGeometryDataKHR result;
402 
403     deMemset(&result, 0, sizeof(result));
404 
405     result.aabbs = aabbs;
406 
407     return result;
408 }
409 
makeVkAccelerationStructureInstancesDataKHR(const VkAccelerationStructureGeometryInstancesDataKHR & instances)410 static inline VkAccelerationStructureGeometryDataKHR makeVkAccelerationStructureInstancesDataKHR(
411     const VkAccelerationStructureGeometryInstancesDataKHR &instances)
412 {
413     VkAccelerationStructureGeometryDataKHR result;
414 
415     deMemset(&result, 0, sizeof(result));
416 
417     result.instances = instances;
418 
419     return result;
420 }
421 
makeVkAccelerationStructureInstanceKHR(const VkTransformMatrixKHR & transform,uint32_t instanceCustomIndex,uint32_t mask,uint32_t instanceShaderBindingTableRecordOffset,VkGeometryInstanceFlagsKHR flags,uint64_t accelerationStructureReference)422 static inline VkAccelerationStructureInstanceKHR makeVkAccelerationStructureInstanceKHR(
423     const VkTransformMatrixKHR &transform, uint32_t instanceCustomIndex, uint32_t mask,
424     uint32_t instanceShaderBindingTableRecordOffset, VkGeometryInstanceFlagsKHR flags,
425     uint64_t accelerationStructureReference)
426 {
427     VkAccelerationStructureInstanceKHR instance     = {transform, 0, 0, 0, 0, accelerationStructureReference};
428     instance.instanceCustomIndex                    = instanceCustomIndex & 0xFFFFFF;
429     instance.mask                                   = mask & 0xFF;
430     instance.instanceShaderBindingTableRecordOffset = instanceShaderBindingTableRecordOffset & 0xFFFFFF;
431     instance.flags                                  = flags & 0xFF;
432     return instance;
433 }
434 
getRayTracingShaderGroupHandlesKHR(const DeviceInterface & vk,const VkDevice device,const VkPipeline pipeline,const uint32_t firstGroup,const uint32_t groupCount,const uintptr_t dataSize,void * pData)435 VkResult getRayTracingShaderGroupHandlesKHR(const DeviceInterface &vk, const VkDevice device, const VkPipeline pipeline,
436                                             const uint32_t firstGroup, const uint32_t groupCount,
437                                             const uintptr_t dataSize, void *pData)
438 {
439     return vk.getRayTracingShaderGroupHandlesKHR(device, pipeline, firstGroup, groupCount, dataSize, pData);
440 }
441 
getRayTracingShaderGroupHandles(const DeviceInterface & vk,const VkDevice device,const VkPipeline pipeline,const uint32_t firstGroup,const uint32_t groupCount,const uintptr_t dataSize,void * pData)442 VkResult getRayTracingShaderGroupHandles(const DeviceInterface &vk, const VkDevice device, const VkPipeline pipeline,
443                                          const uint32_t firstGroup, const uint32_t groupCount, const uintptr_t dataSize,
444                                          void *pData)
445 {
446     return getRayTracingShaderGroupHandlesKHR(vk, device, pipeline, firstGroup, groupCount, dataSize, pData);
447 }
448 
getRayTracingCaptureReplayShaderGroupHandles(const DeviceInterface & vk,const VkDevice device,const VkPipeline pipeline,const uint32_t firstGroup,const uint32_t groupCount,const uintptr_t dataSize,void * pData)449 VkResult getRayTracingCaptureReplayShaderGroupHandles(const DeviceInterface &vk, const VkDevice device,
450                                                       const VkPipeline pipeline, const uint32_t firstGroup,
451                                                       const uint32_t groupCount, const uintptr_t dataSize, void *pData)
452 {
453     return vk.getRayTracingCaptureReplayShaderGroupHandlesKHR(device, pipeline, firstGroup, groupCount, dataSize,
454                                                               pData);
455 }
456 
finishDeferredOperation(const DeviceInterface & vk,VkDevice device,VkDeferredOperationKHR deferredOperation)457 VkResult finishDeferredOperation(const DeviceInterface &vk, VkDevice device, VkDeferredOperationKHR deferredOperation)
458 {
459     VkResult result = vk.deferredOperationJoinKHR(device, deferredOperation);
460 
461     while (result == VK_THREAD_IDLE_KHR)
462     {
463         std::this_thread::yield();
464         result = vk.deferredOperationJoinKHR(device, deferredOperation);
465     }
466 
467     switch (result)
468     {
469     case VK_SUCCESS:
470     {
471         // Deferred operation has finished. Query its result
472         result = vk.getDeferredOperationResultKHR(device, deferredOperation);
473 
474         break;
475     }
476 
477     case VK_THREAD_DONE_KHR:
478     {
479         // Deferred operation is being wrapped up by another thread
480         // wait for that thread to finish
481         do
482         {
483             std::this_thread::yield();
484             result = vk.getDeferredOperationResultKHR(device, deferredOperation);
485         } while (result == VK_NOT_READY);
486 
487         break;
488     }
489 
490     default:
491     {
492         DE_ASSERT(false);
493 
494         break;
495     }
496     }
497 
498     return result;
499 }
500 
finishDeferredOperationThreaded(DeferredThreadParams * deferredThreadParams)501 void finishDeferredOperationThreaded(DeferredThreadParams *deferredThreadParams)
502 {
503     deferredThreadParams->result = finishDeferredOperation(deferredThreadParams->vk, deferredThreadParams->device,
504                                                            deferredThreadParams->deferredOperation);
505 }
506 
finishDeferredOperation(const DeviceInterface & vk,VkDevice device,VkDeferredOperationKHR deferredOperation,const uint32_t workerThreadCount,const bool operationNotDeferred)507 void finishDeferredOperation(const DeviceInterface &vk, VkDevice device, VkDeferredOperationKHR deferredOperation,
508                              const uint32_t workerThreadCount, const bool operationNotDeferred)
509 {
510 
511     if (operationNotDeferred)
512     {
513         // when the operation deferral returns VK_OPERATION_NOT_DEFERRED_KHR,
514         // the deferred operation should act as if no command was deferred
515         VK_CHECK(vk.getDeferredOperationResultKHR(device, deferredOperation));
516 
517         // there is not need to join any threads to the deferred operation,
518         // so below can be skipped.
519         return;
520     }
521 
522     if (workerThreadCount == 0)
523     {
524         VK_CHECK(finishDeferredOperation(vk, device, deferredOperation));
525     }
526     else
527     {
528         const uint32_t maxThreadCountSupported =
529             deMinu32(256u, vk.getDeferredOperationMaxConcurrencyKHR(device, deferredOperation));
530         const uint32_t requestedThreadCount = workerThreadCount;
531         const uint32_t testThreadCount      = requestedThreadCount == std::numeric_limits<uint32_t>::max() ?
532                                                   maxThreadCountSupported :
533                                                   requestedThreadCount;
534 
535         if (maxThreadCountSupported == 0)
536             TCU_FAIL("vkGetDeferredOperationMaxConcurrencyKHR must not return 0");
537 
538         const DeferredThreadParams deferredThreadParams = {
539             vk,                 //  const DeviceInterface& vk;
540             device,             //  VkDevice device;
541             deferredOperation,  //  VkDeferredOperationKHR deferredOperation;
542             VK_RESULT_MAX_ENUM, //  VResult result;
543         };
544         std::vector<DeferredThreadParams> threadParams(testThreadCount, deferredThreadParams);
545         std::vector<de::MovePtr<std::thread>> threads(testThreadCount);
546         bool executionResult = false;
547 
548         DE_ASSERT(threads.size() > 0 && threads.size() == testThreadCount);
549 
550         for (uint32_t threadNdx = 0; threadNdx < testThreadCount; ++threadNdx)
551             threads[threadNdx] =
552                 de::MovePtr<std::thread>(new std::thread(finishDeferredOperationThreaded, &threadParams[threadNdx]));
553 
554         for (uint32_t threadNdx = 0; threadNdx < testThreadCount; ++threadNdx)
555             threads[threadNdx]->join();
556 
557         for (uint32_t threadNdx = 0; threadNdx < testThreadCount; ++threadNdx)
558             if (threadParams[threadNdx].result == VK_SUCCESS)
559                 executionResult = true;
560 
561         if (!executionResult)
562             TCU_FAIL("Neither reported VK_SUCCESS");
563     }
564 }
565 
SerialStorage(const DeviceInterface & vk,const VkDevice device,Allocator & allocator,const VkAccelerationStructureBuildTypeKHR buildType,const VkDeviceSize storageSize)566 SerialStorage::SerialStorage(const DeviceInterface &vk, const VkDevice device, Allocator &allocator,
567                              const VkAccelerationStructureBuildTypeKHR buildType, const VkDeviceSize storageSize)
568     : m_buildType(buildType)
569     , m_storageSize(storageSize)
570     , m_serialInfo()
571 {
572     const VkBufferCreateInfo bufferCreateInfo =
573         makeBufferCreateInfo(storageSize, VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_STORAGE_BIT_KHR |
574                                               VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
575     try
576     {
577         m_buffer = de::MovePtr<BufferWithMemory>(
578             new BufferWithMemory(vk, device, allocator, bufferCreateInfo,
579                                  MemoryRequirement::Cached | MemoryRequirement::HostVisible |
580                                      MemoryRequirement::Coherent | MemoryRequirement::DeviceAddress));
581     }
582     catch (const tcu::NotSupportedError &)
583     {
584         // retry without Cached flag
585         m_buffer = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
586             vk, device, allocator, bufferCreateInfo,
587             MemoryRequirement::HostVisible | MemoryRequirement::Coherent | MemoryRequirement::DeviceAddress));
588     }
589 }
590 
SerialStorage(const DeviceInterface & vk,const VkDevice device,Allocator & allocator,const VkAccelerationStructureBuildTypeKHR buildType,const SerialInfo & serialInfo)591 SerialStorage::SerialStorage(const DeviceInterface &vk, const VkDevice device, Allocator &allocator,
592                              const VkAccelerationStructureBuildTypeKHR buildType, const SerialInfo &serialInfo)
593     : m_buildType(buildType)
594     , m_storageSize(serialInfo.sizes()[0]) // raise assertion if serialInfo is empty
595     , m_serialInfo(serialInfo)
596 {
597     DE_ASSERT(serialInfo.sizes().size() >= 2u);
598 
599     // create buffer for top-level acceleration structure
600     {
601         const VkBufferCreateInfo bufferCreateInfo =
602             makeBufferCreateInfo(m_storageSize, VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_STORAGE_BIT_KHR |
603                                                     VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
604         m_buffer = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
605             vk, device, allocator, bufferCreateInfo,
606             MemoryRequirement::HostVisible | MemoryRequirement::Coherent | MemoryRequirement::DeviceAddress));
607     }
608 
609     // create buffers for bottom-level acceleration structures
610     {
611         std::vector<uint64_t> addrs;
612 
613         for (std::size_t i = 1; i < serialInfo.addresses().size(); ++i)
614         {
615             const uint64_t &lookAddr = serialInfo.addresses()[i];
616             auto end                 = addrs.end();
617             auto match = std::find_if(addrs.begin(), end, [&](const uint64_t &item) { return item == lookAddr; });
618             if (match == end)
619             {
620                 addrs.emplace_back(lookAddr);
621                 m_bottoms.emplace_back(de::SharedPtr<SerialStorage>(
622                     new SerialStorage(vk, device, allocator, buildType, serialInfo.sizes()[i])));
623             }
624         }
625     }
626 }
627 
getAddress(const DeviceInterface & vk,const VkDevice device,const VkAccelerationStructureBuildTypeKHR buildType)628 VkDeviceOrHostAddressKHR SerialStorage::getAddress(const DeviceInterface &vk, const VkDevice device,
629                                                    const VkAccelerationStructureBuildTypeKHR buildType)
630 {
631     if (buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
632         return makeDeviceOrHostAddressKHR(vk, device, m_buffer->get(), 0);
633     else
634         return makeDeviceOrHostAddressKHR(m_buffer->getAllocation().getHostPtr());
635 }
636 
getASHeader()637 SerialStorage::AccelerationStructureHeader *SerialStorage::getASHeader()
638 {
639     return reinterpret_cast<AccelerationStructureHeader *>(getHostAddress().hostAddress);
640 }
641 
hasDeepFormat() const642 bool SerialStorage::hasDeepFormat() const
643 {
644     return (m_serialInfo.sizes().size() >= 2u);
645 }
646 
getBottomStorage(uint32_t index) const647 de::SharedPtr<SerialStorage> SerialStorage::getBottomStorage(uint32_t index) const
648 {
649     return m_bottoms[index];
650 }
651 
getHostAddress(VkDeviceSize offset)652 VkDeviceOrHostAddressKHR SerialStorage::getHostAddress(VkDeviceSize offset)
653 {
654     DE_ASSERT(offset < m_storageSize);
655     return makeDeviceOrHostAddressKHR(static_cast<uint8_t *>(m_buffer->getAllocation().getHostPtr()) + offset);
656 }
657 
getHostAddressConst(VkDeviceSize offset)658 VkDeviceOrHostAddressConstKHR SerialStorage::getHostAddressConst(VkDeviceSize offset)
659 {
660     return makeDeviceOrHostAddressConstKHR(static_cast<uint8_t *>(m_buffer->getAllocation().getHostPtr()) + offset);
661 }
662 
getAddressConst(const DeviceInterface & vk,const VkDevice device,const VkAccelerationStructureBuildTypeKHR buildType)663 VkDeviceOrHostAddressConstKHR SerialStorage::getAddressConst(const DeviceInterface &vk, const VkDevice device,
664                                                              const VkAccelerationStructureBuildTypeKHR buildType)
665 {
666     if (buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
667         return makeDeviceOrHostAddressConstKHR(vk, device, m_buffer->get(), 0);
668     else
669         return getHostAddressConst();
670 }
671 
getStorageSize() const672 inline VkDeviceSize SerialStorage::getStorageSize() const
673 {
674     return m_storageSize;
675 }
676 
getSerialInfo() const677 inline const SerialInfo &SerialStorage::getSerialInfo() const
678 {
679     return m_serialInfo;
680 }
681 
getDeserializedSize()682 uint64_t SerialStorage::getDeserializedSize()
683 {
684     uint64_t result         = 0;
685     const uint8_t *startPtr = static_cast<uint8_t *>(m_buffer->getAllocation().getHostPtr());
686 
687     DE_ASSERT(sizeof(result) == DESERIALIZED_SIZE_SIZE);
688 
689     deMemcpy(&result, startPtr + DESERIALIZED_SIZE_OFFSET, sizeof(result));
690 
691     return result;
692 }
693 
~BottomLevelAccelerationStructure()694 BottomLevelAccelerationStructure::~BottomLevelAccelerationStructure()
695 {
696 }
697 
BottomLevelAccelerationStructure()698 BottomLevelAccelerationStructure::BottomLevelAccelerationStructure()
699     : m_structureSize(0u)
700     , m_updateScratchSize(0u)
701     , m_buildScratchSize(0u)
702 {
703 }
704 
setGeometryData(const std::vector<tcu::Vec3> & geometryData,const bool triangles,const VkGeometryFlagsKHR geometryFlags)705 void BottomLevelAccelerationStructure::setGeometryData(const std::vector<tcu::Vec3> &geometryData, const bool triangles,
706                                                        const VkGeometryFlagsKHR geometryFlags)
707 {
708     if (triangles)
709         DE_ASSERT((geometryData.size() % 3) == 0);
710     else
711         DE_ASSERT((geometryData.size() % 2) == 0);
712 
713     setGeometryCount(1u);
714 
715     addGeometry(geometryData, triangles, geometryFlags);
716 }
717 
setDefaultGeometryData(const VkShaderStageFlagBits testStage,const VkGeometryFlagsKHR geometryFlags)718 void BottomLevelAccelerationStructure::setDefaultGeometryData(const VkShaderStageFlagBits testStage,
719                                                               const VkGeometryFlagsKHR geometryFlags)
720 {
721     bool trianglesData = false;
722     float z            = 0.0f;
723     std::vector<tcu::Vec3> geometryData;
724 
725     switch (testStage)
726     {
727     case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
728         z             = -1.0f;
729         trianglesData = true;
730         break;
731     case VK_SHADER_STAGE_ANY_HIT_BIT_KHR:
732         z             = -1.0f;
733         trianglesData = true;
734         break;
735     case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
736         z             = -1.0f;
737         trianglesData = true;
738         break;
739     case VK_SHADER_STAGE_MISS_BIT_KHR:
740         z             = -9.9f;
741         trianglesData = true;
742         break;
743     case VK_SHADER_STAGE_INTERSECTION_BIT_KHR:
744         z             = -1.0f;
745         trianglesData = false;
746         break;
747     case VK_SHADER_STAGE_CALLABLE_BIT_KHR:
748         z             = -1.0f;
749         trianglesData = true;
750         break;
751     default:
752         TCU_THROW(InternalError, "Unacceptable stage");
753     }
754 
755     if (trianglesData)
756     {
757         geometryData.reserve(6);
758 
759         geometryData.push_back(tcu::Vec3(-1.0f, -1.0f, z));
760         geometryData.push_back(tcu::Vec3(-1.0f, +1.0f, z));
761         geometryData.push_back(tcu::Vec3(+1.0f, -1.0f, z));
762         geometryData.push_back(tcu::Vec3(+1.0f, -1.0f, z));
763         geometryData.push_back(tcu::Vec3(-1.0f, +1.0f, z));
764         geometryData.push_back(tcu::Vec3(+1.0f, +1.0f, z));
765     }
766     else
767     {
768         geometryData.reserve(2);
769 
770         geometryData.push_back(tcu::Vec3(-1.0f, -1.0f, z));
771         geometryData.push_back(tcu::Vec3(+1.0f, +1.0f, z));
772     }
773 
774     setGeometryCount(1u);
775 
776     addGeometry(geometryData, trianglesData, geometryFlags);
777 }
778 
setGeometryCount(const size_t geometryCount)779 void BottomLevelAccelerationStructure::setGeometryCount(const size_t geometryCount)
780 {
781     m_geometriesData.clear();
782 
783     m_geometriesData.reserve(geometryCount);
784 }
785 
addGeometry(de::SharedPtr<RaytracedGeometryBase> & raytracedGeometry)786 void BottomLevelAccelerationStructure::addGeometry(de::SharedPtr<RaytracedGeometryBase> &raytracedGeometry)
787 {
788     m_geometriesData.push_back(raytracedGeometry);
789 }
790 
addGeometry(const std::vector<tcu::Vec3> & geometryData,const bool triangles,const VkGeometryFlagsKHR geometryFlags,const VkAccelerationStructureTrianglesOpacityMicromapEXT * opacityGeometryMicromap)791 void BottomLevelAccelerationStructure::addGeometry(
792     const std::vector<tcu::Vec3> &geometryData, const bool triangles, const VkGeometryFlagsKHR geometryFlags,
793     const VkAccelerationStructureTrianglesOpacityMicromapEXT *opacityGeometryMicromap)
794 {
795     DE_ASSERT(geometryData.size() > 0);
796     DE_ASSERT((triangles && geometryData.size() % 3 == 0) || (!triangles && geometryData.size() % 2 == 0));
797 
798     if (!triangles)
799         for (size_t posNdx = 0; posNdx < geometryData.size() / 2; ++posNdx)
800         {
801             DE_ASSERT(geometryData[2 * posNdx].x() <= geometryData[2 * posNdx + 1].x());
802             DE_ASSERT(geometryData[2 * posNdx].y() <= geometryData[2 * posNdx + 1].y());
803             DE_ASSERT(geometryData[2 * posNdx].z() <= geometryData[2 * posNdx + 1].z());
804         }
805 
806     de::SharedPtr<RaytracedGeometryBase> geometry =
807         makeRaytracedGeometry(triangles ? VK_GEOMETRY_TYPE_TRIANGLES_KHR : VK_GEOMETRY_TYPE_AABBS_KHR,
808                               VK_FORMAT_R32G32B32_SFLOAT, VK_INDEX_TYPE_NONE_KHR);
809     for (auto it = begin(geometryData), eit = end(geometryData); it != eit; ++it)
810         geometry->addVertex(*it);
811 
812     geometry->setGeometryFlags(geometryFlags);
813     if (opacityGeometryMicromap)
814         geometry->setOpacityMicromap(opacityGeometryMicromap);
815     addGeometry(geometry);
816 }
817 
getStructureBuildSizes() const818 VkAccelerationStructureBuildSizesInfoKHR BottomLevelAccelerationStructure::getStructureBuildSizes() const
819 {
820     return {
821         VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_SIZES_INFO_KHR, //  VkStructureType sType;
822         DE_NULL,                                                       //  const void* pNext;
823         m_structureSize,                                               //  VkDeviceSize accelerationStructureSize;
824         m_updateScratchSize,                                           //  VkDeviceSize updateScratchSize;
825         m_buildScratchSize                                             //  VkDeviceSize buildScratchSize;
826     };
827 };
828 
getVertexBufferSize(const std::vector<de::SharedPtr<RaytracedGeometryBase>> & geometriesData)829 VkDeviceSize getVertexBufferSize(const std::vector<de::SharedPtr<RaytracedGeometryBase>> &geometriesData)
830 {
831     DE_ASSERT(geometriesData.size() != 0);
832     VkDeviceSize bufferSizeBytes = 0;
833     for (size_t geometryNdx = 0; geometryNdx < geometriesData.size(); ++geometryNdx)
834         bufferSizeBytes += deAlignSize(geometriesData[geometryNdx]->getVertexByteSize(), 8);
835     return bufferSizeBytes;
836 }
837 
createVertexBuffer(const DeviceInterface & vk,const VkDevice device,Allocator & allocator,const VkDeviceSize bufferSizeBytes)838 BufferWithMemory *createVertexBuffer(const DeviceInterface &vk, const VkDevice device, Allocator &allocator,
839                                      const VkDeviceSize bufferSizeBytes)
840 {
841     const VkBufferCreateInfo bufferCreateInfo =
842         makeBufferCreateInfo(bufferSizeBytes, VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_BUILD_INPUT_READ_ONLY_BIT_KHR |
843                                                   VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
844     return new BufferWithMemory(vk, device, allocator, bufferCreateInfo,
845                                 MemoryRequirement::HostVisible | MemoryRequirement::Coherent |
846                                     MemoryRequirement::DeviceAddress);
847 }
848 
createVertexBuffer(const DeviceInterface & vk,const VkDevice device,Allocator & allocator,const std::vector<de::SharedPtr<RaytracedGeometryBase>> & geometriesData)849 BufferWithMemory *createVertexBuffer(const DeviceInterface &vk, const VkDevice device, Allocator &allocator,
850                                      const std::vector<de::SharedPtr<RaytracedGeometryBase>> &geometriesData)
851 {
852     return createVertexBuffer(vk, device, allocator, getVertexBufferSize(geometriesData));
853 }
854 
updateVertexBuffer(const DeviceInterface & vk,const VkDevice device,const std::vector<de::SharedPtr<RaytracedGeometryBase>> & geometriesData,BufferWithMemory * vertexBuffer,VkDeviceSize geometriesOffset=0)855 void updateVertexBuffer(const DeviceInterface &vk, const VkDevice device,
856                         const std::vector<de::SharedPtr<RaytracedGeometryBase>> &geometriesData,
857                         BufferWithMemory *vertexBuffer, VkDeviceSize geometriesOffset = 0)
858 {
859     const Allocation &geometryAlloc = vertexBuffer->getAllocation();
860     uint8_t *bufferStart            = static_cast<uint8_t *>(geometryAlloc.getHostPtr());
861     VkDeviceSize bufferOffset       = geometriesOffset;
862 
863     for (size_t geometryNdx = 0; geometryNdx < geometriesData.size(); ++geometryNdx)
864     {
865         const void *geometryPtr      = geometriesData[geometryNdx]->getVertexPointer();
866         const size_t geometryPtrSize = geometriesData[geometryNdx]->getVertexByteSize();
867 
868         deMemcpy(&bufferStart[bufferOffset], geometryPtr, geometryPtrSize);
869 
870         bufferOffset += deAlignSize(geometryPtrSize, 8);
871     }
872 
873     // Flush the whole allocation. We could flush only the interesting range, but we'd need to be sure both the offset and size
874     // align to VkPhysicalDeviceLimits::nonCoherentAtomSize, which we are not considering. Also note most code uses Coherent memory
875     // for the vertex and index buffers, so flushing is actually not needed.
876     flushAlloc(vk, device, geometryAlloc);
877 }
878 
getIndexBufferSize(const std::vector<de::SharedPtr<RaytracedGeometryBase>> & geometriesData)879 VkDeviceSize getIndexBufferSize(const std::vector<de::SharedPtr<RaytracedGeometryBase>> &geometriesData)
880 {
881     DE_ASSERT(!geometriesData.empty());
882 
883     VkDeviceSize bufferSizeBytes = 0;
884     for (size_t geometryNdx = 0; geometryNdx < geometriesData.size(); ++geometryNdx)
885         if (geometriesData[geometryNdx]->getIndexType() != VK_INDEX_TYPE_NONE_KHR)
886             bufferSizeBytes += deAlignSize(geometriesData[geometryNdx]->getIndexByteSize(), 8);
887     return bufferSizeBytes;
888 }
889 
createIndexBuffer(const DeviceInterface & vk,const VkDevice device,Allocator & allocator,const VkDeviceSize bufferSizeBytes)890 BufferWithMemory *createIndexBuffer(const DeviceInterface &vk, const VkDevice device, Allocator &allocator,
891                                     const VkDeviceSize bufferSizeBytes)
892 {
893     DE_ASSERT(bufferSizeBytes);
894     const VkBufferCreateInfo bufferCreateInfo =
895         makeBufferCreateInfo(bufferSizeBytes, VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_BUILD_INPUT_READ_ONLY_BIT_KHR |
896                                                   VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
897     return new BufferWithMemory(vk, device, allocator, bufferCreateInfo,
898                                 MemoryRequirement::HostVisible | MemoryRequirement::Coherent |
899                                     MemoryRequirement::DeviceAddress);
900 }
901 
createIndexBuffer(const DeviceInterface & vk,const VkDevice device,Allocator & allocator,const std::vector<de::SharedPtr<RaytracedGeometryBase>> & geometriesData)902 BufferWithMemory *createIndexBuffer(const DeviceInterface &vk, const VkDevice device, Allocator &allocator,
903                                     const std::vector<de::SharedPtr<RaytracedGeometryBase>> &geometriesData)
904 {
905     const VkDeviceSize bufferSizeBytes = getIndexBufferSize(geometriesData);
906     return bufferSizeBytes ? createIndexBuffer(vk, device, allocator, bufferSizeBytes) : nullptr;
907 }
908 
updateIndexBuffer(const DeviceInterface & vk,const VkDevice device,const std::vector<de::SharedPtr<RaytracedGeometryBase>> & geometriesData,BufferWithMemory * indexBuffer,VkDeviceSize geometriesOffset)909 void updateIndexBuffer(const DeviceInterface &vk, const VkDevice device,
910                        const std::vector<de::SharedPtr<RaytracedGeometryBase>> &geometriesData,
911                        BufferWithMemory *indexBuffer, VkDeviceSize geometriesOffset)
912 {
913     const Allocation &indexAlloc = indexBuffer->getAllocation();
914     uint8_t *bufferStart         = static_cast<uint8_t *>(indexAlloc.getHostPtr());
915     VkDeviceSize bufferOffset    = geometriesOffset;
916 
917     for (size_t geometryNdx = 0; geometryNdx < geometriesData.size(); ++geometryNdx)
918     {
919         if (geometriesData[geometryNdx]->getIndexType() != VK_INDEX_TYPE_NONE_KHR)
920         {
921             const void *indexPtr      = geometriesData[geometryNdx]->getIndexPointer();
922             const size_t indexPtrSize = geometriesData[geometryNdx]->getIndexByteSize();
923 
924             deMemcpy(&bufferStart[bufferOffset], indexPtr, indexPtrSize);
925 
926             bufferOffset += deAlignSize(indexPtrSize, 8);
927         }
928     }
929 
930     // Flush the whole allocation. We could flush only the interesting range, but we'd need to be sure both the offset and size
931     // align to VkPhysicalDeviceLimits::nonCoherentAtomSize, which we are not considering. Also note most code uses Coherent memory
932     // for the vertex and index buffers, so flushing is actually not needed.
933     flushAlloc(vk, device, indexAlloc);
934 }
935 
936 class BottomLevelAccelerationStructureKHR : public BottomLevelAccelerationStructure
937 {
938 public:
939     static uint32_t getRequiredAllocationCount(void);
940 
941     BottomLevelAccelerationStructureKHR();
942     BottomLevelAccelerationStructureKHR(const BottomLevelAccelerationStructureKHR &other) = delete;
943     virtual ~BottomLevelAccelerationStructureKHR();
944 
945     void setBuildType(const VkAccelerationStructureBuildTypeKHR buildType) override;
946     VkAccelerationStructureBuildTypeKHR getBuildType() const override;
947     void setCreateFlags(const VkAccelerationStructureCreateFlagsKHR createFlags) override;
948     void setCreateGeneric(bool createGeneric) override;
949     void setCreationBufferUnbounded(bool creationBufferUnbounded) override;
950     void setBuildFlags(const VkBuildAccelerationStructureFlagsKHR buildFlags) override;
951     void setBuildWithoutGeometries(bool buildWithoutGeometries) override;
952     void setBuildWithoutPrimitives(bool buildWithoutPrimitives) override;
953     void setDeferredOperation(const bool deferredOperation, const uint32_t workerThreadCount) override;
954     void setUseArrayOfPointers(const bool useArrayOfPointers) override;
955     void setUseMaintenance5(const bool useMaintenance5) override;
956     void setIndirectBuildParameters(const VkBuffer indirectBuffer, const VkDeviceSize indirectBufferOffset,
957                                     const uint32_t indirectBufferStride) override;
958     VkBuildAccelerationStructureFlagsKHR getBuildFlags() const override;
959 
960     void create(const DeviceInterface &vk, const VkDevice device, Allocator &allocator, VkDeviceSize structureSize,
961                 VkDeviceAddress deviceAddress = 0u, const void *pNext = DE_NULL,
962                 const MemoryRequirement &addMemoryRequirement = MemoryRequirement::Any,
963                 const VkBuffer creationBuffer = VK_NULL_HANDLE, const VkDeviceSize creationBufferSize = 0u) override;
964     void build(const DeviceInterface &vk, const VkDevice device, const VkCommandBuffer cmdBuffer,
965                BottomLevelAccelerationStructure *srcAccelerationStructure = DE_NULL) override;
966     void copyFrom(const DeviceInterface &vk, const VkDevice device, const VkCommandBuffer cmdBuffer,
967                   BottomLevelAccelerationStructure *accelerationStructure, bool compactCopy) override;
968 
969     void serialize(const DeviceInterface &vk, const VkDevice device, const VkCommandBuffer cmdBuffer,
970                    SerialStorage *storage) override;
971     void deserialize(const DeviceInterface &vk, const VkDevice device, const VkCommandBuffer cmdBuffer,
972                      SerialStorage *storage) override;
973 
974     const VkAccelerationStructureKHR *getPtr(void) const override;
975     void updateGeometry(size_t geometryIndex, de::SharedPtr<RaytracedGeometryBase> &raytracedGeometry) override;
976 
977 protected:
978     VkAccelerationStructureBuildTypeKHR m_buildType;
979     VkAccelerationStructureCreateFlagsKHR m_createFlags;
980     bool m_createGeneric;
981     bool m_creationBufferUnbounded;
982     VkBuildAccelerationStructureFlagsKHR m_buildFlags;
983     bool m_buildWithoutGeometries;
984     bool m_buildWithoutPrimitives;
985     bool m_deferredOperation;
986     uint32_t m_workerThreadCount;
987     bool m_useArrayOfPointers;
988     bool m_useMaintenance5;
989     de::MovePtr<BufferWithMemory> m_accelerationStructureBuffer;
990     de::MovePtr<BufferWithMemory> m_vertexBuffer;
991     de::MovePtr<BufferWithMemory> m_indexBuffer;
992     de::MovePtr<BufferWithMemory> m_deviceScratchBuffer;
993     de::UniquePtr<std::vector<uint8_t>> m_hostScratchBuffer;
994     Move<VkAccelerationStructureKHR> m_accelerationStructureKHR;
995     VkBuffer m_indirectBuffer;
996     VkDeviceSize m_indirectBufferOffset;
997     uint32_t m_indirectBufferStride;
998 
999     void prepareGeometries(
1000         const DeviceInterface &vk, const VkDevice device,
1001         std::vector<VkAccelerationStructureGeometryKHR> &accelerationStructureGeometriesKHR,
1002         std::vector<VkAccelerationStructureGeometryKHR *> &accelerationStructureGeometriesKHRPointers,
1003         std::vector<VkAccelerationStructureBuildRangeInfoKHR> &accelerationStructureBuildRangeInfoKHR,
1004         std::vector<VkAccelerationStructureTrianglesOpacityMicromapEXT> &accelerationStructureGeometryMicromapsEXT,
1005         std::vector<uint32_t> &maxPrimitiveCounts, VkDeviceSize vertexBufferOffset = 0,
1006         VkDeviceSize indexBufferOffset = 0) const;
1007 
getAccelerationStructureBuffer() const1008     virtual BufferWithMemory *getAccelerationStructureBuffer() const
1009     {
1010         return m_accelerationStructureBuffer.get();
1011     }
getDeviceScratchBuffer() const1012     virtual BufferWithMemory *getDeviceScratchBuffer() const
1013     {
1014         return m_deviceScratchBuffer.get();
1015     }
getHostScratchBuffer() const1016     virtual std::vector<uint8_t> *getHostScratchBuffer() const
1017     {
1018         return m_hostScratchBuffer.get();
1019     }
getVertexBuffer() const1020     virtual BufferWithMemory *getVertexBuffer() const
1021     {
1022         return m_vertexBuffer.get();
1023     }
getIndexBuffer() const1024     virtual BufferWithMemory *getIndexBuffer() const
1025     {
1026         return m_indexBuffer.get();
1027     }
1028 
getAccelerationStructureBufferOffset() const1029     virtual VkDeviceSize getAccelerationStructureBufferOffset() const
1030     {
1031         return 0;
1032     }
getDeviceScratchBufferOffset() const1033     virtual VkDeviceSize getDeviceScratchBufferOffset() const
1034     {
1035         return 0;
1036     }
getVertexBufferOffset() const1037     virtual VkDeviceSize getVertexBufferOffset() const
1038     {
1039         return 0;
1040     }
getIndexBufferOffset() const1041     virtual VkDeviceSize getIndexBufferOffset() const
1042     {
1043         return 0;
1044     }
1045 };
1046 
getRequiredAllocationCount(void)1047 uint32_t BottomLevelAccelerationStructureKHR::getRequiredAllocationCount(void)
1048 {
1049     /*
1050         de::MovePtr<BufferWithMemory>                            m_geometryBuffer; // but only when m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR
1051         de::MovePtr<Allocation>                                    m_accelerationStructureAlloc;
1052         de::MovePtr<BufferWithMemory>                            m_deviceScratchBuffer;
1053     */
1054     return 3u;
1055 }
1056 
~BottomLevelAccelerationStructureKHR()1057 BottomLevelAccelerationStructureKHR::~BottomLevelAccelerationStructureKHR()
1058 {
1059 }
1060 
BottomLevelAccelerationStructureKHR()1061 BottomLevelAccelerationStructureKHR::BottomLevelAccelerationStructureKHR()
1062     : BottomLevelAccelerationStructure()
1063     , m_buildType(VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
1064     , m_createFlags(0u)
1065     , m_createGeneric(false)
1066     , m_creationBufferUnbounded(false)
1067     , m_buildFlags(0u)
1068     , m_buildWithoutGeometries(false)
1069     , m_buildWithoutPrimitives(false)
1070     , m_deferredOperation(false)
1071     , m_workerThreadCount(0)
1072     , m_useArrayOfPointers(false)
1073     , m_useMaintenance5(false)
1074     , m_accelerationStructureBuffer(DE_NULL)
1075     , m_vertexBuffer(DE_NULL)
1076     , m_indexBuffer(DE_NULL)
1077     , m_deviceScratchBuffer(DE_NULL)
1078     , m_hostScratchBuffer(new std::vector<uint8_t>)
1079     , m_accelerationStructureKHR()
1080     , m_indirectBuffer(DE_NULL)
1081     , m_indirectBufferOffset(0)
1082     , m_indirectBufferStride(0)
1083 {
1084 }
1085 
setBuildType(const VkAccelerationStructureBuildTypeKHR buildType)1086 void BottomLevelAccelerationStructureKHR::setBuildType(const VkAccelerationStructureBuildTypeKHR buildType)
1087 {
1088     m_buildType = buildType;
1089 }
1090 
getBuildType() const1091 VkAccelerationStructureBuildTypeKHR BottomLevelAccelerationStructureKHR::getBuildType() const
1092 {
1093     return m_buildType;
1094 }
1095 
setCreateFlags(const VkAccelerationStructureCreateFlagsKHR createFlags)1096 void BottomLevelAccelerationStructureKHR::setCreateFlags(const VkAccelerationStructureCreateFlagsKHR createFlags)
1097 {
1098     m_createFlags = createFlags;
1099 }
1100 
setCreateGeneric(bool createGeneric)1101 void BottomLevelAccelerationStructureKHR::setCreateGeneric(bool createGeneric)
1102 {
1103     m_createGeneric = createGeneric;
1104 }
1105 
setCreationBufferUnbounded(bool creationBufferUnbounded)1106 void BottomLevelAccelerationStructureKHR::setCreationBufferUnbounded(bool creationBufferUnbounded)
1107 {
1108     m_creationBufferUnbounded = creationBufferUnbounded;
1109 }
1110 
setBuildFlags(const VkBuildAccelerationStructureFlagsKHR buildFlags)1111 void BottomLevelAccelerationStructureKHR::setBuildFlags(const VkBuildAccelerationStructureFlagsKHR buildFlags)
1112 {
1113     m_buildFlags = buildFlags;
1114 }
1115 
setBuildWithoutGeometries(bool buildWithoutGeometries)1116 void BottomLevelAccelerationStructureKHR::setBuildWithoutGeometries(bool buildWithoutGeometries)
1117 {
1118     m_buildWithoutGeometries = buildWithoutGeometries;
1119 }
1120 
setBuildWithoutPrimitives(bool buildWithoutPrimitives)1121 void BottomLevelAccelerationStructureKHR::setBuildWithoutPrimitives(bool buildWithoutPrimitives)
1122 {
1123     m_buildWithoutPrimitives = buildWithoutPrimitives;
1124 }
1125 
setDeferredOperation(const bool deferredOperation,const uint32_t workerThreadCount)1126 void BottomLevelAccelerationStructureKHR::setDeferredOperation(const bool deferredOperation,
1127                                                                const uint32_t workerThreadCount)
1128 {
1129     m_deferredOperation = deferredOperation;
1130     m_workerThreadCount = workerThreadCount;
1131 }
1132 
setUseArrayOfPointers(const bool useArrayOfPointers)1133 void BottomLevelAccelerationStructureKHR::setUseArrayOfPointers(const bool useArrayOfPointers)
1134 {
1135     m_useArrayOfPointers = useArrayOfPointers;
1136 }
1137 
setUseMaintenance5(const bool useMaintenance5)1138 void BottomLevelAccelerationStructureKHR::setUseMaintenance5(const bool useMaintenance5)
1139 {
1140     m_useMaintenance5 = useMaintenance5;
1141 }
1142 
setIndirectBuildParameters(const VkBuffer indirectBuffer,const VkDeviceSize indirectBufferOffset,const uint32_t indirectBufferStride)1143 void BottomLevelAccelerationStructureKHR::setIndirectBuildParameters(const VkBuffer indirectBuffer,
1144                                                                      const VkDeviceSize indirectBufferOffset,
1145                                                                      const uint32_t indirectBufferStride)
1146 {
1147     m_indirectBuffer       = indirectBuffer;
1148     m_indirectBufferOffset = indirectBufferOffset;
1149     m_indirectBufferStride = indirectBufferStride;
1150 }
1151 
getBuildFlags() const1152 VkBuildAccelerationStructureFlagsKHR BottomLevelAccelerationStructureKHR::getBuildFlags() const
1153 {
1154     return m_buildFlags;
1155 }
1156 
create(const DeviceInterface & vk,const VkDevice device,Allocator & allocator,VkDeviceSize structureSize,VkDeviceAddress deviceAddress,const void * pNext,const MemoryRequirement & addMemoryRequirement,const VkBuffer creationBuffer,const VkDeviceSize creationBufferSize)1157 void BottomLevelAccelerationStructureKHR::create(const DeviceInterface &vk, const VkDevice device, Allocator &allocator,
1158                                                  VkDeviceSize structureSize, VkDeviceAddress deviceAddress,
1159                                                  const void *pNext, const MemoryRequirement &addMemoryRequirement,
1160                                                  const VkBuffer creationBuffer, const VkDeviceSize creationBufferSize)
1161 {
1162     // AS may be built from geometries using vkCmdBuildAccelerationStructuresKHR / vkBuildAccelerationStructuresKHR
1163     // or may be copied/compacted/deserialized from other AS ( in this case AS does not need geometries, but it needs to know its size before creation ).
1164     DE_ASSERT(!m_geometriesData.empty() != !(structureSize == 0)); // logical xor
1165 
1166     if (structureSize == 0)
1167     {
1168         std::vector<VkAccelerationStructureGeometryKHR> accelerationStructureGeometriesKHR;
1169         std::vector<VkAccelerationStructureGeometryKHR *> accelerationStructureGeometriesKHRPointers;
1170         std::vector<VkAccelerationStructureBuildRangeInfoKHR> accelerationStructureBuildRangeInfoKHR;
1171         std::vector<VkAccelerationStructureTrianglesOpacityMicromapEXT> accelerationStructureGeometryMicromapsEXT;
1172         std::vector<uint32_t> maxPrimitiveCounts;
1173         prepareGeometries(vk, device, accelerationStructureGeometriesKHR, accelerationStructureGeometriesKHRPointers,
1174                           accelerationStructureBuildRangeInfoKHR, accelerationStructureGeometryMicromapsEXT,
1175                           maxPrimitiveCounts);
1176 
1177         const VkAccelerationStructureGeometryKHR *accelerationStructureGeometriesKHRPointer =
1178             accelerationStructureGeometriesKHR.data();
1179         const VkAccelerationStructureGeometryKHR *const *accelerationStructureGeometry =
1180             accelerationStructureGeometriesKHRPointers.data();
1181 
1182         const uint32_t geometryCount =
1183             (m_buildWithoutGeometries ? 0u : static_cast<uint32_t>(accelerationStructureGeometriesKHR.size()));
1184         VkAccelerationStructureBuildGeometryInfoKHR accelerationStructureBuildGeometryInfoKHR = {
1185             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_GEOMETRY_INFO_KHR, //  VkStructureType sType;
1186             DE_NULL,                                                          //  const void* pNext;
1187             VK_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL_KHR,                  //  VkAccelerationStructureTypeKHR type;
1188             m_buildFlags,                                   //  VkBuildAccelerationStructureFlagsKHR flags;
1189             VK_BUILD_ACCELERATION_STRUCTURE_MODE_BUILD_KHR, //  VkBuildAccelerationStructureModeKHR mode;
1190             DE_NULL,                                        //  VkAccelerationStructureKHR srcAccelerationStructure;
1191             DE_NULL,                                        //  VkAccelerationStructureKHR dstAccelerationStructure;
1192             geometryCount,                                  //  uint32_t geometryCount;
1193             m_useArrayOfPointers ?
1194                 DE_NULL :
1195                 accelerationStructureGeometriesKHRPointer, //  const VkAccelerationStructureGeometryKHR* pGeometries;
1196             m_useArrayOfPointers ? accelerationStructureGeometry :
1197                                    DE_NULL,     //  const VkAccelerationStructureGeometryKHR* const* ppGeometries;
1198             makeDeviceOrHostAddressKHR(DE_NULL) //  VkDeviceOrHostAddressKHR scratchData;
1199         };
1200         VkAccelerationStructureBuildSizesInfoKHR sizeInfo = {
1201             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_SIZES_INFO_KHR, //  VkStructureType sType;
1202             DE_NULL,                                                       //  const void* pNext;
1203             0,                                                             //  VkDeviceSize accelerationStructureSize;
1204             0,                                                             //  VkDeviceSize updateScratchSize;
1205             0                                                              //  VkDeviceSize buildScratchSize;
1206         };
1207 
1208         vk.getAccelerationStructureBuildSizesKHR(device, m_buildType, &accelerationStructureBuildGeometryInfoKHR,
1209                                                  maxPrimitiveCounts.data(), &sizeInfo);
1210 
1211         m_structureSize     = sizeInfo.accelerationStructureSize;
1212         m_updateScratchSize = sizeInfo.updateScratchSize;
1213         m_buildScratchSize  = sizeInfo.buildScratchSize;
1214     }
1215     else
1216     {
1217         m_structureSize     = structureSize;
1218         m_updateScratchSize = 0u;
1219         m_buildScratchSize  = 0u;
1220     }
1221 
1222     const bool externalCreationBuffer = (creationBuffer != VK_NULL_HANDLE);
1223 
1224     if (externalCreationBuffer)
1225     {
1226         DE_UNREF(creationBufferSize); // For release builds.
1227         DE_ASSERT(creationBufferSize >= m_structureSize);
1228     }
1229 
1230     if (!externalCreationBuffer)
1231     {
1232         VkBufferCreateInfo bufferCreateInfo =
1233             makeBufferCreateInfo(m_structureSize, VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_STORAGE_BIT_KHR |
1234                                                       VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
1235         VkBufferUsageFlags2CreateInfoKHR bufferUsageFlags2 = vk::initVulkanStructure();
1236 
1237         if (m_useMaintenance5)
1238         {
1239             bufferUsageFlags2.usage = VK_BUFFER_USAGE_2_ACCELERATION_STRUCTURE_STORAGE_BIT_KHR |
1240                                       VK_BUFFER_USAGE_2_SHADER_DEVICE_ADDRESS_BIT_KHR;
1241             bufferCreateInfo.pNext = &bufferUsageFlags2;
1242             bufferCreateInfo.usage = 0;
1243         }
1244 
1245         const MemoryRequirement memoryRequirement = addMemoryRequirement | MemoryRequirement::HostVisible |
1246                                                     MemoryRequirement::Coherent | MemoryRequirement::DeviceAddress;
1247         const bool bindMemOnCreation = (!m_creationBufferUnbounded);
1248 
1249         try
1250         {
1251             m_accelerationStructureBuffer = de::MovePtr<BufferWithMemory>(
1252                 new BufferWithMemory(vk, device, allocator, bufferCreateInfo,
1253                                      (MemoryRequirement::Cached | memoryRequirement), bindMemOnCreation));
1254         }
1255         catch (const tcu::NotSupportedError &)
1256         {
1257             // retry without Cached flag
1258             m_accelerationStructureBuffer = de::MovePtr<BufferWithMemory>(
1259                 new BufferWithMemory(vk, device, allocator, bufferCreateInfo, memoryRequirement, bindMemOnCreation));
1260         }
1261     }
1262 
1263     const auto createInfoBuffer = (externalCreationBuffer ? creationBuffer : getAccelerationStructureBuffer()->get());
1264     const auto createInfoOffset =
1265         (externalCreationBuffer ? static_cast<VkDeviceSize>(0) : getAccelerationStructureBufferOffset());
1266     {
1267         const VkAccelerationStructureTypeKHR structureType =
1268             (m_createGeneric ? VK_ACCELERATION_STRUCTURE_TYPE_GENERIC_KHR :
1269                                VK_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL_KHR);
1270         const VkAccelerationStructureCreateInfoKHR accelerationStructureCreateInfoKHR{
1271             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_CREATE_INFO_KHR, //  VkStructureType sType;
1272             pNext,                                                    //  const void* pNext;
1273             m_createFlags,    //  VkAccelerationStructureCreateFlagsKHR createFlags;
1274             createInfoBuffer, //  VkBuffer buffer;
1275             createInfoOffset, //  VkDeviceSize offset;
1276             m_structureSize,  //  VkDeviceSize size;
1277             structureType,    //  VkAccelerationStructureTypeKHR type;
1278             deviceAddress     //  VkDeviceAddress deviceAddress;
1279         };
1280 
1281         m_accelerationStructureKHR =
1282             createAccelerationStructureKHR(vk, device, &accelerationStructureCreateInfoKHR, DE_NULL);
1283 
1284         // Make sure buffer memory is always bound after creation.
1285         if (!externalCreationBuffer)
1286             m_accelerationStructureBuffer->bindMemory();
1287     }
1288 
1289     if (m_buildScratchSize > 0u || m_updateScratchSize > 0u)
1290     {
1291         VkDeviceSize scratch_size = de::max(m_buildScratchSize, m_updateScratchSize);
1292         if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
1293         {
1294             const VkBufferCreateInfo bufferCreateInfo = makeBufferCreateInfo(
1295                 scratch_size, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT | VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
1296             m_deviceScratchBuffer = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
1297                 vk, device, allocator, bufferCreateInfo,
1298                 MemoryRequirement::HostVisible | MemoryRequirement::Coherent | MemoryRequirement::DeviceAddress));
1299         }
1300         else
1301         {
1302             m_hostScratchBuffer->resize(static_cast<size_t>(scratch_size));
1303         }
1304     }
1305 
1306     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR && !m_geometriesData.empty())
1307     {
1308         VkBufferCreateInfo bufferCreateInfo =
1309             makeBufferCreateInfo(getVertexBufferSize(m_geometriesData),
1310                                  VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_BUILD_INPUT_READ_ONLY_BIT_KHR |
1311                                      VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
1312         VkBufferUsageFlags2CreateInfoKHR bufferUsageFlags2 = vk::initVulkanStructure();
1313 
1314         if (m_useMaintenance5)
1315         {
1316             bufferUsageFlags2.usage = vk::VK_BUFFER_USAGE_2_ACCELERATION_STRUCTURE_BUILD_INPUT_READ_ONLY_BIT_KHR |
1317                                       VK_BUFFER_USAGE_2_SHADER_DEVICE_ADDRESS_BIT_KHR;
1318             bufferCreateInfo.pNext = &bufferUsageFlags2;
1319             bufferCreateInfo.usage = 0;
1320         }
1321 
1322         const vk::MemoryRequirement memoryRequirement =
1323             MemoryRequirement::HostVisible | MemoryRequirement::Coherent | MemoryRequirement::DeviceAddress;
1324         m_vertexBuffer = de::MovePtr<BufferWithMemory>(
1325             new BufferWithMemory(vk, device, allocator, bufferCreateInfo, memoryRequirement));
1326 
1327         bufferCreateInfo.size = getIndexBufferSize(m_geometriesData);
1328         if (bufferCreateInfo.size)
1329             m_indexBuffer = de::MovePtr<BufferWithMemory>(
1330                 new BufferWithMemory(vk, device, allocator, bufferCreateInfo, memoryRequirement));
1331         else
1332             m_indexBuffer = de::MovePtr<BufferWithMemory>(nullptr);
1333     }
1334 }
1335 
build(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,BottomLevelAccelerationStructure * srcAccelerationStructure)1336 void BottomLevelAccelerationStructureKHR::build(const DeviceInterface &vk, const VkDevice device,
1337                                                 const VkCommandBuffer cmdBuffer,
1338                                                 BottomLevelAccelerationStructure *srcAccelerationStructure)
1339 {
1340     DE_ASSERT(!m_geometriesData.empty());
1341     DE_ASSERT(m_accelerationStructureKHR.get() != DE_NULL);
1342     DE_ASSERT(m_buildScratchSize != 0);
1343 
1344     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
1345     {
1346         updateVertexBuffer(vk, device, m_geometriesData, getVertexBuffer(), getVertexBufferOffset());
1347         if (getIndexBuffer() != DE_NULL)
1348             updateIndexBuffer(vk, device, m_geometriesData, getIndexBuffer(), getIndexBufferOffset());
1349     }
1350 
1351     {
1352         std::vector<VkAccelerationStructureGeometryKHR> accelerationStructureGeometriesKHR;
1353         std::vector<VkAccelerationStructureGeometryKHR *> accelerationStructureGeometriesKHRPointers;
1354         std::vector<VkAccelerationStructureBuildRangeInfoKHR> accelerationStructureBuildRangeInfoKHR;
1355         std::vector<VkAccelerationStructureTrianglesOpacityMicromapEXT> accelerationStructureGeometryMicromapsEXT;
1356         std::vector<uint32_t> maxPrimitiveCounts;
1357 
1358         prepareGeometries(vk, device, accelerationStructureGeometriesKHR, accelerationStructureGeometriesKHRPointers,
1359                           accelerationStructureBuildRangeInfoKHR, accelerationStructureGeometryMicromapsEXT,
1360                           maxPrimitiveCounts, getVertexBufferOffset(), getIndexBufferOffset());
1361 
1362         const VkAccelerationStructureGeometryKHR *accelerationStructureGeometriesKHRPointer =
1363             accelerationStructureGeometriesKHR.data();
1364         const VkAccelerationStructureGeometryKHR *const *accelerationStructureGeometry =
1365             accelerationStructureGeometriesKHRPointers.data();
1366         VkDeviceOrHostAddressKHR scratchData =
1367             (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR) ?
1368                 makeDeviceOrHostAddressKHR(vk, device, getDeviceScratchBuffer()->get(),
1369                                            getDeviceScratchBufferOffset()) :
1370                 makeDeviceOrHostAddressKHR(getHostScratchBuffer()->data());
1371         const uint32_t geometryCount =
1372             (m_buildWithoutGeometries ? 0u : static_cast<uint32_t>(accelerationStructureGeometriesKHR.size()));
1373 
1374         VkAccelerationStructureKHR srcStructure =
1375             (srcAccelerationStructure != DE_NULL) ? *(srcAccelerationStructure->getPtr()) : DE_NULL;
1376         VkBuildAccelerationStructureModeKHR mode = (srcAccelerationStructure != DE_NULL) ?
1377                                                        VK_BUILD_ACCELERATION_STRUCTURE_MODE_UPDATE_KHR :
1378                                                        VK_BUILD_ACCELERATION_STRUCTURE_MODE_BUILD_KHR;
1379 
1380         VkAccelerationStructureBuildGeometryInfoKHR accelerationStructureBuildGeometryInfoKHR = {
1381             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_GEOMETRY_INFO_KHR, //  VkStructureType sType;
1382             DE_NULL,                                                          //  const void* pNext;
1383             VK_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL_KHR,                  //  VkAccelerationStructureTypeKHR type;
1384             m_buildFlags,                     //  VkBuildAccelerationStructureFlagsKHR flags;
1385             mode,                             //  VkBuildAccelerationStructureModeKHR mode;
1386             srcStructure,                     //  VkAccelerationStructureKHR srcAccelerationStructure;
1387             m_accelerationStructureKHR.get(), //  VkAccelerationStructureKHR dstAccelerationStructure;
1388             geometryCount,                    //  uint32_t geometryCount;
1389             m_useArrayOfPointers ?
1390                 DE_NULL :
1391                 accelerationStructureGeometriesKHRPointer, //  const VkAccelerationStructureGeometryKHR* pGeometries;
1392             m_useArrayOfPointers ? accelerationStructureGeometry :
1393                                    DE_NULL, //  const VkAccelerationStructureGeometryKHR* const* ppGeometries;
1394             scratchData                     //  VkDeviceOrHostAddressKHR scratchData;
1395         };
1396 
1397         VkAccelerationStructureBuildRangeInfoKHR *accelerationStructureBuildRangeInfoKHRPtr =
1398             accelerationStructureBuildRangeInfoKHR.data();
1399 
1400         if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
1401         {
1402             if (m_indirectBuffer == DE_NULL)
1403                 vk.cmdBuildAccelerationStructuresKHR(
1404                     cmdBuffer, 1u, &accelerationStructureBuildGeometryInfoKHR,
1405                     (const VkAccelerationStructureBuildRangeInfoKHR **)&accelerationStructureBuildRangeInfoKHRPtr);
1406             else
1407             {
1408                 VkDeviceAddress indirectDeviceAddress =
1409                     getBufferDeviceAddress(vk, device, m_indirectBuffer, m_indirectBufferOffset);
1410                 uint32_t *pMaxPrimitiveCounts = maxPrimitiveCounts.data();
1411                 vk.cmdBuildAccelerationStructuresIndirectKHR(cmdBuffer, 1u, &accelerationStructureBuildGeometryInfoKHR,
1412                                                              &indirectDeviceAddress, &m_indirectBufferStride,
1413                                                              &pMaxPrimitiveCounts);
1414             }
1415         }
1416         else if (!m_deferredOperation)
1417         {
1418             VK_CHECK(vk.buildAccelerationStructuresKHR(
1419                 device, DE_NULL, 1u, &accelerationStructureBuildGeometryInfoKHR,
1420                 (const VkAccelerationStructureBuildRangeInfoKHR **)&accelerationStructureBuildRangeInfoKHRPtr));
1421         }
1422         else
1423         {
1424             const auto deferredOperationPtr = createDeferredOperationKHR(vk, device);
1425             const auto deferredOperation    = deferredOperationPtr.get();
1426 
1427             VkResult result = vk.buildAccelerationStructuresKHR(
1428                 device, deferredOperation, 1u, &accelerationStructureBuildGeometryInfoKHR,
1429                 (const VkAccelerationStructureBuildRangeInfoKHR **)&accelerationStructureBuildRangeInfoKHRPtr);
1430 
1431             DE_ASSERT(result == VK_OPERATION_DEFERRED_KHR || result == VK_OPERATION_NOT_DEFERRED_KHR ||
1432                       result == VK_SUCCESS);
1433 
1434             finishDeferredOperation(vk, device, deferredOperation, m_workerThreadCount,
1435                                     result == VK_OPERATION_NOT_DEFERRED_KHR);
1436         }
1437     }
1438 
1439     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
1440     {
1441         const VkAccessFlags accessMasks =
1442             VK_ACCESS_ACCELERATION_STRUCTURE_WRITE_BIT_KHR | VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR;
1443         const VkMemoryBarrier memBarrier = makeMemoryBarrier(accessMasks, accessMasks);
1444 
1445         cmdPipelineMemoryBarrier(vk, cmdBuffer, VK_PIPELINE_STAGE_ACCELERATION_STRUCTURE_BUILD_BIT_KHR,
1446                                  VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, &memBarrier);
1447     }
1448 }
1449 
copyFrom(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,BottomLevelAccelerationStructure * accelerationStructure,bool compactCopy)1450 void BottomLevelAccelerationStructureKHR::copyFrom(const DeviceInterface &vk, const VkDevice device,
1451                                                    const VkCommandBuffer cmdBuffer,
1452                                                    BottomLevelAccelerationStructure *accelerationStructure,
1453                                                    bool compactCopy)
1454 {
1455     DE_ASSERT(m_accelerationStructureKHR.get() != DE_NULL);
1456     DE_ASSERT(accelerationStructure != DE_NULL);
1457 
1458     VkCopyAccelerationStructureInfoKHR copyAccelerationStructureInfo = {
1459         VK_STRUCTURE_TYPE_COPY_ACCELERATION_STRUCTURE_INFO_KHR, // VkStructureType sType;
1460         DE_NULL,                                                // const void* pNext;
1461         *(accelerationStructure->getPtr()),                     // VkAccelerationStructureKHR src;
1462         *(getPtr()),                                            // VkAccelerationStructureKHR dst;
1463         compactCopy ? VK_COPY_ACCELERATION_STRUCTURE_MODE_COMPACT_KHR :
1464                       VK_COPY_ACCELERATION_STRUCTURE_MODE_CLONE_KHR // VkCopyAccelerationStructureModeKHR mode;
1465     };
1466 
1467     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
1468     {
1469         vk.cmdCopyAccelerationStructureKHR(cmdBuffer, &copyAccelerationStructureInfo);
1470     }
1471     else if (!m_deferredOperation)
1472     {
1473         VK_CHECK(vk.copyAccelerationStructureKHR(device, DE_NULL, &copyAccelerationStructureInfo));
1474     }
1475     else
1476     {
1477         const auto deferredOperationPtr = createDeferredOperationKHR(vk, device);
1478         const auto deferredOperation    = deferredOperationPtr.get();
1479 
1480         VkResult result = vk.copyAccelerationStructureKHR(device, deferredOperation, &copyAccelerationStructureInfo);
1481 
1482         DE_ASSERT(result == VK_OPERATION_DEFERRED_KHR || result == VK_OPERATION_NOT_DEFERRED_KHR ||
1483                   result == VK_SUCCESS);
1484 
1485         finishDeferredOperation(vk, device, deferredOperation, m_workerThreadCount,
1486                                 result == VK_OPERATION_NOT_DEFERRED_KHR);
1487     }
1488 
1489     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
1490     {
1491         const VkAccessFlags accessMasks =
1492             VK_ACCESS_ACCELERATION_STRUCTURE_WRITE_BIT_KHR | VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR;
1493         const VkMemoryBarrier memBarrier = makeMemoryBarrier(accessMasks, accessMasks);
1494 
1495         cmdPipelineMemoryBarrier(vk, cmdBuffer, VK_PIPELINE_STAGE_ACCELERATION_STRUCTURE_BUILD_BIT_KHR,
1496                                  VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, &memBarrier);
1497     }
1498 }
1499 
serialize(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,SerialStorage * storage)1500 void BottomLevelAccelerationStructureKHR::serialize(const DeviceInterface &vk, const VkDevice device,
1501                                                     const VkCommandBuffer cmdBuffer, SerialStorage *storage)
1502 {
1503     DE_ASSERT(m_accelerationStructureKHR.get() != DE_NULL);
1504     DE_ASSERT(storage != DE_NULL);
1505 
1506     const VkCopyAccelerationStructureToMemoryInfoKHR copyAccelerationStructureInfo = {
1507         VK_STRUCTURE_TYPE_COPY_ACCELERATION_STRUCTURE_TO_MEMORY_INFO_KHR, // VkStructureType sType;
1508         DE_NULL,                                                          // const void* pNext;
1509         *(getPtr()),                                                      // VkAccelerationStructureKHR src;
1510         storage->getAddress(vk, device, m_buildType),                     // VkDeviceOrHostAddressKHR dst;
1511         VK_COPY_ACCELERATION_STRUCTURE_MODE_SERIALIZE_KHR                 // VkCopyAccelerationStructureModeKHR mode;
1512     };
1513 
1514     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
1515     {
1516         vk.cmdCopyAccelerationStructureToMemoryKHR(cmdBuffer, &copyAccelerationStructureInfo);
1517     }
1518     else if (!m_deferredOperation)
1519     {
1520         VK_CHECK(vk.copyAccelerationStructureToMemoryKHR(device, DE_NULL, &copyAccelerationStructureInfo));
1521     }
1522     else
1523     {
1524         const auto deferredOperationPtr = createDeferredOperationKHR(vk, device);
1525         const auto deferredOperation    = deferredOperationPtr.get();
1526 
1527         const VkResult result =
1528             vk.copyAccelerationStructureToMemoryKHR(device, deferredOperation, &copyAccelerationStructureInfo);
1529 
1530         DE_ASSERT(result == VK_OPERATION_DEFERRED_KHR || result == VK_OPERATION_NOT_DEFERRED_KHR ||
1531                   result == VK_SUCCESS);
1532 
1533         finishDeferredOperation(vk, device, deferredOperation, m_workerThreadCount,
1534                                 result == VK_OPERATION_NOT_DEFERRED_KHR);
1535     }
1536 }
1537 
deserialize(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,SerialStorage * storage)1538 void BottomLevelAccelerationStructureKHR::deserialize(const DeviceInterface &vk, const VkDevice device,
1539                                                       const VkCommandBuffer cmdBuffer, SerialStorage *storage)
1540 {
1541     DE_ASSERT(m_accelerationStructureKHR.get() != DE_NULL);
1542     DE_ASSERT(storage != DE_NULL);
1543 
1544     const VkCopyMemoryToAccelerationStructureInfoKHR copyAccelerationStructureInfo = {
1545         VK_STRUCTURE_TYPE_COPY_MEMORY_TO_ACCELERATION_STRUCTURE_INFO_KHR, // VkStructureType sType;
1546         DE_NULL,                                                          // const void* pNext;
1547         storage->getAddressConst(vk, device, m_buildType),                // VkDeviceOrHostAddressConstKHR src;
1548         *(getPtr()),                                                      // VkAccelerationStructureKHR dst;
1549         VK_COPY_ACCELERATION_STRUCTURE_MODE_DESERIALIZE_KHR               // VkCopyAccelerationStructureModeKHR mode;
1550     };
1551 
1552     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
1553     {
1554         vk.cmdCopyMemoryToAccelerationStructureKHR(cmdBuffer, &copyAccelerationStructureInfo);
1555     }
1556     else if (!m_deferredOperation)
1557     {
1558         VK_CHECK(vk.copyMemoryToAccelerationStructureKHR(device, DE_NULL, &copyAccelerationStructureInfo));
1559     }
1560     else
1561     {
1562         const auto deferredOperationPtr = createDeferredOperationKHR(vk, device);
1563         const auto deferredOperation    = deferredOperationPtr.get();
1564 
1565         const VkResult result =
1566             vk.copyMemoryToAccelerationStructureKHR(device, deferredOperation, &copyAccelerationStructureInfo);
1567 
1568         DE_ASSERT(result == VK_OPERATION_DEFERRED_KHR || result == VK_OPERATION_NOT_DEFERRED_KHR ||
1569                   result == VK_SUCCESS);
1570 
1571         finishDeferredOperation(vk, device, deferredOperation, m_workerThreadCount,
1572                                 result == VK_OPERATION_NOT_DEFERRED_KHR);
1573     }
1574 
1575     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
1576     {
1577         const VkAccessFlags accessMasks =
1578             VK_ACCESS_ACCELERATION_STRUCTURE_WRITE_BIT_KHR | VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR;
1579         const VkMemoryBarrier memBarrier = makeMemoryBarrier(accessMasks, accessMasks);
1580 
1581         cmdPipelineMemoryBarrier(vk, cmdBuffer, VK_PIPELINE_STAGE_ACCELERATION_STRUCTURE_BUILD_BIT_KHR,
1582                                  VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, &memBarrier);
1583     }
1584 }
1585 
getPtr(void) const1586 const VkAccelerationStructureKHR *BottomLevelAccelerationStructureKHR::getPtr(void) const
1587 {
1588     return &m_accelerationStructureKHR.get();
1589 }
1590 
prepareGeometries(const DeviceInterface & vk,const VkDevice device,std::vector<VkAccelerationStructureGeometryKHR> & accelerationStructureGeometriesKHR,std::vector<VkAccelerationStructureGeometryKHR * > & accelerationStructureGeometriesKHRPointers,std::vector<VkAccelerationStructureBuildRangeInfoKHR> & accelerationStructureBuildRangeInfoKHR,std::vector<VkAccelerationStructureTrianglesOpacityMicromapEXT> & accelerationStructureGeometryMicromapsEXT,std::vector<uint32_t> & maxPrimitiveCounts,VkDeviceSize vertexBufferOffset,VkDeviceSize indexBufferOffset) const1591 void BottomLevelAccelerationStructureKHR::prepareGeometries(
1592     const DeviceInterface &vk, const VkDevice device,
1593     std::vector<VkAccelerationStructureGeometryKHR> &accelerationStructureGeometriesKHR,
1594     std::vector<VkAccelerationStructureGeometryKHR *> &accelerationStructureGeometriesKHRPointers,
1595     std::vector<VkAccelerationStructureBuildRangeInfoKHR> &accelerationStructureBuildRangeInfoKHR,
1596     std::vector<VkAccelerationStructureTrianglesOpacityMicromapEXT> &accelerationStructureGeometryMicromapsEXT,
1597     std::vector<uint32_t> &maxPrimitiveCounts, VkDeviceSize vertexBufferOffset, VkDeviceSize indexBufferOffset) const
1598 {
1599     accelerationStructureGeometriesKHR.resize(m_geometriesData.size());
1600     accelerationStructureGeometriesKHRPointers.resize(m_geometriesData.size());
1601     accelerationStructureBuildRangeInfoKHR.resize(m_geometriesData.size());
1602     accelerationStructureGeometryMicromapsEXT.resize(m_geometriesData.size());
1603     maxPrimitiveCounts.resize(m_geometriesData.size());
1604 
1605     for (size_t geometryNdx = 0; geometryNdx < m_geometriesData.size(); ++geometryNdx)
1606     {
1607         const de::SharedPtr<RaytracedGeometryBase> &geometryData = m_geometriesData[geometryNdx];
1608         VkDeviceOrHostAddressConstKHR vertexData, indexData;
1609         if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
1610         {
1611             if (getVertexBuffer() != DE_NULL)
1612             {
1613                 vertexData = makeDeviceOrHostAddressConstKHR(vk, device, getVertexBuffer()->get(), vertexBufferOffset);
1614                 if (m_indirectBuffer == DE_NULL)
1615                 {
1616                     vertexBufferOffset += deAlignSize(geometryData->getVertexByteSize(), 8);
1617                 }
1618             }
1619             else
1620                 vertexData = makeDeviceOrHostAddressConstKHR(DE_NULL);
1621 
1622             if (getIndexBuffer() != DE_NULL && geometryData->getIndexType() != VK_INDEX_TYPE_NONE_KHR)
1623             {
1624                 indexData = makeDeviceOrHostAddressConstKHR(vk, device, getIndexBuffer()->get(), indexBufferOffset);
1625                 indexBufferOffset += deAlignSize(geometryData->getIndexByteSize(), 8);
1626             }
1627             else
1628                 indexData = makeDeviceOrHostAddressConstKHR(DE_NULL);
1629         }
1630         else
1631         {
1632             vertexData = makeDeviceOrHostAddressConstKHR(geometryData->getVertexPointer());
1633             if (geometryData->getIndexType() != VK_INDEX_TYPE_NONE_KHR)
1634                 indexData = makeDeviceOrHostAddressConstKHR(geometryData->getIndexPointer());
1635             else
1636                 indexData = makeDeviceOrHostAddressConstKHR(DE_NULL);
1637         }
1638 
1639         VkAccelerationStructureGeometryTrianglesDataKHR accelerationStructureGeometryTrianglesDataKHR = {
1640             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_GEOMETRY_TRIANGLES_DATA_KHR, //  VkStructureType sType;
1641             DE_NULL,                                                              //  const void* pNext;
1642             geometryData->getVertexFormat(),                                      //  VkFormat vertexFormat;
1643             vertexData,                                            //  VkDeviceOrHostAddressConstKHR vertexData;
1644             geometryData->getVertexStride(),                       //  VkDeviceSize vertexStride;
1645             static_cast<uint32_t>(geometryData->getVertexCount()), //  uint32_t maxVertex;
1646             geometryData->getIndexType(),                          //  VkIndexType indexType;
1647             indexData,                                             //  VkDeviceOrHostAddressConstKHR indexData;
1648             makeDeviceOrHostAddressConstKHR(DE_NULL),              //  VkDeviceOrHostAddressConstKHR transformData;
1649         };
1650 
1651         if (geometryData->getHasOpacityMicromap())
1652             accelerationStructureGeometryTrianglesDataKHR.pNext = &geometryData->getOpacityMicromap();
1653 
1654         const VkAccelerationStructureGeometryAabbsDataKHR accelerationStructureGeometryAabbsDataKHR = {
1655             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_GEOMETRY_AABBS_DATA_KHR, //  VkStructureType sType;
1656             DE_NULL,                                                          //  const void* pNext;
1657             vertexData,                                                       //  VkDeviceOrHostAddressConstKHR data;
1658             geometryData->getAABBStride()                                     //  VkDeviceSize stride;
1659         };
1660         const VkAccelerationStructureGeometryDataKHR geometry =
1661             (geometryData->isTrianglesType()) ?
1662                 makeVkAccelerationStructureGeometryDataKHR(accelerationStructureGeometryTrianglesDataKHR) :
1663                 makeVkAccelerationStructureGeometryDataKHR(accelerationStructureGeometryAabbsDataKHR);
1664         const VkAccelerationStructureGeometryKHR accelerationStructureGeometryKHR = {
1665             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_GEOMETRY_KHR, //  VkStructureType sType;
1666             DE_NULL,                                               //  const void* pNext;
1667             geometryData->getGeometryType(),                       //  VkGeometryTypeKHR geometryType;
1668             geometry,                                              //  VkAccelerationStructureGeometryDataKHR geometry;
1669             geometryData->getGeometryFlags()                       //  VkGeometryFlagsKHR flags;
1670         };
1671 
1672         const uint32_t primitiveCount = (m_buildWithoutPrimitives ? 0u : geometryData->getPrimitiveCount());
1673 
1674         const VkAccelerationStructureBuildRangeInfoKHR accelerationStructureBuildRangeInfosKHR = {
1675             primitiveCount, //  uint32_t primitiveCount;
1676             0,              //  uint32_t primitiveOffset;
1677             0,              //  uint32_t firstVertex;
1678             0               //  uint32_t firstTransform;
1679         };
1680 
1681         accelerationStructureGeometriesKHR[geometryNdx]         = accelerationStructureGeometryKHR;
1682         accelerationStructureGeometriesKHRPointers[geometryNdx] = &accelerationStructureGeometriesKHR[geometryNdx];
1683         accelerationStructureBuildRangeInfoKHR[geometryNdx]     = accelerationStructureBuildRangeInfosKHR;
1684         maxPrimitiveCounts[geometryNdx]                         = geometryData->getPrimitiveCount();
1685     }
1686 }
1687 
getRequiredAllocationCount(void)1688 uint32_t BottomLevelAccelerationStructure::getRequiredAllocationCount(void)
1689 {
1690     return BottomLevelAccelerationStructureKHR::getRequiredAllocationCount();
1691 }
1692 
createAndBuild(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,Allocator & allocator,VkDeviceAddress deviceAddress)1693 void BottomLevelAccelerationStructure::createAndBuild(const DeviceInterface &vk, const VkDevice device,
1694                                                       const VkCommandBuffer cmdBuffer, Allocator &allocator,
1695                                                       VkDeviceAddress deviceAddress)
1696 {
1697     create(vk, device, allocator, 0u, deviceAddress);
1698     build(vk, device, cmdBuffer);
1699 }
1700 
createAndCopyFrom(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,Allocator & allocator,BottomLevelAccelerationStructure * accelerationStructure,VkDeviceSize compactCopySize,VkDeviceAddress deviceAddress)1701 void BottomLevelAccelerationStructure::createAndCopyFrom(const DeviceInterface &vk, const VkDevice device,
1702                                                          const VkCommandBuffer cmdBuffer, Allocator &allocator,
1703                                                          BottomLevelAccelerationStructure *accelerationStructure,
1704                                                          VkDeviceSize compactCopySize, VkDeviceAddress deviceAddress)
1705 {
1706     DE_ASSERT(accelerationStructure != NULL);
1707     VkDeviceSize copiedSize = compactCopySize > 0u ?
1708                                   compactCopySize :
1709                                   accelerationStructure->getStructureBuildSizes().accelerationStructureSize;
1710     DE_ASSERT(copiedSize != 0u);
1711 
1712     create(vk, device, allocator, copiedSize, deviceAddress);
1713     copyFrom(vk, device, cmdBuffer, accelerationStructure, compactCopySize > 0u);
1714 }
1715 
createAndDeserializeFrom(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,Allocator & allocator,SerialStorage * storage,VkDeviceAddress deviceAddress)1716 void BottomLevelAccelerationStructure::createAndDeserializeFrom(const DeviceInterface &vk, const VkDevice device,
1717                                                                 const VkCommandBuffer cmdBuffer, Allocator &allocator,
1718                                                                 SerialStorage *storage, VkDeviceAddress deviceAddress)
1719 {
1720     DE_ASSERT(storage != NULL);
1721     DE_ASSERT(storage->getStorageSize() >= SerialStorage::SERIAL_STORAGE_SIZE_MIN);
1722     create(vk, device, allocator, storage->getDeserializedSize(), deviceAddress);
1723     deserialize(vk, device, cmdBuffer, storage);
1724 }
1725 
updateGeometry(size_t geometryIndex,de::SharedPtr<RaytracedGeometryBase> & raytracedGeometry)1726 void BottomLevelAccelerationStructureKHR::updateGeometry(size_t geometryIndex,
1727                                                          de::SharedPtr<RaytracedGeometryBase> &raytracedGeometry)
1728 {
1729     DE_ASSERT(geometryIndex < m_geometriesData.size());
1730     m_geometriesData[geometryIndex] = raytracedGeometry;
1731 }
1732 
makeBottomLevelAccelerationStructure()1733 de::MovePtr<BottomLevelAccelerationStructure> makeBottomLevelAccelerationStructure()
1734 {
1735     return de::MovePtr<BottomLevelAccelerationStructure>(new BottomLevelAccelerationStructureKHR);
1736 }
1737 
1738 // Forward declaration
1739 struct BottomLevelAccelerationStructurePoolImpl;
1740 
1741 class BottomLevelAccelerationStructurePoolMember : public BottomLevelAccelerationStructureKHR
1742 {
1743 public:
1744     friend class BottomLevelAccelerationStructurePool;
1745 
1746     BottomLevelAccelerationStructurePoolMember(BottomLevelAccelerationStructurePoolImpl &pool);
1747     BottomLevelAccelerationStructurePoolMember(const BottomLevelAccelerationStructurePoolMember &) = delete;
1748     BottomLevelAccelerationStructurePoolMember(BottomLevelAccelerationStructurePoolMember &&)      = delete;
1749     virtual ~BottomLevelAccelerationStructurePoolMember()                                          = default;
1750 
create(const DeviceInterface &,const VkDevice,Allocator &,VkDeviceSize,VkDeviceAddress,const void *,const MemoryRequirement &,const VkBuffer,const VkDeviceSize)1751     virtual void create(const DeviceInterface &, const VkDevice, Allocator &, VkDeviceSize, VkDeviceAddress,
1752                         const void *, const MemoryRequirement &, const VkBuffer, const VkDeviceSize) override
1753     {
1754         DE_ASSERT(0); // Silent this method
1755     }
1756     virtual auto computeBuildSize(const DeviceInterface &vk, const VkDevice device, const VkDeviceSize strSize) const
1757         //              accStrSize,updateScratch, buildScratch, vertexSize,   indexSize
1758         -> std::tuple<VkDeviceSize, VkDeviceSize, VkDeviceSize, VkDeviceSize, VkDeviceSize>;
1759 
1760 protected:
1761     struct Info;
1762     virtual void preCreateSetSizesAndOffsets(const Info &info, const VkDeviceSize accStrSize,
1763                                              const VkDeviceSize updateScratchSize, const VkDeviceSize buildScratchSize);
1764     virtual void createAccellerationStructure(const DeviceInterface &vk, const VkDevice device,
1765                                               VkDeviceAddress deviceAddress);
1766 
1767     virtual BufferWithMemory *getAccelerationStructureBuffer() const override;
1768     virtual BufferWithMemory *getDeviceScratchBuffer() const override;
1769     virtual std::vector<uint8_t> *getHostScratchBuffer() const override;
1770     virtual BufferWithMemory *getVertexBuffer() const override;
1771     virtual BufferWithMemory *getIndexBuffer() const override;
1772 
getAccelerationStructureBufferOffset() const1773     virtual VkDeviceSize getAccelerationStructureBufferOffset() const override
1774     {
1775         return m_info.accStrOffset;
1776     }
getDeviceScratchBufferOffset() const1777     virtual VkDeviceSize getDeviceScratchBufferOffset() const override
1778     {
1779         return m_info.buildScratchBuffOffset;
1780     }
getVertexBufferOffset() const1781     virtual VkDeviceSize getVertexBufferOffset() const override
1782     {
1783         return m_info.vertBuffOffset;
1784     }
getIndexBufferOffset() const1785     virtual VkDeviceSize getIndexBufferOffset() const override
1786     {
1787         return m_info.indexBuffOffset;
1788     }
1789 
1790     BottomLevelAccelerationStructurePoolImpl &m_pool;
1791 
1792     struct Info
1793     {
1794         uint32_t accStrIndex;
1795         VkDeviceSize accStrOffset;
1796         uint32_t vertBuffIndex;
1797         VkDeviceSize vertBuffOffset;
1798         uint32_t indexBuffIndex;
1799         VkDeviceSize indexBuffOffset;
1800         uint32_t buildScratchBuffIndex;
1801         VkDeviceSize buildScratchBuffOffset;
1802     } m_info;
1803 };
1804 
1805 template <class X>
negz(const X &)1806 inline X negz(const X &)
1807 {
1808     return (~static_cast<X>(0));
1809 }
1810 template <class X>
isnegz(const X & x)1811 inline bool isnegz(const X &x)
1812 {
1813     return x == negz(x);
1814 }
1815 template <class Y>
make_unsigned(const Y & y)1816 inline auto make_unsigned(const Y &y) -> typename std::make_unsigned<Y>::type
1817 {
1818     return static_cast<typename std::make_unsigned<Y>::type>(y);
1819 }
1820 
BottomLevelAccelerationStructurePoolMember(BottomLevelAccelerationStructurePoolImpl & pool)1821 BottomLevelAccelerationStructurePoolMember::BottomLevelAccelerationStructurePoolMember(
1822     BottomLevelAccelerationStructurePoolImpl &pool)
1823     : m_pool(pool)
1824     , m_info{}
1825 {
1826 }
1827 
1828 struct BottomLevelAccelerationStructurePoolImpl
1829 {
1830     BottomLevelAccelerationStructurePoolImpl(BottomLevelAccelerationStructurePoolImpl &&)      = delete;
1831     BottomLevelAccelerationStructurePoolImpl(const BottomLevelAccelerationStructurePoolImpl &) = delete;
1832     BottomLevelAccelerationStructurePoolImpl(BottomLevelAccelerationStructurePool &pool);
1833 
1834     BottomLevelAccelerationStructurePool &m_pool;
1835     std::vector<de::SharedPtr<BufferWithMemory>> m_accellerationStructureBuffers;
1836     de::SharedPtr<BufferWithMemory> m_deviceScratchBuffer;
1837     de::UniquePtr<std::vector<uint8_t>> m_hostScratchBuffer;
1838     std::vector<de::SharedPtr<BufferWithMemory>> m_vertexBuffers;
1839     std::vector<de::SharedPtr<BufferWithMemory>> m_indexBuffers;
1840 };
BottomLevelAccelerationStructurePoolImpl(BottomLevelAccelerationStructurePool & pool)1841 BottomLevelAccelerationStructurePoolImpl::BottomLevelAccelerationStructurePoolImpl(
1842     BottomLevelAccelerationStructurePool &pool)
1843     : m_pool(pool)
1844     , m_accellerationStructureBuffers()
1845     , m_deviceScratchBuffer()
1846     , m_hostScratchBuffer(new std::vector<uint8_t>)
1847     , m_vertexBuffers()
1848     , m_indexBuffers()
1849 {
1850 }
getAccelerationStructureBuffer() const1851 BufferWithMemory *BottomLevelAccelerationStructurePoolMember::getAccelerationStructureBuffer() const
1852 {
1853     BufferWithMemory *result = nullptr;
1854     if (m_pool.m_accellerationStructureBuffers.size())
1855     {
1856         DE_ASSERT(!isnegz(m_info.accStrIndex));
1857         result = m_pool.m_accellerationStructureBuffers[m_info.accStrIndex].get();
1858     }
1859     return result;
1860 }
getDeviceScratchBuffer() const1861 BufferWithMemory *BottomLevelAccelerationStructurePoolMember::getDeviceScratchBuffer() const
1862 {
1863     DE_ASSERT(m_info.buildScratchBuffIndex == 0);
1864     return m_pool.m_deviceScratchBuffer.get();
1865 }
getHostScratchBuffer() const1866 std::vector<uint8_t> *BottomLevelAccelerationStructurePoolMember::getHostScratchBuffer() const
1867 {
1868     return this->m_buildScratchSize ? m_pool.m_hostScratchBuffer.get() : nullptr;
1869 }
1870 
getVertexBuffer() const1871 BufferWithMemory *BottomLevelAccelerationStructurePoolMember::getVertexBuffer() const
1872 {
1873     BufferWithMemory *result = nullptr;
1874     if (m_pool.m_vertexBuffers.size())
1875     {
1876         DE_ASSERT(!isnegz(m_info.vertBuffIndex));
1877         result = m_pool.m_vertexBuffers[m_info.vertBuffIndex].get();
1878     }
1879     return result;
1880 }
getIndexBuffer() const1881 BufferWithMemory *BottomLevelAccelerationStructurePoolMember::getIndexBuffer() const
1882 {
1883     BufferWithMemory *result = nullptr;
1884     if (m_pool.m_indexBuffers.size())
1885     {
1886         DE_ASSERT(!isnegz(m_info.indexBuffIndex));
1887         result = m_pool.m_indexBuffers[m_info.indexBuffIndex].get();
1888     }
1889     return result;
1890 }
1891 
1892 struct BottomLevelAccelerationStructurePool::Impl : BottomLevelAccelerationStructurePoolImpl
1893 {
1894     friend class BottomLevelAccelerationStructurePool;
1895     friend class BottomLevelAccelerationStructurePoolMember;
1896 
Implvk::BottomLevelAccelerationStructurePool::Impl1897     Impl(BottomLevelAccelerationStructurePool &pool) : BottomLevelAccelerationStructurePoolImpl(pool)
1898     {
1899     }
1900 };
1901 
BottomLevelAccelerationStructurePool()1902 BottomLevelAccelerationStructurePool::BottomLevelAccelerationStructurePool()
1903     : m_batchStructCount(4)
1904     , m_batchGeomCount(0)
1905     , m_infos()
1906     , m_structs()
1907     , m_createOnce(false)
1908     , m_tryCachedMemory(true)
1909     , m_structsBuffSize(0)
1910     , m_updatesScratchSize(0)
1911     , m_buildsScratchSize(0)
1912     , m_verticesSize(0)
1913     , m_indicesSize(0)
1914     , m_impl(new Impl(*this))
1915 {
1916 }
1917 
~BottomLevelAccelerationStructurePool()1918 BottomLevelAccelerationStructurePool::~BottomLevelAccelerationStructurePool()
1919 {
1920     delete m_impl;
1921 }
1922 
batchStructCount(const uint32_t & value)1923 void BottomLevelAccelerationStructurePool::batchStructCount(const uint32_t &value)
1924 {
1925     DE_ASSERT(value >= 1);
1926     m_batchStructCount = value;
1927 }
1928 
add(VkDeviceSize structureSize,VkDeviceAddress deviceAddress)1929 auto BottomLevelAccelerationStructurePool::add(VkDeviceSize structureSize, VkDeviceAddress deviceAddress)
1930     -> BottomLevelAccelerationStructurePool::BlasPtr
1931 {
1932     // Prevent a programmer from calling this method after batchCreate(...) method has been called.
1933     if (m_createOnce)
1934         DE_ASSERT(0);
1935 
1936     auto blas = new BottomLevelAccelerationStructurePoolMember(*m_impl);
1937     m_infos.push_back({structureSize, deviceAddress});
1938     m_structs.emplace_back(blas);
1939     return m_structs.back();
1940 }
1941 
adjustBatchCount(const DeviceInterface & vkd,const VkDevice device,const std::vector<BottomLevelAccelerationStructurePool::BlasPtr> & structs,const std::vector<BottomLevelAccelerationStructurePool::BlasInfo> & infos,const VkDeviceSize maxBufferSize,uint32_t (& result)[4])1942 void adjustBatchCount(const DeviceInterface &vkd, const VkDevice device,
1943                       const std::vector<BottomLevelAccelerationStructurePool::BlasPtr> &structs,
1944                       const std::vector<BottomLevelAccelerationStructurePool::BlasInfo> &infos,
1945                       const VkDeviceSize maxBufferSize, uint32_t (&result)[4])
1946 {
1947     tcu::Vector<VkDeviceSize, 4> sizes(0);
1948     tcu::Vector<VkDeviceSize, 4> sums(0);
1949     tcu::Vector<uint32_t, 4> tmps(0);
1950     tcu::Vector<uint32_t, 4> batches(0);
1951 
1952     VkDeviceSize updateScratchSize = 0;
1953     static_cast<void>(updateScratchSize); // not used yet, disabled for future implementation
1954 
1955     auto updateIf = [&](uint32_t c)
1956     {
1957         if (sums[c] + sizes[c] <= maxBufferSize)
1958         {
1959             sums[c] += sizes[c];
1960             tmps[c] += 1;
1961 
1962             batches[c] = std::max(tmps[c], batches[c]);
1963         }
1964         else
1965         {
1966             sums[c] = 0;
1967             tmps[c] = 0;
1968         }
1969     };
1970 
1971     const uint32_t maxIter = static_cast<uint32_t>(structs.size());
1972     for (uint32_t i = 0; i < maxIter; ++i)
1973     {
1974         auto &str = *dynamic_cast<BottomLevelAccelerationStructurePoolMember *>(structs[i].get());
1975         std::tie(sizes[0], updateScratchSize, sizes[1], sizes[2], sizes[3]) =
1976             str.computeBuildSize(vkd, device, infos[i].structureSize);
1977 
1978         updateIf(0);
1979         updateIf(1);
1980         updateIf(2);
1981         updateIf(3);
1982     }
1983 
1984     result[0] = std::max(batches[0], 1u);
1985     result[1] = std::max(batches[1], 1u);
1986     result[2] = std::max(batches[2], 1u);
1987     result[3] = std::max(batches[3], 1u);
1988 }
1989 
getAllocationCount() const1990 size_t BottomLevelAccelerationStructurePool::getAllocationCount() const
1991 {
1992     return m_impl->m_accellerationStructureBuffers.size() + m_impl->m_vertexBuffers.size() +
1993            m_impl->m_indexBuffers.size() + 1 /* for scratch buffer */;
1994 }
1995 
getAllocationCount(const DeviceInterface & vk,const VkDevice device,const VkDeviceSize maxBufferSize) const1996 size_t BottomLevelAccelerationStructurePool::getAllocationCount(const DeviceInterface &vk, const VkDevice device,
1997                                                                 const VkDeviceSize maxBufferSize) const
1998 {
1999     DE_ASSERT(m_structs.size() != 0);
2000 
2001     std::map<uint32_t, VkDeviceSize> accStrSizes;
2002     std::map<uint32_t, VkDeviceSize> vertBuffSizes;
2003     std::map<uint32_t, VkDeviceSize> indexBuffSizes;
2004     std::map<uint32_t, VkDeviceSize> scratchBuffSizes;
2005 
2006     const uint32_t allStructsCount = structCount();
2007 
2008     uint32_t batchStructCount  = m_batchStructCount;
2009     uint32_t batchScratchCount = m_batchStructCount;
2010     uint32_t batchVertexCount  = m_batchGeomCount ? m_batchGeomCount : m_batchStructCount;
2011     uint32_t batchIndexCount   = batchVertexCount;
2012 
2013     if (!isnegz(maxBufferSize))
2014     {
2015         uint32_t batches[4];
2016         adjustBatchCount(vk, device, m_structs, m_infos, maxBufferSize, batches);
2017         batchStructCount  = batches[0];
2018         batchScratchCount = batches[1];
2019         batchVertexCount  = batches[2];
2020         batchIndexCount   = batches[3];
2021     }
2022 
2023     uint32_t iStr     = 0;
2024     uint32_t iScratch = 0;
2025     uint32_t iVertex  = 0;
2026     uint32_t iIndex   = 0;
2027 
2028     VkDeviceSize strSize           = 0;
2029     VkDeviceSize updateScratchSize = 0;
2030     VkDeviceSize buildScratchSize  = 0;
2031     VkDeviceSize vertexSize        = 0;
2032     VkDeviceSize indexSize         = 0;
2033 
2034     for (; iStr < allStructsCount; ++iStr)
2035     {
2036         auto &str = *dynamic_cast<BottomLevelAccelerationStructurePoolMember *>(m_structs[iStr].get());
2037         std::tie(strSize, updateScratchSize, buildScratchSize, vertexSize, indexSize) =
2038             str.computeBuildSize(vk, device, m_infos[iStr].structureSize);
2039 
2040         {
2041             const VkDeviceSize alignedStrSize = deAlign64(strSize, 256);
2042             const uint32_t accStrIndex        = (iStr / batchStructCount);
2043             accStrSizes[accStrIndex] += alignedStrSize;
2044         }
2045 
2046         if (buildScratchSize != 0)
2047         {
2048             const VkDeviceSize alignedBuilsScratchSize = deAlign64(buildScratchSize, 256);
2049             const uint32_t scratchBuffIndex            = (iScratch / batchScratchCount);
2050             scratchBuffSizes[scratchBuffIndex] += alignedBuilsScratchSize;
2051             iScratch += 1;
2052         }
2053 
2054         if (vertexSize != 0)
2055         {
2056             const VkDeviceSize alignedVertBuffSize = deAlign64(vertexSize, 8);
2057             const uint32_t vertBuffIndex           = (iVertex / batchVertexCount);
2058             vertBuffSizes[vertBuffIndex] += alignedVertBuffSize;
2059             iVertex += 1;
2060         }
2061 
2062         if (indexSize != 0)
2063         {
2064             const VkDeviceSize alignedIndexBuffSize = deAlign64(indexSize, 8);
2065             const uint32_t indexBuffIndex           = (iIndex / batchIndexCount);
2066             indexBuffSizes[indexBuffIndex] += alignedIndexBuffSize;
2067             iIndex += 1;
2068         }
2069     }
2070 
2071     return accStrSizes.size() + vertBuffSizes.size() + indexBuffSizes.size() + scratchBuffSizes.size();
2072 }
2073 
getAllocationSizes(const DeviceInterface & vk,const VkDevice device) const2074 tcu::Vector<VkDeviceSize, 4> BottomLevelAccelerationStructurePool::getAllocationSizes(const DeviceInterface &vk,
2075                                                                                       const VkDevice device) const
2076 {
2077     if (m_structsBuffSize)
2078     {
2079         return tcu::Vector<VkDeviceSize, 4>(m_structsBuffSize, m_buildsScratchSize, m_verticesSize, m_indicesSize);
2080     }
2081 
2082     VkDeviceSize strSize           = 0;
2083     VkDeviceSize updateScratchSize = 0;
2084     static_cast<void>(updateScratchSize); // not used yet, disabled for future implementation
2085     VkDeviceSize buildScratchSize     = 0;
2086     VkDeviceSize vertexSize           = 0;
2087     VkDeviceSize indexSize            = 0;
2088     VkDeviceSize sumStrSize           = 0;
2089     VkDeviceSize sumUpdateScratchSize = 0;
2090     static_cast<void>(sumUpdateScratchSize); // not used yet, disabled for future implementation
2091     VkDeviceSize sumBuildScratchSize = 0;
2092     VkDeviceSize sumVertexSize       = 0;
2093     VkDeviceSize sumIndexSize        = 0;
2094     for (size_t i = 0; i < structCount(); ++i)
2095     {
2096         auto &str = *dynamic_cast<BottomLevelAccelerationStructurePoolMember *>(m_structs[i].get());
2097         std::tie(strSize, updateScratchSize, buildScratchSize, vertexSize, indexSize) =
2098             str.computeBuildSize(vk, device, m_infos[i].structureSize);
2099         sumStrSize += deAlign64(strSize, 256);
2100         //sumUpdateScratchSize    += deAlign64(updateScratchSize, 256);    not used yet, disabled for future implementation
2101         sumBuildScratchSize += deAlign64(buildScratchSize, 256);
2102         sumVertexSize += deAlign64(vertexSize, 8);
2103         sumIndexSize += deAlign64(indexSize, 8);
2104     }
2105     return tcu::Vector<VkDeviceSize, 4>(sumStrSize, sumBuildScratchSize, sumVertexSize, sumIndexSize);
2106 }
2107 
batchCreate(const DeviceInterface & vkd,const VkDevice device,Allocator & allocator)2108 void BottomLevelAccelerationStructurePool::batchCreate(const DeviceInterface &vkd, const VkDevice device,
2109                                                        Allocator &allocator)
2110 {
2111     batchCreateAdjust(vkd, device, allocator, negz<VkDeviceSize>(0));
2112 }
2113 
batchCreateAdjust(const DeviceInterface & vkd,const VkDevice device,Allocator & allocator,const VkDeviceSize maxBufferSize)2114 void BottomLevelAccelerationStructurePool::batchCreateAdjust(const DeviceInterface &vkd, const VkDevice device,
2115                                                              Allocator &allocator, const VkDeviceSize maxBufferSize)
2116 {
2117     // Prevent a programmer from calling this method more than once.
2118     if (m_createOnce)
2119         DE_ASSERT(0);
2120 
2121     m_createOnce = true;
2122     DE_ASSERT(m_structs.size() != 0);
2123 
2124     auto createAccellerationStructureBuffer = [&](VkDeviceSize bufferSize) ->
2125         typename std::add_pointer<BufferWithMemory>::type
2126     {
2127         BufferWithMemory *res = nullptr;
2128         const VkBufferCreateInfo bci =
2129             makeBufferCreateInfo(bufferSize, VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_STORAGE_BIT_KHR |
2130                                                  VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
2131 
2132         if (m_tryCachedMemory)
2133             try
2134             {
2135                 res = new BufferWithMemory(vkd, device, allocator, bci,
2136                                            MemoryRequirement::Cached | MemoryRequirement::HostVisible |
2137                                                MemoryRequirement::Coherent | MemoryRequirement::DeviceAddress);
2138             }
2139             catch (const tcu::NotSupportedError &)
2140             {
2141                 res = nullptr;
2142             }
2143 
2144         return (nullptr != res) ? res :
2145                                   (new BufferWithMemory(vkd, device, allocator, bci,
2146                                                         MemoryRequirement::HostVisible | MemoryRequirement::Coherent |
2147                                                             MemoryRequirement::DeviceAddress));
2148     };
2149 
2150     auto createDeviceScratchBuffer = [&](VkDeviceSize bufferSize) -> de::SharedPtr<BufferWithMemory>
2151     {
2152         const VkBufferCreateInfo bci = makeBufferCreateInfo(bufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT |
2153                                                                             VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
2154         BufferWithMemory *p          = new BufferWithMemory(vkd, device, allocator, bci,
2155                                                             MemoryRequirement::HostVisible | MemoryRequirement::Coherent |
2156                                                                 MemoryRequirement::DeviceAddress);
2157         return de::SharedPtr<BufferWithMemory>(p);
2158     };
2159 
2160     std::map<uint32_t, VkDeviceSize> accStrSizes;
2161     std::map<uint32_t, VkDeviceSize> vertBuffSizes;
2162     std::map<uint32_t, VkDeviceSize> indexBuffSizes;
2163 
2164     const uint32_t allStructsCount = structCount();
2165     uint32_t iterKey               = 0;
2166 
2167     uint32_t batchStructCount = m_batchStructCount;
2168     uint32_t batchVertexCount = m_batchGeomCount ? m_batchGeomCount : m_batchStructCount;
2169     uint32_t batchIndexCount  = batchVertexCount;
2170 
2171     if (!isnegz(maxBufferSize))
2172     {
2173         uint32_t batches[4];
2174         adjustBatchCount(vkd, device, m_structs, m_infos, maxBufferSize, batches);
2175         batchStructCount = batches[0];
2176         // batches[1]: batchScratchCount
2177         batchVertexCount = batches[2];
2178         batchIndexCount  = batches[3];
2179     }
2180 
2181     uint32_t iStr    = 0;
2182     uint32_t iVertex = 0;
2183     uint32_t iIndex  = 0;
2184 
2185     VkDeviceSize strSize             = 0;
2186     VkDeviceSize updateScratchSize   = 0;
2187     VkDeviceSize buildScratchSize    = 0;
2188     VkDeviceSize maxBuildScratchSize = 0;
2189     VkDeviceSize vertexSize          = 0;
2190     VkDeviceSize indexSize           = 0;
2191 
2192     VkDeviceSize strOffset    = 0;
2193     VkDeviceSize vertexOffset = 0;
2194     VkDeviceSize indexOffset  = 0;
2195 
2196     uint32_t hostStructCount   = 0;
2197     uint32_t deviceStructCount = 0;
2198 
2199     for (; iStr < allStructsCount; ++iStr)
2200     {
2201         BottomLevelAccelerationStructurePoolMember::Info info{};
2202         auto &str = *dynamic_cast<BottomLevelAccelerationStructurePoolMember *>(m_structs[iStr].get());
2203         std::tie(strSize, updateScratchSize, buildScratchSize, vertexSize, indexSize) =
2204             str.computeBuildSize(vkd, device, m_infos[iStr].structureSize);
2205 
2206         ++(str.getBuildType() == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_HOST_KHR ? hostStructCount : deviceStructCount);
2207 
2208         {
2209             const VkDeviceSize alignedStrSize = deAlign64(strSize, 256);
2210             const uint32_t accStrIndex        = (iStr / batchStructCount);
2211             if (iStr != 0 && (iStr % batchStructCount) == 0)
2212             {
2213                 strOffset = 0;
2214             }
2215 
2216             info.accStrIndex  = accStrIndex;
2217             info.accStrOffset = strOffset;
2218             accStrSizes[accStrIndex] += alignedStrSize;
2219             strOffset += alignedStrSize;
2220             m_structsBuffSize += alignedStrSize;
2221         }
2222 
2223         if (buildScratchSize != 0)
2224         {
2225             maxBuildScratchSize = std::max(maxBuildScratchSize, make_unsigned(deAlign64(buildScratchSize, 256u)));
2226 
2227             info.buildScratchBuffIndex  = 0;
2228             info.buildScratchBuffOffset = 0;
2229         }
2230 
2231         if (vertexSize != 0)
2232         {
2233             const VkDeviceSize alignedVertBuffSize = deAlign64(vertexSize, 8);
2234             const uint32_t vertBuffIndex           = (iVertex / batchVertexCount);
2235             if (iVertex != 0 && (iVertex % batchVertexCount) == 0)
2236             {
2237                 vertexOffset = 0;
2238             }
2239 
2240             info.vertBuffIndex  = vertBuffIndex;
2241             info.vertBuffOffset = vertexOffset;
2242             vertBuffSizes[vertBuffIndex] += alignedVertBuffSize;
2243             vertexOffset += alignedVertBuffSize;
2244             m_verticesSize += alignedVertBuffSize;
2245             iVertex += 1;
2246         }
2247 
2248         if (indexSize != 0)
2249         {
2250             const VkDeviceSize alignedIndexBuffSize = deAlign64(indexSize, 8);
2251             const uint32_t indexBuffIndex           = (iIndex / batchIndexCount);
2252             if (iIndex != 0 && (iIndex % batchIndexCount) == 0)
2253             {
2254                 indexOffset = 0;
2255             }
2256 
2257             info.indexBuffIndex  = indexBuffIndex;
2258             info.indexBuffOffset = indexOffset;
2259             indexBuffSizes[indexBuffIndex] += alignedIndexBuffSize;
2260             indexOffset += alignedIndexBuffSize;
2261             m_indicesSize += alignedIndexBuffSize;
2262             iIndex += 1;
2263         }
2264 
2265         str.preCreateSetSizesAndOffsets(info, strSize, updateScratchSize, buildScratchSize);
2266     }
2267 
2268     for (iterKey = 0; iterKey < static_cast<uint32_t>(accStrSizes.size()); ++iterKey)
2269     {
2270         m_impl->m_accellerationStructureBuffers.emplace_back(
2271             createAccellerationStructureBuffer(accStrSizes.at(iterKey)));
2272     }
2273     for (iterKey = 0; iterKey < static_cast<uint32_t>(vertBuffSizes.size()); ++iterKey)
2274     {
2275         m_impl->m_vertexBuffers.emplace_back(createVertexBuffer(vkd, device, allocator, vertBuffSizes.at(iterKey)));
2276     }
2277     for (iterKey = 0; iterKey < static_cast<uint32_t>(indexBuffSizes.size()); ++iterKey)
2278     {
2279         m_impl->m_indexBuffers.emplace_back(createIndexBuffer(vkd, device, allocator, indexBuffSizes.at(iterKey)));
2280     }
2281 
2282     if (maxBuildScratchSize)
2283     {
2284         if (hostStructCount)
2285             m_impl->m_hostScratchBuffer->resize(static_cast<size_t>(maxBuildScratchSize));
2286         if (deviceStructCount)
2287             m_impl->m_deviceScratchBuffer = createDeviceScratchBuffer(maxBuildScratchSize);
2288 
2289         m_buildsScratchSize = maxBuildScratchSize;
2290     }
2291 
2292     for (iterKey = 0; iterKey < allStructsCount; ++iterKey)
2293     {
2294         auto &str = *dynamic_cast<BottomLevelAccelerationStructurePoolMember *>(m_structs[iterKey].get());
2295         str.createAccellerationStructure(vkd, device, m_infos[iterKey].deviceAddress);
2296     }
2297 }
2298 
batchBuild(const DeviceInterface & vk,const VkDevice device,VkCommandBuffer cmdBuffer)2299 void BottomLevelAccelerationStructurePool::batchBuild(const DeviceInterface &vk, const VkDevice device,
2300                                                       VkCommandBuffer cmdBuffer)
2301 {
2302     for (const auto &str : m_structs)
2303     {
2304         str->build(vk, device, cmdBuffer);
2305     }
2306 }
2307 
batchBuild(const DeviceInterface & vk,const VkDevice device,VkCommandPool cmdPool,VkQueue queue,qpWatchDog * watchDog)2308 void BottomLevelAccelerationStructurePool::batchBuild(const DeviceInterface &vk, const VkDevice device,
2309                                                       VkCommandPool cmdPool, VkQueue queue, qpWatchDog *watchDog)
2310 {
2311     const uint32_t limit = 10000u;
2312     const uint32_t count = structCount();
2313     std::vector<BlasPtr> buildingOnDevice;
2314 
2315     auto buildOnDevice = [&]() -> void
2316     {
2317         Move<VkCommandBuffer> cmd = allocateCommandBuffer(vk, device, cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
2318 
2319         beginCommandBuffer(vk, *cmd, 0u);
2320         for (const auto &str : buildingOnDevice)
2321             str->build(vk, device, *cmd);
2322         endCommandBuffer(vk, *cmd);
2323 
2324         submitCommandsAndWait(vk, device, queue, *cmd);
2325         vk.resetCommandPool(device, cmdPool, VK_COMMAND_POOL_RESET_RELEASE_RESOURCES_BIT);
2326     };
2327 
2328     buildingOnDevice.reserve(limit);
2329     for (uint32_t i = 0; i < count; ++i)
2330     {
2331         auto str = m_structs[i];
2332 
2333         if (str->getBuildType() == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_HOST_KHR)
2334             str->build(vk, device, DE_NULL);
2335         else
2336             buildingOnDevice.emplace_back(str);
2337 
2338         if (buildingOnDevice.size() == limit || (count - 1) == i)
2339         {
2340             buildOnDevice();
2341             buildingOnDevice.clear();
2342         }
2343 
2344         if ((i % WATCHDOG_INTERVAL) == 0 && watchDog)
2345             qpWatchDog_touch(watchDog);
2346     }
2347 }
2348 
computeBuildSize(const DeviceInterface & vk,const VkDevice device,const VkDeviceSize strSize) const2349 auto BottomLevelAccelerationStructurePoolMember::computeBuildSize(const DeviceInterface &vk, const VkDevice device,
2350                                                                   const VkDeviceSize strSize) const
2351     //              accStrSize,updateScratch,buildScratch, vertexSize, indexSize
2352     -> std::tuple<VkDeviceSize, VkDeviceSize, VkDeviceSize, VkDeviceSize, VkDeviceSize>
2353 {
2354     DE_ASSERT(!m_geometriesData.empty() != !(strSize == 0)); // logical xor
2355 
2356     std::tuple<VkDeviceSize, VkDeviceSize, VkDeviceSize, VkDeviceSize, VkDeviceSize> result(deAlign64(strSize, 256), 0,
2357                                                                                             0, 0, 0);
2358 
2359     if (!m_geometriesData.empty())
2360     {
2361         std::vector<VkAccelerationStructureGeometryKHR> accelerationStructureGeometriesKHR;
2362         std::vector<VkAccelerationStructureGeometryKHR *> accelerationStructureGeometriesKHRPointers;
2363         std::vector<VkAccelerationStructureBuildRangeInfoKHR> accelerationStructureBuildRangeInfoKHR;
2364         std::vector<VkAccelerationStructureTrianglesOpacityMicromapEXT> accelerationStructureGeometryMicromapsEXT;
2365         std::vector<uint32_t> maxPrimitiveCounts;
2366         prepareGeometries(vk, device, accelerationStructureGeometriesKHR, accelerationStructureGeometriesKHRPointers,
2367                           accelerationStructureBuildRangeInfoKHR, accelerationStructureGeometryMicromapsEXT,
2368                           maxPrimitiveCounts);
2369 
2370         const VkAccelerationStructureGeometryKHR *accelerationStructureGeometriesKHRPointer =
2371             accelerationStructureGeometriesKHR.data();
2372         const VkAccelerationStructureGeometryKHR *const *accelerationStructureGeometry =
2373             accelerationStructureGeometriesKHRPointers.data();
2374 
2375         VkAccelerationStructureBuildGeometryInfoKHR accelerationStructureBuildGeometryInfoKHR = {
2376             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_GEOMETRY_INFO_KHR, //  VkStructureType sType;
2377             DE_NULL,                                                          //  const void* pNext;
2378             VK_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL_KHR,                  //  VkAccelerationStructureTypeKHR type;
2379             m_buildFlags,                                   //  VkBuildAccelerationStructureFlagsKHR flags;
2380             VK_BUILD_ACCELERATION_STRUCTURE_MODE_BUILD_KHR, //  VkBuildAccelerationStructureModeKHR mode;
2381             DE_NULL,                                        //  VkAccelerationStructureKHR srcAccelerationStructure;
2382             DE_NULL,                                        //  VkAccelerationStructureKHR dstAccelerationStructure;
2383             static_cast<uint32_t>(accelerationStructureGeometriesKHR.size()), //  uint32_t geometryCount;
2384             m_useArrayOfPointers ?
2385                 DE_NULL :
2386                 accelerationStructureGeometriesKHRPointer, //  const VkAccelerationStructureGeometryKHR* pGeometries;
2387             m_useArrayOfPointers ? accelerationStructureGeometry :
2388                                    DE_NULL,     //  const VkAccelerationStructureGeometryKHR* const* ppGeometries;
2389             makeDeviceOrHostAddressKHR(DE_NULL) //  VkDeviceOrHostAddressKHR scratchData;
2390         };
2391 
2392         VkAccelerationStructureBuildSizesInfoKHR sizeInfo = {
2393             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_SIZES_INFO_KHR, //  VkStructureType sType;
2394             DE_NULL,                                                       //  const void* pNext;
2395             0,                                                             //  VkDeviceSize accelerationStructureSize;
2396             0,                                                             //  VkDeviceSize updateScratchSize;
2397             0                                                              //  VkDeviceSize buildScratchSize;
2398         };
2399 
2400         vk.getAccelerationStructureBuildSizesKHR(device, m_buildType, &accelerationStructureBuildGeometryInfoKHR,
2401                                                  maxPrimitiveCounts.data(), &sizeInfo);
2402 
2403         std::get<0>(result) = sizeInfo.accelerationStructureSize;
2404         std::get<1>(result) = sizeInfo.updateScratchSize;
2405         std::get<2>(result) = sizeInfo.buildScratchSize;
2406         std::get<3>(result) = getVertexBufferSize(m_geometriesData);
2407         std::get<4>(result) = getIndexBufferSize(m_geometriesData);
2408     }
2409 
2410     return result;
2411 }
2412 
preCreateSetSizesAndOffsets(const Info & info,const VkDeviceSize accStrSize,const VkDeviceSize updateScratchSize,const VkDeviceSize buildScratchSize)2413 void BottomLevelAccelerationStructurePoolMember::preCreateSetSizesAndOffsets(const Info &info,
2414                                                                              const VkDeviceSize accStrSize,
2415                                                                              const VkDeviceSize updateScratchSize,
2416                                                                              const VkDeviceSize buildScratchSize)
2417 {
2418     m_info              = info;
2419     m_structureSize     = accStrSize;
2420     m_updateScratchSize = updateScratchSize;
2421     m_buildScratchSize  = buildScratchSize;
2422 }
2423 
createAccellerationStructure(const DeviceInterface & vk,const VkDevice device,VkDeviceAddress deviceAddress)2424 void BottomLevelAccelerationStructurePoolMember::createAccellerationStructure(const DeviceInterface &vk,
2425                                                                               const VkDevice device,
2426                                                                               VkDeviceAddress deviceAddress)
2427 {
2428     const VkAccelerationStructureTypeKHR structureType =
2429         (m_createGeneric ? VK_ACCELERATION_STRUCTURE_TYPE_GENERIC_KHR :
2430                            VK_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL_KHR);
2431     const VkAccelerationStructureCreateInfoKHR accelerationStructureCreateInfoKHR{
2432         VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_CREATE_INFO_KHR, //  VkStructureType sType;
2433         DE_NULL,                                                  //  const void* pNext;
2434         m_createFlags,                                            //  VkAccelerationStructureCreateFlagsKHR createFlags;
2435         getAccelerationStructureBuffer()->get(),                  //  VkBuffer buffer;
2436         getAccelerationStructureBufferOffset(),                   //  VkDeviceSize offset;
2437         m_structureSize,                                          //  VkDeviceSize size;
2438         structureType,                                            //  VkAccelerationStructureTypeKHR type;
2439         deviceAddress                                             //  VkDeviceAddress deviceAddress;
2440     };
2441 
2442     m_accelerationStructureKHR =
2443         createAccelerationStructureKHR(vk, device, &accelerationStructureCreateInfoKHR, DE_NULL);
2444 }
2445 
~TopLevelAccelerationStructure()2446 TopLevelAccelerationStructure::~TopLevelAccelerationStructure()
2447 {
2448 }
2449 
TopLevelAccelerationStructure()2450 TopLevelAccelerationStructure::TopLevelAccelerationStructure()
2451     : m_structureSize(0u)
2452     , m_updateScratchSize(0u)
2453     , m_buildScratchSize(0u)
2454 {
2455 }
2456 
setInstanceCount(const size_t instanceCount)2457 void TopLevelAccelerationStructure::setInstanceCount(const size_t instanceCount)
2458 {
2459     m_bottomLevelInstances.reserve(instanceCount);
2460     m_instanceData.reserve(instanceCount);
2461 }
2462 
addInstance(de::SharedPtr<BottomLevelAccelerationStructure> bottomLevelStructure,const VkTransformMatrixKHR & matrix,uint32_t instanceCustomIndex,uint32_t mask,uint32_t instanceShaderBindingTableRecordOffset,VkGeometryInstanceFlagsKHR flags)2463 void TopLevelAccelerationStructure::addInstance(de::SharedPtr<BottomLevelAccelerationStructure> bottomLevelStructure,
2464                                                 const VkTransformMatrixKHR &matrix, uint32_t instanceCustomIndex,
2465                                                 uint32_t mask, uint32_t instanceShaderBindingTableRecordOffset,
2466                                                 VkGeometryInstanceFlagsKHR flags)
2467 {
2468     m_bottomLevelInstances.push_back(bottomLevelStructure);
2469     m_instanceData.push_back(
2470         InstanceData(matrix, instanceCustomIndex, mask, instanceShaderBindingTableRecordOffset, flags));
2471 }
2472 
getStructureBuildSizes() const2473 VkAccelerationStructureBuildSizesInfoKHR TopLevelAccelerationStructure::getStructureBuildSizes() const
2474 {
2475     return {
2476         VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_SIZES_INFO_KHR, //  VkStructureType sType;
2477         DE_NULL,                                                       //  const void* pNext;
2478         m_structureSize,                                               //  VkDeviceSize accelerationStructureSize;
2479         m_updateScratchSize,                                           //  VkDeviceSize updateScratchSize;
2480         m_buildScratchSize                                             //  VkDeviceSize buildScratchSize;
2481     };
2482 }
2483 
createAndBuild(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,Allocator & allocator,VkDeviceAddress deviceAddress)2484 void TopLevelAccelerationStructure::createAndBuild(const DeviceInterface &vk, const VkDevice device,
2485                                                    const VkCommandBuffer cmdBuffer, Allocator &allocator,
2486                                                    VkDeviceAddress deviceAddress)
2487 {
2488     create(vk, device, allocator, 0u, deviceAddress);
2489     build(vk, device, cmdBuffer);
2490 }
2491 
createAndCopyFrom(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,Allocator & allocator,TopLevelAccelerationStructure * accelerationStructure,VkDeviceSize compactCopySize,VkDeviceAddress deviceAddress)2492 void TopLevelAccelerationStructure::createAndCopyFrom(const DeviceInterface &vk, const VkDevice device,
2493                                                       const VkCommandBuffer cmdBuffer, Allocator &allocator,
2494                                                       TopLevelAccelerationStructure *accelerationStructure,
2495                                                       VkDeviceSize compactCopySize, VkDeviceAddress deviceAddress)
2496 {
2497     DE_ASSERT(accelerationStructure != NULL);
2498     VkDeviceSize copiedSize = compactCopySize > 0u ?
2499                                   compactCopySize :
2500                                   accelerationStructure->getStructureBuildSizes().accelerationStructureSize;
2501     DE_ASSERT(copiedSize != 0u);
2502 
2503     create(vk, device, allocator, copiedSize, deviceAddress);
2504     copyFrom(vk, device, cmdBuffer, accelerationStructure, compactCopySize > 0u);
2505 }
2506 
createAndDeserializeFrom(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,Allocator & allocator,SerialStorage * storage,VkDeviceAddress deviceAddress)2507 void TopLevelAccelerationStructure::createAndDeserializeFrom(const DeviceInterface &vk, const VkDevice device,
2508                                                              const VkCommandBuffer cmdBuffer, Allocator &allocator,
2509                                                              SerialStorage *storage, VkDeviceAddress deviceAddress)
2510 {
2511     DE_ASSERT(storage != NULL);
2512     DE_ASSERT(storage->getStorageSize() >= SerialStorage::SERIAL_STORAGE_SIZE_MIN);
2513     create(vk, device, allocator, storage->getDeserializedSize(), deviceAddress);
2514     if (storage->hasDeepFormat())
2515         createAndDeserializeBottoms(vk, device, cmdBuffer, allocator, storage);
2516     deserialize(vk, device, cmdBuffer, storage);
2517 }
2518 
createInstanceBuffer(const DeviceInterface & vk,const VkDevice device,Allocator & allocator,std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> bottomLevelInstances,std::vector<InstanceData> instanceData,const bool tryCachedMemory)2519 BufferWithMemory *createInstanceBuffer(
2520     const DeviceInterface &vk, const VkDevice device, Allocator &allocator,
2521     std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> bottomLevelInstances,
2522     std::vector<InstanceData> instanceData, const bool tryCachedMemory)
2523 {
2524     DE_ASSERT(bottomLevelInstances.size() != 0);
2525     DE_ASSERT(bottomLevelInstances.size() == instanceData.size());
2526     DE_UNREF(instanceData);
2527 
2528     BufferWithMemory *result           = nullptr;
2529     const VkDeviceSize bufferSizeBytes = bottomLevelInstances.size() * sizeof(VkAccelerationStructureInstanceKHR);
2530     const VkBufferCreateInfo bufferCreateInfo =
2531         makeBufferCreateInfo(bufferSizeBytes, VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_BUILD_INPUT_READ_ONLY_BIT_KHR |
2532                                                   VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
2533     if (tryCachedMemory)
2534         try
2535         {
2536             result = new BufferWithMemory(vk, device, allocator, bufferCreateInfo,
2537                                           MemoryRequirement::Cached | MemoryRequirement::HostVisible |
2538                                               MemoryRequirement::Coherent | MemoryRequirement::DeviceAddress);
2539         }
2540         catch (const tcu::NotSupportedError &)
2541         {
2542             result = nullptr;
2543         }
2544     return result ? result :
2545                     new BufferWithMemory(vk, device, allocator, bufferCreateInfo,
2546                                          MemoryRequirement::HostVisible | MemoryRequirement::Coherent |
2547                                              MemoryRequirement::DeviceAddress);
2548 }
2549 
updateSingleInstance(const DeviceInterface & vk,const VkDevice device,const BottomLevelAccelerationStructure & bottomLevelAccelerationStructure,const InstanceData & instanceData,uint8_t * bufferLocation,VkAccelerationStructureBuildTypeKHR buildType,bool inactiveInstances)2550 void updateSingleInstance(const DeviceInterface &vk, const VkDevice device,
2551                           const BottomLevelAccelerationStructure &bottomLevelAccelerationStructure,
2552                           const InstanceData &instanceData, uint8_t *bufferLocation,
2553                           VkAccelerationStructureBuildTypeKHR buildType, bool inactiveInstances)
2554 {
2555     const VkAccelerationStructureKHR accelerationStructureKHR = *bottomLevelAccelerationStructure.getPtr();
2556 
2557     // This part needs to be fixed once a new version of the VkAccelerationStructureInstanceKHR will be added to vkStructTypes.inl
2558     VkDeviceAddress accelerationStructureAddress;
2559     if (buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
2560     {
2561         VkAccelerationStructureDeviceAddressInfoKHR asDeviceAddressInfo = {
2562             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_DEVICE_ADDRESS_INFO_KHR, // VkStructureType sType;
2563             DE_NULL,                                                          // const void* pNext;
2564             accelerationStructureKHR // VkAccelerationStructureKHR accelerationStructure;
2565         };
2566         accelerationStructureAddress = vk.getAccelerationStructureDeviceAddressKHR(device, &asDeviceAddressInfo);
2567     }
2568 
2569     uint64_t structureReference;
2570     if (inactiveInstances)
2571     {
2572         // Instances will be marked inactive by making their references VK_NULL_HANDLE or having address zero.
2573         structureReference = 0ull;
2574     }
2575     else
2576     {
2577         structureReference = (buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR) ?
2578                                  uint64_t(accelerationStructureAddress) :
2579                                  uint64_t(accelerationStructureKHR.getInternal());
2580     }
2581 
2582     VkAccelerationStructureInstanceKHR accelerationStructureInstanceKHR = makeVkAccelerationStructureInstanceKHR(
2583         instanceData.matrix,                                 //  VkTransformMatrixKHR transform;
2584         instanceData.instanceCustomIndex,                    //  uint32_t instanceCustomIndex:24;
2585         instanceData.mask,                                   //  uint32_t mask:8;
2586         instanceData.instanceShaderBindingTableRecordOffset, //  uint32_t instanceShaderBindingTableRecordOffset:24;
2587         instanceData.flags,                                  //  VkGeometryInstanceFlagsKHR flags:8;
2588         structureReference                                   //  uint64_t accelerationStructureReference;
2589     );
2590 
2591     deMemcpy(bufferLocation, &accelerationStructureInstanceKHR, sizeof(VkAccelerationStructureInstanceKHR));
2592 }
2593 
updateInstanceBuffer(const DeviceInterface & vk,const VkDevice device,const std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> & bottomLevelInstances,const std::vector<InstanceData> & instanceData,const BufferWithMemory * instanceBuffer,VkAccelerationStructureBuildTypeKHR buildType,bool inactiveInstances)2594 void updateInstanceBuffer(const DeviceInterface &vk, const VkDevice device,
2595                           const std::vector<de::SharedPtr<BottomLevelAccelerationStructure>> &bottomLevelInstances,
2596                           const std::vector<InstanceData> &instanceData, const BufferWithMemory *instanceBuffer,
2597                           VkAccelerationStructureBuildTypeKHR buildType, bool inactiveInstances)
2598 {
2599     DE_ASSERT(bottomLevelInstances.size() != 0);
2600     DE_ASSERT(bottomLevelInstances.size() == instanceData.size());
2601 
2602     auto &instancesAlloc      = instanceBuffer->getAllocation();
2603     auto bufferStart          = reinterpret_cast<uint8_t *>(instancesAlloc.getHostPtr());
2604     VkDeviceSize bufferOffset = 0ull;
2605 
2606     for (size_t instanceNdx = 0; instanceNdx < bottomLevelInstances.size(); ++instanceNdx)
2607     {
2608         const auto &blas = *bottomLevelInstances[instanceNdx];
2609         updateSingleInstance(vk, device, blas, instanceData[instanceNdx], bufferStart + bufferOffset, buildType,
2610                              inactiveInstances);
2611         bufferOffset += sizeof(VkAccelerationStructureInstanceKHR);
2612     }
2613 
2614     flushMappedMemoryRange(vk, device, instancesAlloc.getMemory(), instancesAlloc.getOffset(), VK_WHOLE_SIZE);
2615 }
2616 
2617 class TopLevelAccelerationStructureKHR : public TopLevelAccelerationStructure
2618 {
2619 public:
2620     static uint32_t getRequiredAllocationCount(void);
2621 
2622     TopLevelAccelerationStructureKHR();
2623     TopLevelAccelerationStructureKHR(const TopLevelAccelerationStructureKHR &other) = delete;
2624     virtual ~TopLevelAccelerationStructureKHR();
2625 
2626     void setBuildType(const VkAccelerationStructureBuildTypeKHR buildType) override;
2627     void setCreateFlags(const VkAccelerationStructureCreateFlagsKHR createFlags) override;
2628     void setCreateGeneric(bool createGeneric) override;
2629     void setCreationBufferUnbounded(bool creationBufferUnbounded) override;
2630     void setBuildFlags(const VkBuildAccelerationStructureFlagsKHR buildFlags) override;
2631     void setBuildWithoutPrimitives(bool buildWithoutPrimitives) override;
2632     void setInactiveInstances(bool inactiveInstances) override;
2633     void setDeferredOperation(const bool deferredOperation, const uint32_t workerThreadCount) override;
2634     void setUseArrayOfPointers(const bool useArrayOfPointers) override;
2635     void setIndirectBuildParameters(const VkBuffer indirectBuffer, const VkDeviceSize indirectBufferOffset,
2636                                     const uint32_t indirectBufferStride) override;
2637     void setUsePPGeometries(const bool usePPGeometries) override;
2638     void setTryCachedMemory(const bool tryCachedMemory) override;
2639     VkBuildAccelerationStructureFlagsKHR getBuildFlags() const override;
2640 
2641     void getCreationSizes(const DeviceInterface &vk, const VkDevice device, const VkDeviceSize structureSize,
2642                           CreationSizes &sizes) override;
2643     void create(const DeviceInterface &vk, const VkDevice device, Allocator &allocator, VkDeviceSize structureSize,
2644                 VkDeviceAddress deviceAddress = 0u, const void *pNext = DE_NULL,
2645                 const MemoryRequirement &addMemoryRequirement = MemoryRequirement::Any,
2646                 const VkBuffer creationBuffer = VK_NULL_HANDLE, const VkDeviceSize creationBufferSize = 0u) override;
2647     void build(const DeviceInterface &vk, const VkDevice device, const VkCommandBuffer cmdBuffer,
2648                TopLevelAccelerationStructure *srcAccelerationStructure = DE_NULL) override;
2649     void copyFrom(const DeviceInterface &vk, const VkDevice device, const VkCommandBuffer cmdBuffer,
2650                   TopLevelAccelerationStructure *accelerationStructure, bool compactCopy) override;
2651     void serialize(const DeviceInterface &vk, const VkDevice device, const VkCommandBuffer cmdBuffer,
2652                    SerialStorage *storage) override;
2653     void deserialize(const DeviceInterface &vk, const VkDevice device, const VkCommandBuffer cmdBuffer,
2654                      SerialStorage *storage) override;
2655 
2656     std::vector<VkDeviceSize> getSerializingSizes(const DeviceInterface &vk, const VkDevice device, const VkQueue queue,
2657                                                   const uint32_t queueFamilyIndex) override;
2658 
2659     std::vector<uint64_t> getSerializingAddresses(const DeviceInterface &vk, const VkDevice device) const override;
2660 
2661     const VkAccelerationStructureKHR *getPtr(void) const override;
2662 
2663     void updateInstanceMatrix(const DeviceInterface &vk, const VkDevice device, size_t instanceIndex,
2664                               const VkTransformMatrixKHR &matrix) override;
2665 
2666 protected:
2667     VkAccelerationStructureBuildTypeKHR m_buildType;
2668     VkAccelerationStructureCreateFlagsKHR m_createFlags;
2669     bool m_createGeneric;
2670     bool m_creationBufferUnbounded;
2671     VkBuildAccelerationStructureFlagsKHR m_buildFlags;
2672     bool m_buildWithoutPrimitives;
2673     bool m_inactiveInstances;
2674     bool m_deferredOperation;
2675     uint32_t m_workerThreadCount;
2676     bool m_useArrayOfPointers;
2677     de::MovePtr<BufferWithMemory> m_accelerationStructureBuffer;
2678     de::MovePtr<BufferWithMemory> m_instanceBuffer;
2679     de::MovePtr<BufferWithMemory> m_instanceAddressBuffer;
2680     de::MovePtr<BufferWithMemory> m_deviceScratchBuffer;
2681     std::vector<uint8_t> m_hostScratchBuffer;
2682     Move<VkAccelerationStructureKHR> m_accelerationStructureKHR;
2683     VkBuffer m_indirectBuffer;
2684     VkDeviceSize m_indirectBufferOffset;
2685     uint32_t m_indirectBufferStride;
2686     bool m_usePPGeometries;
2687     bool m_tryCachedMemory;
2688 
2689     void prepareInstances(const DeviceInterface &vk, const VkDevice device,
2690                           VkAccelerationStructureGeometryKHR &accelerationStructureGeometryKHR,
2691                           std::vector<uint32_t> &maxPrimitiveCounts);
2692 
2693     void serializeBottoms(const DeviceInterface &vk, const VkDevice device, const VkCommandBuffer cmdBuffer,
2694                           SerialStorage *storage, VkDeferredOperationKHR deferredOperation);
2695 
2696     void createAndDeserializeBottoms(const DeviceInterface &vk, const VkDevice device, const VkCommandBuffer cmdBuffer,
2697                                      Allocator &allocator, SerialStorage *storage) override;
2698 };
2699 
getRequiredAllocationCount(void)2700 uint32_t TopLevelAccelerationStructureKHR::getRequiredAllocationCount(void)
2701 {
2702     /*
2703         de::MovePtr<BufferWithMemory>                            m_instanceBuffer;
2704         de::MovePtr<Allocation>                                    m_accelerationStructureAlloc;
2705         de::MovePtr<BufferWithMemory>                            m_deviceScratchBuffer;
2706     */
2707     return 3u;
2708 }
2709 
TopLevelAccelerationStructureKHR()2710 TopLevelAccelerationStructureKHR::TopLevelAccelerationStructureKHR()
2711     : TopLevelAccelerationStructure()
2712     , m_buildType(VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
2713     , m_createFlags(0u)
2714     , m_createGeneric(false)
2715     , m_creationBufferUnbounded(false)
2716     , m_buildFlags(0u)
2717     , m_buildWithoutPrimitives(false)
2718     , m_inactiveInstances(false)
2719     , m_deferredOperation(false)
2720     , m_workerThreadCount(0)
2721     , m_useArrayOfPointers(false)
2722     , m_accelerationStructureBuffer(DE_NULL)
2723     , m_instanceBuffer(DE_NULL)
2724     , m_instanceAddressBuffer(DE_NULL)
2725     , m_deviceScratchBuffer(DE_NULL)
2726     , m_accelerationStructureKHR()
2727     , m_indirectBuffer(DE_NULL)
2728     , m_indirectBufferOffset(0)
2729     , m_indirectBufferStride(0)
2730     , m_usePPGeometries(false)
2731     , m_tryCachedMemory(true)
2732 {
2733 }
2734 
~TopLevelAccelerationStructureKHR()2735 TopLevelAccelerationStructureKHR::~TopLevelAccelerationStructureKHR()
2736 {
2737 }
2738 
setBuildType(const VkAccelerationStructureBuildTypeKHR buildType)2739 void TopLevelAccelerationStructureKHR::setBuildType(const VkAccelerationStructureBuildTypeKHR buildType)
2740 {
2741     m_buildType = buildType;
2742 }
2743 
setCreateFlags(const VkAccelerationStructureCreateFlagsKHR createFlags)2744 void TopLevelAccelerationStructureKHR::setCreateFlags(const VkAccelerationStructureCreateFlagsKHR createFlags)
2745 {
2746     m_createFlags = createFlags;
2747 }
2748 
setCreateGeneric(bool createGeneric)2749 void TopLevelAccelerationStructureKHR::setCreateGeneric(bool createGeneric)
2750 {
2751     m_createGeneric = createGeneric;
2752 }
2753 
setCreationBufferUnbounded(bool creationBufferUnbounded)2754 void TopLevelAccelerationStructureKHR::setCreationBufferUnbounded(bool creationBufferUnbounded)
2755 {
2756     m_creationBufferUnbounded = creationBufferUnbounded;
2757 }
2758 
setInactiveInstances(bool inactiveInstances)2759 void TopLevelAccelerationStructureKHR::setInactiveInstances(bool inactiveInstances)
2760 {
2761     m_inactiveInstances = inactiveInstances;
2762 }
2763 
setBuildFlags(const VkBuildAccelerationStructureFlagsKHR buildFlags)2764 void TopLevelAccelerationStructureKHR::setBuildFlags(const VkBuildAccelerationStructureFlagsKHR buildFlags)
2765 {
2766     m_buildFlags = buildFlags;
2767 }
2768 
setBuildWithoutPrimitives(bool buildWithoutPrimitives)2769 void TopLevelAccelerationStructureKHR::setBuildWithoutPrimitives(bool buildWithoutPrimitives)
2770 {
2771     m_buildWithoutPrimitives = buildWithoutPrimitives;
2772 }
2773 
setDeferredOperation(const bool deferredOperation,const uint32_t workerThreadCount)2774 void TopLevelAccelerationStructureKHR::setDeferredOperation(const bool deferredOperation,
2775                                                             const uint32_t workerThreadCount)
2776 {
2777     m_deferredOperation = deferredOperation;
2778     m_workerThreadCount = workerThreadCount;
2779 }
2780 
setUseArrayOfPointers(const bool useArrayOfPointers)2781 void TopLevelAccelerationStructureKHR::setUseArrayOfPointers(const bool useArrayOfPointers)
2782 {
2783     m_useArrayOfPointers = useArrayOfPointers;
2784 }
2785 
setUsePPGeometries(const bool usePPGeometries)2786 void TopLevelAccelerationStructureKHR::setUsePPGeometries(const bool usePPGeometries)
2787 {
2788     m_usePPGeometries = usePPGeometries;
2789 }
2790 
setTryCachedMemory(const bool tryCachedMemory)2791 void TopLevelAccelerationStructureKHR::setTryCachedMemory(const bool tryCachedMemory)
2792 {
2793     m_tryCachedMemory = tryCachedMemory;
2794 }
2795 
setIndirectBuildParameters(const VkBuffer indirectBuffer,const VkDeviceSize indirectBufferOffset,const uint32_t indirectBufferStride)2796 void TopLevelAccelerationStructureKHR::setIndirectBuildParameters(const VkBuffer indirectBuffer,
2797                                                                   const VkDeviceSize indirectBufferOffset,
2798                                                                   const uint32_t indirectBufferStride)
2799 {
2800     m_indirectBuffer       = indirectBuffer;
2801     m_indirectBufferOffset = indirectBufferOffset;
2802     m_indirectBufferStride = indirectBufferStride;
2803 }
2804 
getBuildFlags() const2805 VkBuildAccelerationStructureFlagsKHR TopLevelAccelerationStructureKHR::getBuildFlags() const
2806 {
2807     return m_buildFlags;
2808 }
2809 
sum() const2810 VkDeviceSize TopLevelAccelerationStructure::CreationSizes::sum() const
2811 {
2812     return structure + updateScratch + buildScratch + instancePointers + instancesBuffer;
2813 }
2814 
getCreationSizes(const DeviceInterface & vk,const VkDevice device,const VkDeviceSize structureSize,CreationSizes & sizes)2815 void TopLevelAccelerationStructureKHR::getCreationSizes(const DeviceInterface &vk, const VkDevice device,
2816                                                         const VkDeviceSize structureSize, CreationSizes &sizes)
2817 {
2818     // AS may be built from geometries using vkCmdBuildAccelerationStructureKHR / vkBuildAccelerationStructureKHR
2819     // or may be copied/compacted/deserialized from other AS ( in this case AS does not need geometries, but it needs to know its size before creation ).
2820     DE_ASSERT(!m_bottomLevelInstances.empty() != !(structureSize == 0)); // logical xor
2821 
2822     if (structureSize == 0)
2823     {
2824         VkAccelerationStructureGeometryKHR accelerationStructureGeometryKHR;
2825         const auto accelerationStructureGeometryKHRPtr = &accelerationStructureGeometryKHR;
2826         std::vector<uint32_t> maxPrimitiveCounts;
2827         prepareInstances(vk, device, accelerationStructureGeometryKHR, maxPrimitiveCounts);
2828 
2829         VkAccelerationStructureBuildGeometryInfoKHR accelerationStructureBuildGeometryInfoKHR = {
2830             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_GEOMETRY_INFO_KHR, //  VkStructureType sType;
2831             DE_NULL,                                                          //  const void* pNext;
2832             VK_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL_KHR,                     //  VkAccelerationStructureTypeKHR type;
2833             m_buildFlags,                                   //  VkBuildAccelerationStructureFlagsKHR flags;
2834             VK_BUILD_ACCELERATION_STRUCTURE_MODE_BUILD_KHR, //  VkBuildAccelerationStructureModeKHR mode;
2835             DE_NULL,                                        //  VkAccelerationStructureKHR srcAccelerationStructure;
2836             DE_NULL,                                        //  VkAccelerationStructureKHR dstAccelerationStructure;
2837             1u,                                             //  uint32_t geometryCount;
2838             (m_usePPGeometries ?
2839                  nullptr :
2840                  &accelerationStructureGeometryKHR), //  const VkAccelerationStructureGeometryKHR* pGeometries;
2841             (m_usePPGeometries ? &accelerationStructureGeometryKHRPtr :
2842                                  nullptr),      //  const VkAccelerationStructureGeometryKHR* const* ppGeometries;
2843             makeDeviceOrHostAddressKHR(DE_NULL) //  VkDeviceOrHostAddressKHR scratchData;
2844         };
2845 
2846         VkAccelerationStructureBuildSizesInfoKHR sizeInfo = {
2847             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_SIZES_INFO_KHR, //  VkStructureType sType;
2848             DE_NULL,                                                       //  const void* pNext;
2849             0,                                                             //  VkDeviceSize accelerationStructureSize;
2850             0,                                                             //  VkDeviceSize updateScratchSize;
2851             0                                                              //  VkDeviceSize buildScratchSize;
2852         };
2853 
2854         vk.getAccelerationStructureBuildSizesKHR(device, m_buildType, &accelerationStructureBuildGeometryInfoKHR,
2855                                                  maxPrimitiveCounts.data(), &sizeInfo);
2856 
2857         sizes.structure     = sizeInfo.accelerationStructureSize;
2858         sizes.updateScratch = sizeInfo.updateScratchSize;
2859         sizes.buildScratch  = sizeInfo.buildScratchSize;
2860     }
2861     else
2862     {
2863         sizes.structure     = structureSize;
2864         sizes.updateScratch = 0u;
2865         sizes.buildScratch  = 0u;
2866     }
2867 
2868     sizes.instancePointers = 0u;
2869     if (m_useArrayOfPointers)
2870     {
2871         const size_t pointerSize = (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR) ?
2872                                        sizeof(VkDeviceOrHostAddressConstKHR::deviceAddress) :
2873                                        sizeof(VkDeviceOrHostAddressConstKHR::hostAddress);
2874         sizes.instancePointers   = static_cast<VkDeviceSize>(m_bottomLevelInstances.size() * pointerSize);
2875     }
2876 
2877     sizes.instancesBuffer = m_bottomLevelInstances.empty() ?
2878                                 0u :
2879                                 m_bottomLevelInstances.size() * sizeof(VkAccelerationStructureInstanceKHR);
2880 }
2881 
create(const DeviceInterface & vk,const VkDevice device,Allocator & allocator,VkDeviceSize structureSize,VkDeviceAddress deviceAddress,const void * pNext,const MemoryRequirement & addMemoryRequirement,const VkBuffer creationBuffer,const VkDeviceSize creationBufferSize)2882 void TopLevelAccelerationStructureKHR::create(const DeviceInterface &vk, const VkDevice device, Allocator &allocator,
2883                                               VkDeviceSize structureSize, VkDeviceAddress deviceAddress,
2884                                               const void *pNext, const MemoryRequirement &addMemoryRequirement,
2885                                               const VkBuffer creationBuffer, const VkDeviceSize creationBufferSize)
2886 {
2887     // AS may be built from geometries using vkCmdBuildAccelerationStructureKHR / vkBuildAccelerationStructureKHR
2888     // or may be copied/compacted/deserialized from other AS ( in this case AS does not need geometries, but it needs to know its size before creation ).
2889     DE_ASSERT(!m_bottomLevelInstances.empty() != !(structureSize == 0)); // logical xor
2890 
2891     if (structureSize == 0)
2892     {
2893         VkAccelerationStructureGeometryKHR accelerationStructureGeometryKHR;
2894         const auto accelerationStructureGeometryKHRPtr = &accelerationStructureGeometryKHR;
2895         std::vector<uint32_t> maxPrimitiveCounts;
2896         prepareInstances(vk, device, accelerationStructureGeometryKHR, maxPrimitiveCounts);
2897 
2898         VkAccelerationStructureBuildGeometryInfoKHR accelerationStructureBuildGeometryInfoKHR = {
2899             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_GEOMETRY_INFO_KHR, //  VkStructureType sType;
2900             DE_NULL,                                                          //  const void* pNext;
2901             VK_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL_KHR,                     //  VkAccelerationStructureTypeKHR type;
2902             m_buildFlags,                                   //  VkBuildAccelerationStructureFlagsKHR flags;
2903             VK_BUILD_ACCELERATION_STRUCTURE_MODE_BUILD_KHR, //  VkBuildAccelerationStructureModeKHR mode;
2904             DE_NULL,                                        //  VkAccelerationStructureKHR srcAccelerationStructure;
2905             DE_NULL,                                        //  VkAccelerationStructureKHR dstAccelerationStructure;
2906             1u,                                             //  uint32_t geometryCount;
2907             (m_usePPGeometries ?
2908                  nullptr :
2909                  &accelerationStructureGeometryKHR), //  const VkAccelerationStructureGeometryKHR* pGeometries;
2910             (m_usePPGeometries ? &accelerationStructureGeometryKHRPtr :
2911                                  nullptr),      //  const VkAccelerationStructureGeometryKHR* const* ppGeometries;
2912             makeDeviceOrHostAddressKHR(DE_NULL) //  VkDeviceOrHostAddressKHR scratchData;
2913         };
2914 
2915         VkAccelerationStructureBuildSizesInfoKHR sizeInfo = {
2916             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_SIZES_INFO_KHR, //  VkStructureType sType;
2917             DE_NULL,                                                       //  const void* pNext;
2918             0,                                                             //  VkDeviceSize accelerationStructureSize;
2919             0,                                                             //  VkDeviceSize updateScratchSize;
2920             0                                                              //  VkDeviceSize buildScratchSize;
2921         };
2922 
2923         vk.getAccelerationStructureBuildSizesKHR(device, m_buildType, &accelerationStructureBuildGeometryInfoKHR,
2924                                                  maxPrimitiveCounts.data(), &sizeInfo);
2925 
2926         m_structureSize     = sizeInfo.accelerationStructureSize;
2927         m_updateScratchSize = sizeInfo.updateScratchSize;
2928         m_buildScratchSize  = sizeInfo.buildScratchSize;
2929     }
2930     else
2931     {
2932         m_structureSize     = structureSize;
2933         m_updateScratchSize = 0u;
2934         m_buildScratchSize  = 0u;
2935     }
2936 
2937     const bool externalCreationBuffer = (creationBuffer != VK_NULL_HANDLE);
2938 
2939     if (externalCreationBuffer)
2940     {
2941         DE_UNREF(creationBufferSize); // For release builds.
2942         DE_ASSERT(creationBufferSize >= m_structureSize);
2943     }
2944 
2945     if (!externalCreationBuffer)
2946     {
2947         const VkBufferCreateInfo bufferCreateInfo =
2948             makeBufferCreateInfo(m_structureSize, VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_STORAGE_BIT_KHR |
2949                                                       VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
2950         const MemoryRequirement memoryRequirement = addMemoryRequirement | MemoryRequirement::HostVisible |
2951                                                     MemoryRequirement::Coherent | MemoryRequirement::DeviceAddress;
2952         const bool bindMemOnCreation = (!m_creationBufferUnbounded);
2953 
2954         try
2955         {
2956             m_accelerationStructureBuffer = de::MovePtr<BufferWithMemory>(
2957                 new BufferWithMemory(vk, device, allocator, bufferCreateInfo,
2958                                      (MemoryRequirement::Cached | memoryRequirement), bindMemOnCreation));
2959         }
2960         catch (const tcu::NotSupportedError &)
2961         {
2962             // retry without Cached flag
2963             m_accelerationStructureBuffer = de::MovePtr<BufferWithMemory>(
2964                 new BufferWithMemory(vk, device, allocator, bufferCreateInfo, memoryRequirement, bindMemOnCreation));
2965         }
2966     }
2967 
2968     const auto createInfoBuffer = (externalCreationBuffer ? creationBuffer : m_accelerationStructureBuffer->get());
2969     {
2970         const VkAccelerationStructureTypeKHR structureType =
2971             (m_createGeneric ? VK_ACCELERATION_STRUCTURE_TYPE_GENERIC_KHR :
2972                                VK_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL_KHR);
2973         const VkAccelerationStructureCreateInfoKHR accelerationStructureCreateInfoKHR = {
2974             VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_CREATE_INFO_KHR, //  VkStructureType sType;
2975             pNext,                                                    //  const void* pNext;
2976             m_createFlags,    //  VkAccelerationStructureCreateFlagsKHR createFlags;
2977             createInfoBuffer, //  VkBuffer buffer;
2978             0u,               //  VkDeviceSize offset;
2979             m_structureSize,  //  VkDeviceSize size;
2980             structureType,    //  VkAccelerationStructureTypeKHR type;
2981             deviceAddress     //  VkDeviceAddress deviceAddress;
2982         };
2983 
2984         m_accelerationStructureKHR =
2985             createAccelerationStructureKHR(vk, device, &accelerationStructureCreateInfoKHR, DE_NULL);
2986 
2987         // Make sure buffer memory is always bound after creation.
2988         if (!externalCreationBuffer)
2989             m_accelerationStructureBuffer->bindMemory();
2990     }
2991 
2992     if (m_buildScratchSize > 0u || m_updateScratchSize > 0u)
2993     {
2994         VkDeviceSize scratch_size = de::max(m_buildScratchSize, m_updateScratchSize);
2995         if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
2996         {
2997             const VkBufferCreateInfo bufferCreateInfo = makeBufferCreateInfo(
2998                 scratch_size, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT | VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
2999             m_deviceScratchBuffer = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
3000                 vk, device, allocator, bufferCreateInfo,
3001                 MemoryRequirement::HostVisible | MemoryRequirement::Coherent | MemoryRequirement::DeviceAddress));
3002         }
3003         else
3004         {
3005             m_hostScratchBuffer.resize(static_cast<size_t>(scratch_size));
3006         }
3007     }
3008 
3009     if (m_useArrayOfPointers)
3010     {
3011         const size_t pointerSize = (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR) ?
3012                                        sizeof(VkDeviceOrHostAddressConstKHR::deviceAddress) :
3013                                        sizeof(VkDeviceOrHostAddressConstKHR::hostAddress);
3014         const VkBufferCreateInfo bufferCreateInfo =
3015             makeBufferCreateInfo(static_cast<VkDeviceSize>(m_bottomLevelInstances.size() * pointerSize),
3016                                  VK_BUFFER_USAGE_ACCELERATION_STRUCTURE_BUILD_INPUT_READ_ONLY_BIT_KHR |
3017                                      VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT);
3018         m_instanceAddressBuffer = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
3019             vk, device, allocator, bufferCreateInfo,
3020             MemoryRequirement::HostVisible | MemoryRequirement::Coherent | MemoryRequirement::DeviceAddress));
3021     }
3022 
3023     if (!m_bottomLevelInstances.empty())
3024         m_instanceBuffer = de::MovePtr<BufferWithMemory>(
3025             createInstanceBuffer(vk, device, allocator, m_bottomLevelInstances, m_instanceData, m_tryCachedMemory));
3026 }
3027 
updateInstanceMatrix(const DeviceInterface & vk,const VkDevice device,size_t instanceIndex,const VkTransformMatrixKHR & matrix)3028 void TopLevelAccelerationStructureKHR::updateInstanceMatrix(const DeviceInterface &vk, const VkDevice device,
3029                                                             size_t instanceIndex, const VkTransformMatrixKHR &matrix)
3030 {
3031     DE_ASSERT(instanceIndex < m_bottomLevelInstances.size());
3032     DE_ASSERT(instanceIndex < m_instanceData.size());
3033 
3034     const auto &blas          = *m_bottomLevelInstances[instanceIndex];
3035     auto &instanceData        = m_instanceData[instanceIndex];
3036     auto &instancesAlloc      = m_instanceBuffer->getAllocation();
3037     auto bufferStart          = reinterpret_cast<uint8_t *>(instancesAlloc.getHostPtr());
3038     VkDeviceSize bufferOffset = sizeof(VkAccelerationStructureInstanceKHR) * instanceIndex;
3039 
3040     instanceData.matrix = matrix;
3041     updateSingleInstance(vk, device, blas, instanceData, bufferStart + bufferOffset, m_buildType, m_inactiveInstances);
3042     flushMappedMemoryRange(vk, device, instancesAlloc.getMemory(), instancesAlloc.getOffset(), VK_WHOLE_SIZE);
3043 }
3044 
build(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,TopLevelAccelerationStructure * srcAccelerationStructure)3045 void TopLevelAccelerationStructureKHR::build(const DeviceInterface &vk, const VkDevice device,
3046                                              const VkCommandBuffer cmdBuffer,
3047                                              TopLevelAccelerationStructure *srcAccelerationStructure)
3048 {
3049     DE_ASSERT(!m_bottomLevelInstances.empty());
3050     DE_ASSERT(m_accelerationStructureKHR.get() != DE_NULL);
3051     DE_ASSERT(m_buildScratchSize != 0);
3052 
3053     updateInstanceBuffer(vk, device, m_bottomLevelInstances, m_instanceData, m_instanceBuffer.get(), m_buildType,
3054                          m_inactiveInstances);
3055 
3056     VkAccelerationStructureGeometryKHR accelerationStructureGeometryKHR;
3057     const auto accelerationStructureGeometryKHRPtr = &accelerationStructureGeometryKHR;
3058     std::vector<uint32_t> maxPrimitiveCounts;
3059     prepareInstances(vk, device, accelerationStructureGeometryKHR, maxPrimitiveCounts);
3060 
3061     VkDeviceOrHostAddressKHR scratchData = (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR) ?
3062                                                makeDeviceOrHostAddressKHR(vk, device, m_deviceScratchBuffer->get(), 0) :
3063                                                makeDeviceOrHostAddressKHR(m_hostScratchBuffer.data());
3064 
3065     VkAccelerationStructureKHR srcStructure =
3066         (srcAccelerationStructure != DE_NULL) ? *(srcAccelerationStructure->getPtr()) : DE_NULL;
3067     VkBuildAccelerationStructureModeKHR mode = (srcAccelerationStructure != DE_NULL) ?
3068                                                    VK_BUILD_ACCELERATION_STRUCTURE_MODE_UPDATE_KHR :
3069                                                    VK_BUILD_ACCELERATION_STRUCTURE_MODE_BUILD_KHR;
3070 
3071     VkAccelerationStructureBuildGeometryInfoKHR accelerationStructureBuildGeometryInfoKHR = {
3072         VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_BUILD_GEOMETRY_INFO_KHR, //  VkStructureType sType;
3073         DE_NULL,                                                          //  const void* pNext;
3074         VK_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL_KHR,                     //  VkAccelerationStructureTypeKHR type;
3075         m_buildFlags,                     //  VkBuildAccelerationStructureFlagsKHR flags;
3076         mode,                             //  VkBuildAccelerationStructureModeKHR mode;
3077         srcStructure,                     //  VkAccelerationStructureKHR srcAccelerationStructure;
3078         m_accelerationStructureKHR.get(), //  VkAccelerationStructureKHR dstAccelerationStructure;
3079         1u,                               //  uint32_t geometryCount;
3080         (m_usePPGeometries ?
3081              nullptr :
3082              &accelerationStructureGeometryKHR), //  const VkAccelerationStructureGeometryKHR* pGeometries;
3083         (m_usePPGeometries ? &accelerationStructureGeometryKHRPtr :
3084                              nullptr), //  const VkAccelerationStructureGeometryKHR* const* ppGeometries;
3085         scratchData                    //  VkDeviceOrHostAddressKHR scratchData;
3086     };
3087 
3088     const uint32_t primitiveCount =
3089         (m_buildWithoutPrimitives ? 0u : static_cast<uint32_t>(m_bottomLevelInstances.size()));
3090 
3091     VkAccelerationStructureBuildRangeInfoKHR accelerationStructureBuildRangeInfoKHR = {
3092         primitiveCount, //  uint32_t primitiveCount;
3093         0,              //  uint32_t primitiveOffset;
3094         0,              //  uint32_t firstVertex;
3095         0               //  uint32_t transformOffset;
3096     };
3097     VkAccelerationStructureBuildRangeInfoKHR *accelerationStructureBuildRangeInfoKHRPtr =
3098         &accelerationStructureBuildRangeInfoKHR;
3099 
3100     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3101     {
3102         if (m_indirectBuffer == DE_NULL)
3103             vk.cmdBuildAccelerationStructuresKHR(
3104                 cmdBuffer, 1u, &accelerationStructureBuildGeometryInfoKHR,
3105                 (const VkAccelerationStructureBuildRangeInfoKHR **)&accelerationStructureBuildRangeInfoKHRPtr);
3106         else
3107         {
3108             VkDeviceAddress indirectDeviceAddress =
3109                 getBufferDeviceAddress(vk, device, m_indirectBuffer, m_indirectBufferOffset);
3110             uint32_t *pMaxPrimitiveCounts = maxPrimitiveCounts.data();
3111             vk.cmdBuildAccelerationStructuresIndirectKHR(cmdBuffer, 1u, &accelerationStructureBuildGeometryInfoKHR,
3112                                                          &indirectDeviceAddress, &m_indirectBufferStride,
3113                                                          &pMaxPrimitiveCounts);
3114         }
3115     }
3116     else if (!m_deferredOperation)
3117     {
3118         VK_CHECK(vk.buildAccelerationStructuresKHR(
3119             device, DE_NULL, 1u, &accelerationStructureBuildGeometryInfoKHR,
3120             (const VkAccelerationStructureBuildRangeInfoKHR **)&accelerationStructureBuildRangeInfoKHRPtr));
3121     }
3122     else
3123     {
3124         const auto deferredOperationPtr = createDeferredOperationKHR(vk, device);
3125         const auto deferredOperation    = deferredOperationPtr.get();
3126 
3127         VkResult result = vk.buildAccelerationStructuresKHR(
3128             device, deferredOperation, 1u, &accelerationStructureBuildGeometryInfoKHR,
3129             (const VkAccelerationStructureBuildRangeInfoKHR **)&accelerationStructureBuildRangeInfoKHRPtr);
3130 
3131         DE_ASSERT(result == VK_OPERATION_DEFERRED_KHR || result == VK_OPERATION_NOT_DEFERRED_KHR ||
3132                   result == VK_SUCCESS);
3133 
3134         finishDeferredOperation(vk, device, deferredOperation, m_workerThreadCount,
3135                                 result == VK_OPERATION_NOT_DEFERRED_KHR);
3136 
3137         accelerationStructureBuildGeometryInfoKHR.pNext = DE_NULL;
3138     }
3139 
3140     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3141     {
3142         const VkAccessFlags accessMasks =
3143             VK_ACCESS_ACCELERATION_STRUCTURE_WRITE_BIT_KHR | VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR;
3144         const VkMemoryBarrier memBarrier = makeMemoryBarrier(accessMasks, accessMasks);
3145 
3146         cmdPipelineMemoryBarrier(vk, cmdBuffer, VK_PIPELINE_STAGE_ACCELERATION_STRUCTURE_BUILD_BIT_KHR,
3147                                  VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, &memBarrier);
3148     }
3149 }
3150 
copyFrom(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,TopLevelAccelerationStructure * accelerationStructure,bool compactCopy)3151 void TopLevelAccelerationStructureKHR::copyFrom(const DeviceInterface &vk, const VkDevice device,
3152                                                 const VkCommandBuffer cmdBuffer,
3153                                                 TopLevelAccelerationStructure *accelerationStructure, bool compactCopy)
3154 {
3155     DE_ASSERT(m_accelerationStructureKHR.get() != DE_NULL);
3156     DE_ASSERT(accelerationStructure != DE_NULL);
3157 
3158     VkCopyAccelerationStructureInfoKHR copyAccelerationStructureInfo = {
3159         VK_STRUCTURE_TYPE_COPY_ACCELERATION_STRUCTURE_INFO_KHR, // VkStructureType sType;
3160         DE_NULL,                                                // const void* pNext;
3161         *(accelerationStructure->getPtr()),                     // VkAccelerationStructureKHR src;
3162         *(getPtr()),                                            // VkAccelerationStructureKHR dst;
3163         compactCopy ? VK_COPY_ACCELERATION_STRUCTURE_MODE_COMPACT_KHR :
3164                       VK_COPY_ACCELERATION_STRUCTURE_MODE_CLONE_KHR // VkCopyAccelerationStructureModeKHR mode;
3165     };
3166 
3167     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3168     {
3169         vk.cmdCopyAccelerationStructureKHR(cmdBuffer, &copyAccelerationStructureInfo);
3170     }
3171     else if (!m_deferredOperation)
3172     {
3173         VK_CHECK(vk.copyAccelerationStructureKHR(device, DE_NULL, &copyAccelerationStructureInfo));
3174     }
3175     else
3176     {
3177         const auto deferredOperationPtr = createDeferredOperationKHR(vk, device);
3178         const auto deferredOperation    = deferredOperationPtr.get();
3179 
3180         VkResult result = vk.copyAccelerationStructureKHR(device, deferredOperation, &copyAccelerationStructureInfo);
3181 
3182         DE_ASSERT(result == VK_OPERATION_DEFERRED_KHR || result == VK_OPERATION_NOT_DEFERRED_KHR ||
3183                   result == VK_SUCCESS);
3184 
3185         finishDeferredOperation(vk, device, deferredOperation, m_workerThreadCount,
3186                                 result == VK_OPERATION_NOT_DEFERRED_KHR);
3187     }
3188 
3189     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3190     {
3191         const VkAccessFlags accessMasks =
3192             VK_ACCESS_ACCELERATION_STRUCTURE_WRITE_BIT_KHR | VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR;
3193         const VkMemoryBarrier memBarrier = makeMemoryBarrier(accessMasks, accessMasks);
3194 
3195         cmdPipelineMemoryBarrier(vk, cmdBuffer, VK_PIPELINE_STAGE_ACCELERATION_STRUCTURE_BUILD_BIT_KHR,
3196                                  VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, &memBarrier);
3197     }
3198 }
3199 
serialize(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,SerialStorage * storage)3200 void TopLevelAccelerationStructureKHR::serialize(const DeviceInterface &vk, const VkDevice device,
3201                                                  const VkCommandBuffer cmdBuffer, SerialStorage *storage)
3202 {
3203     DE_ASSERT(m_accelerationStructureKHR.get() != DE_NULL);
3204     DE_ASSERT(storage != DE_NULL);
3205 
3206     const VkCopyAccelerationStructureToMemoryInfoKHR copyAccelerationStructureInfo = {
3207         VK_STRUCTURE_TYPE_COPY_ACCELERATION_STRUCTURE_TO_MEMORY_INFO_KHR, // VkStructureType sType;
3208         DE_NULL,                                                          // const void* pNext;
3209         *(getPtr()),                                                      // VkAccelerationStructureKHR src;
3210         storage->getAddress(vk, device, m_buildType),                     // VkDeviceOrHostAddressKHR dst;
3211         VK_COPY_ACCELERATION_STRUCTURE_MODE_SERIALIZE_KHR                 // VkCopyAccelerationStructureModeKHR mode;
3212     };
3213 
3214     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3215     {
3216         vk.cmdCopyAccelerationStructureToMemoryKHR(cmdBuffer, &copyAccelerationStructureInfo);
3217         if (storage->hasDeepFormat())
3218             serializeBottoms(vk, device, cmdBuffer, storage, DE_NULL);
3219     }
3220     else if (!m_deferredOperation)
3221     {
3222         VK_CHECK(vk.copyAccelerationStructureToMemoryKHR(device, DE_NULL, &copyAccelerationStructureInfo));
3223         if (storage->hasDeepFormat())
3224             serializeBottoms(vk, device, cmdBuffer, storage, DE_NULL);
3225     }
3226     else
3227     {
3228         const auto deferredOperationPtr = createDeferredOperationKHR(vk, device);
3229         const auto deferredOperation    = deferredOperationPtr.get();
3230 
3231         const VkResult result =
3232             vk.copyAccelerationStructureToMemoryKHR(device, deferredOperation, &copyAccelerationStructureInfo);
3233 
3234         DE_ASSERT(result == VK_OPERATION_DEFERRED_KHR || result == VK_OPERATION_NOT_DEFERRED_KHR ||
3235                   result == VK_SUCCESS);
3236         if (storage->hasDeepFormat())
3237             serializeBottoms(vk, device, cmdBuffer, storage, deferredOperation);
3238 
3239         finishDeferredOperation(vk, device, deferredOperation, m_workerThreadCount,
3240                                 result == VK_OPERATION_NOT_DEFERRED_KHR);
3241     }
3242 }
3243 
deserialize(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,SerialStorage * storage)3244 void TopLevelAccelerationStructureKHR::deserialize(const DeviceInterface &vk, const VkDevice device,
3245                                                    const VkCommandBuffer cmdBuffer, SerialStorage *storage)
3246 {
3247     DE_ASSERT(m_accelerationStructureKHR.get() != DE_NULL);
3248     DE_ASSERT(storage != DE_NULL);
3249 
3250     const VkCopyMemoryToAccelerationStructureInfoKHR copyAccelerationStructureInfo = {
3251         VK_STRUCTURE_TYPE_COPY_MEMORY_TO_ACCELERATION_STRUCTURE_INFO_KHR, // VkStructureType sType;
3252         DE_NULL,                                                          // const void* pNext;
3253         storage->getAddressConst(vk, device, m_buildType),                // VkDeviceOrHostAddressConstKHR src;
3254         *(getPtr()),                                                      // VkAccelerationStructureKHR dst;
3255         VK_COPY_ACCELERATION_STRUCTURE_MODE_DESERIALIZE_KHR               // VkCopyAccelerationStructureModeKHR mode;
3256     };
3257 
3258     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3259     {
3260         vk.cmdCopyMemoryToAccelerationStructureKHR(cmdBuffer, &copyAccelerationStructureInfo);
3261     }
3262     else if (!m_deferredOperation)
3263     {
3264         VK_CHECK(vk.copyMemoryToAccelerationStructureKHR(device, DE_NULL, &copyAccelerationStructureInfo));
3265     }
3266     else
3267     {
3268         const auto deferredOperationPtr = createDeferredOperationKHR(vk, device);
3269         const auto deferredOperation    = deferredOperationPtr.get();
3270 
3271         const VkResult result =
3272             vk.copyMemoryToAccelerationStructureKHR(device, deferredOperation, &copyAccelerationStructureInfo);
3273 
3274         DE_ASSERT(result == VK_OPERATION_DEFERRED_KHR || result == VK_OPERATION_NOT_DEFERRED_KHR ||
3275                   result == VK_SUCCESS);
3276 
3277         finishDeferredOperation(vk, device, deferredOperation, m_workerThreadCount,
3278                                 result == VK_OPERATION_NOT_DEFERRED_KHR);
3279     }
3280 
3281     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3282     {
3283         const VkAccessFlags accessMasks =
3284             VK_ACCESS_ACCELERATION_STRUCTURE_WRITE_BIT_KHR | VK_ACCESS_ACCELERATION_STRUCTURE_READ_BIT_KHR;
3285         const VkMemoryBarrier memBarrier = makeMemoryBarrier(accessMasks, accessMasks);
3286 
3287         cmdPipelineMemoryBarrier(vk, cmdBuffer, VK_PIPELINE_STAGE_ACCELERATION_STRUCTURE_BUILD_BIT_KHR,
3288                                  VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, &memBarrier);
3289     }
3290 }
3291 
serializeBottoms(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,SerialStorage * storage,VkDeferredOperationKHR deferredOperation)3292 void TopLevelAccelerationStructureKHR::serializeBottoms(const DeviceInterface &vk, const VkDevice device,
3293                                                         const VkCommandBuffer cmdBuffer, SerialStorage *storage,
3294                                                         VkDeferredOperationKHR deferredOperation)
3295 {
3296     DE_UNREF(deferredOperation);
3297     DE_ASSERT(storage->hasDeepFormat());
3298 
3299     const std::vector<uint64_t> &addresses = storage->getSerialInfo().addresses();
3300     const std::size_t cbottoms             = m_bottomLevelInstances.size();
3301 
3302     uint32_t storageIndex = 0;
3303     std::vector<uint64_t> matches;
3304 
3305     for (std::size_t i = 0; i < cbottoms; ++i)
3306     {
3307         const uint64_t &lookAddr = addresses[i + 1];
3308         auto end                 = matches.end();
3309         auto match = std::find_if(matches.begin(), end, [&](const uint64_t &item) { return item == lookAddr; });
3310         if (match == end)
3311         {
3312             matches.emplace_back(lookAddr);
3313             m_bottomLevelInstances[i].get()->serialize(vk, device, cmdBuffer,
3314                                                        storage->getBottomStorage(storageIndex).get());
3315             storageIndex += 1;
3316         }
3317     }
3318 }
3319 
createAndDeserializeBottoms(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,Allocator & allocator,SerialStorage * storage)3320 void TopLevelAccelerationStructureKHR::createAndDeserializeBottoms(const DeviceInterface &vk, const VkDevice device,
3321                                                                    const VkCommandBuffer cmdBuffer,
3322                                                                    Allocator &allocator, SerialStorage *storage)
3323 {
3324     DE_ASSERT(storage->hasDeepFormat());
3325     DE_ASSERT(m_bottomLevelInstances.size() == 0);
3326 
3327     const std::vector<uint64_t> &addresses = storage->getSerialInfo().addresses();
3328     const std::size_t cbottoms             = addresses.size() - 1;
3329     uint32_t storageIndex                  = 0;
3330     std::vector<std::pair<uint64_t, std::size_t>> matches;
3331 
3332     for (std::size_t i = 0; i < cbottoms; ++i)
3333     {
3334         const uint64_t &lookAddr = addresses[i + 1];
3335         auto end                 = matches.end();
3336         auto match               = std::find_if(matches.begin(), end,
3337                                                 [&](const std::pair<uint64_t, std::size_t> &item) { return item.first == lookAddr; });
3338         if (match != end)
3339         {
3340             m_bottomLevelInstances.emplace_back(m_bottomLevelInstances[match->second]);
3341         }
3342         else
3343         {
3344             de::MovePtr<BottomLevelAccelerationStructure> blas = makeBottomLevelAccelerationStructure();
3345             blas->createAndDeserializeFrom(vk, device, cmdBuffer, allocator,
3346                                            storage->getBottomStorage(storageIndex).get());
3347             m_bottomLevelInstances.emplace_back(de::SharedPtr<BottomLevelAccelerationStructure>(blas.release()));
3348             matches.emplace_back(lookAddr, i);
3349             storageIndex += 1;
3350         }
3351     }
3352 
3353     std::vector<uint64_t> newAddresses = getSerializingAddresses(vk, device);
3354     DE_ASSERT(addresses.size() == newAddresses.size());
3355 
3356     SerialStorage::AccelerationStructureHeader *header = storage->getASHeader();
3357     DE_ASSERT(cbottoms == header->handleCount);
3358 
3359     // finally update bottom-level AS addresses before top-level AS deserialization
3360     for (std::size_t i = 0; i < cbottoms; ++i)
3361     {
3362         header->handleArray[i] = newAddresses[i + 1];
3363     }
3364 }
3365 
getSerializingSizes(const DeviceInterface & vk,const VkDevice device,const VkQueue queue,const uint32_t queueFamilyIndex)3366 std::vector<VkDeviceSize> TopLevelAccelerationStructureKHR::getSerializingSizes(const DeviceInterface &vk,
3367                                                                                 const VkDevice device,
3368                                                                                 const VkQueue queue,
3369                                                                                 const uint32_t queueFamilyIndex)
3370 {
3371     const uint32_t queryCount(uint32_t(m_bottomLevelInstances.size()) + 1);
3372     std::vector<VkAccelerationStructureKHR> handles(queryCount);
3373     std::vector<VkDeviceSize> sizes(queryCount);
3374 
3375     handles[0] = m_accelerationStructureKHR.get();
3376 
3377     for (uint32_t h = 1; h < queryCount; ++h)
3378         handles[h] = *m_bottomLevelInstances[h - 1].get()->getPtr();
3379 
3380     if (VK_ACCELERATION_STRUCTURE_BUILD_TYPE_HOST_KHR == m_buildType)
3381         queryAccelerationStructureSize(vk, device, DE_NULL, handles, m_buildType, DE_NULL,
3382                                        VK_QUERY_TYPE_ACCELERATION_STRUCTURE_SERIALIZATION_SIZE_KHR, 0u, sizes);
3383     else
3384     {
3385         const Move<VkCommandPool> cmdPool = createCommandPool(vk, device, 0, queueFamilyIndex);
3386         const Move<VkCommandBuffer> cmdBuffer =
3387             allocateCommandBuffer(vk, device, *cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
3388         const Move<VkQueryPool> queryPool =
3389             makeQueryPool(vk, device, VK_QUERY_TYPE_ACCELERATION_STRUCTURE_SERIALIZATION_SIZE_KHR, queryCount);
3390 
3391         beginCommandBuffer(vk, *cmdBuffer);
3392         queryAccelerationStructureSize(vk, device, *cmdBuffer, handles, m_buildType, *queryPool,
3393                                        VK_QUERY_TYPE_ACCELERATION_STRUCTURE_SERIALIZATION_SIZE_KHR, 0u, sizes);
3394         endCommandBuffer(vk, *cmdBuffer);
3395         submitCommandsAndWait(vk, device, queue, cmdBuffer.get());
3396 
3397         VK_CHECK(vk.getQueryPoolResults(device, *queryPool, 0u, queryCount, queryCount * sizeof(VkDeviceSize),
3398                                         sizes.data(), sizeof(VkDeviceSize),
3399                                         VK_QUERY_RESULT_64_BIT | VK_QUERY_RESULT_WAIT_BIT));
3400     }
3401 
3402     return sizes;
3403 }
3404 
getSerializingAddresses(const DeviceInterface & vk,const VkDevice device) const3405 std::vector<uint64_t> TopLevelAccelerationStructureKHR::getSerializingAddresses(const DeviceInterface &vk,
3406                                                                                 const VkDevice device) const
3407 {
3408     std::vector<uint64_t> result(m_bottomLevelInstances.size() + 1);
3409 
3410     VkAccelerationStructureDeviceAddressInfoKHR asDeviceAddressInfo = {
3411         VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_DEVICE_ADDRESS_INFO_KHR, // VkStructureType sType;
3412         DE_NULL,                                                          // const void* pNext;
3413         DE_NULL // VkAccelerationStructureKHR accelerationStructure;
3414     };
3415 
3416     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3417     {
3418         asDeviceAddressInfo.accelerationStructure = m_accelerationStructureKHR.get();
3419         result[0] = vk.getAccelerationStructureDeviceAddressKHR(device, &asDeviceAddressInfo);
3420     }
3421     else
3422     {
3423         result[0] = uint64_t(getPtr()->getInternal());
3424     }
3425 
3426     for (size_t instanceNdx = 0; instanceNdx < m_bottomLevelInstances.size(); ++instanceNdx)
3427     {
3428         const BottomLevelAccelerationStructure &bottomLevelAccelerationStructure = *m_bottomLevelInstances[instanceNdx];
3429         const VkAccelerationStructureKHR accelerationStructureKHR = *bottomLevelAccelerationStructure.getPtr();
3430 
3431         if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3432         {
3433             asDeviceAddressInfo.accelerationStructure = accelerationStructureKHR;
3434             result[instanceNdx + 1] = vk.getAccelerationStructureDeviceAddressKHR(device, &asDeviceAddressInfo);
3435         }
3436         else
3437         {
3438             result[instanceNdx + 1] = uint64_t(accelerationStructureKHR.getInternal());
3439         }
3440     }
3441 
3442     return result;
3443 }
3444 
getPtr(void) const3445 const VkAccelerationStructureKHR *TopLevelAccelerationStructureKHR::getPtr(void) const
3446 {
3447     return &m_accelerationStructureKHR.get();
3448 }
3449 
prepareInstances(const DeviceInterface & vk,const VkDevice device,VkAccelerationStructureGeometryKHR & accelerationStructureGeometryKHR,std::vector<uint32_t> & maxPrimitiveCounts)3450 void TopLevelAccelerationStructureKHR::prepareInstances(
3451     const DeviceInterface &vk, const VkDevice device,
3452     VkAccelerationStructureGeometryKHR &accelerationStructureGeometryKHR, std::vector<uint32_t> &maxPrimitiveCounts)
3453 {
3454     maxPrimitiveCounts.resize(1);
3455     maxPrimitiveCounts[0] = static_cast<uint32_t>(m_bottomLevelInstances.size());
3456 
3457     VkDeviceOrHostAddressConstKHR instancesData;
3458     if (m_buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3459     {
3460         if (m_instanceBuffer.get() != DE_NULL)
3461         {
3462             if (m_useArrayOfPointers)
3463             {
3464                 uint8_t *bufferStart = static_cast<uint8_t *>(m_instanceAddressBuffer->getAllocation().getHostPtr());
3465                 VkDeviceSize bufferOffset = 0;
3466                 VkDeviceOrHostAddressConstKHR firstInstance =
3467                     makeDeviceOrHostAddressConstKHR(vk, device, m_instanceBuffer->get(), 0);
3468                 for (size_t instanceNdx = 0; instanceNdx < m_bottomLevelInstances.size(); ++instanceNdx)
3469                 {
3470                     VkDeviceOrHostAddressConstKHR currentInstance;
3471                     currentInstance.deviceAddress =
3472                         firstInstance.deviceAddress + instanceNdx * sizeof(VkAccelerationStructureInstanceKHR);
3473 
3474                     deMemcpy(&bufferStart[bufferOffset], &currentInstance,
3475                              sizeof(VkDeviceOrHostAddressConstKHR::deviceAddress));
3476                     bufferOffset += sizeof(VkDeviceOrHostAddressConstKHR::deviceAddress);
3477                 }
3478                 flushMappedMemoryRange(vk, device, m_instanceAddressBuffer->getAllocation().getMemory(),
3479                                        m_instanceAddressBuffer->getAllocation().getOffset(), VK_WHOLE_SIZE);
3480 
3481                 instancesData = makeDeviceOrHostAddressConstKHR(vk, device, m_instanceAddressBuffer->get(), 0);
3482             }
3483             else
3484                 instancesData = makeDeviceOrHostAddressConstKHR(vk, device, m_instanceBuffer->get(), 0);
3485         }
3486         else
3487             instancesData = makeDeviceOrHostAddressConstKHR(DE_NULL);
3488     }
3489     else
3490     {
3491         if (m_instanceBuffer.get() != DE_NULL)
3492         {
3493             if (m_useArrayOfPointers)
3494             {
3495                 uint8_t *bufferStart = static_cast<uint8_t *>(m_instanceAddressBuffer->getAllocation().getHostPtr());
3496                 VkDeviceSize bufferOffset = 0;
3497                 for (size_t instanceNdx = 0; instanceNdx < m_bottomLevelInstances.size(); ++instanceNdx)
3498                 {
3499                     VkDeviceOrHostAddressConstKHR currentInstance;
3500                     currentInstance.hostAddress = (uint8_t *)m_instanceBuffer->getAllocation().getHostPtr() +
3501                                                   instanceNdx * sizeof(VkAccelerationStructureInstanceKHR);
3502 
3503                     deMemcpy(&bufferStart[bufferOffset], &currentInstance,
3504                              sizeof(VkDeviceOrHostAddressConstKHR::hostAddress));
3505                     bufferOffset += sizeof(VkDeviceOrHostAddressConstKHR::hostAddress);
3506                 }
3507                 instancesData = makeDeviceOrHostAddressConstKHR(m_instanceAddressBuffer->getAllocation().getHostPtr());
3508             }
3509             else
3510                 instancesData = makeDeviceOrHostAddressConstKHR(m_instanceBuffer->getAllocation().getHostPtr());
3511         }
3512         else
3513             instancesData = makeDeviceOrHostAddressConstKHR(DE_NULL);
3514     }
3515 
3516     VkAccelerationStructureGeometryInstancesDataKHR accelerationStructureGeometryInstancesDataKHR = {
3517         VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_GEOMETRY_INSTANCES_DATA_KHR, //  VkStructureType sType;
3518         DE_NULL,                                                              //  const void* pNext;
3519         (VkBool32)(m_useArrayOfPointers ? true : false),                      //  VkBool32 arrayOfPointers;
3520         instancesData                                                         //  VkDeviceOrHostAddressConstKHR data;
3521     };
3522 
3523     accelerationStructureGeometryKHR = {
3524         VK_STRUCTURE_TYPE_ACCELERATION_STRUCTURE_GEOMETRY_KHR, //  VkStructureType sType;
3525         DE_NULL,                                               //  const void* pNext;
3526         VK_GEOMETRY_TYPE_INSTANCES_KHR,                        //  VkGeometryTypeKHR geometryType;
3527         makeVkAccelerationStructureInstancesDataKHR(
3528             accelerationStructureGeometryInstancesDataKHR), //  VkAccelerationStructureGeometryDataKHR geometry;
3529         (VkGeometryFlagsKHR)0u                              //  VkGeometryFlagsKHR flags;
3530     };
3531 }
3532 
getRequiredAllocationCount(void)3533 uint32_t TopLevelAccelerationStructure::getRequiredAllocationCount(void)
3534 {
3535     return TopLevelAccelerationStructureKHR::getRequiredAllocationCount();
3536 }
3537 
makeTopLevelAccelerationStructure()3538 de::MovePtr<TopLevelAccelerationStructure> makeTopLevelAccelerationStructure()
3539 {
3540     return de::MovePtr<TopLevelAccelerationStructure>(new TopLevelAccelerationStructureKHR);
3541 }
3542 
queryAccelerationStructureSizeKHR(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,const std::vector<VkAccelerationStructureKHR> & accelerationStructureHandles,VkAccelerationStructureBuildTypeKHR buildType,const VkQueryPool queryPool,VkQueryType queryType,uint32_t firstQuery,std::vector<VkDeviceSize> & results)3543 bool queryAccelerationStructureSizeKHR(const DeviceInterface &vk, const VkDevice device,
3544                                        const VkCommandBuffer cmdBuffer,
3545                                        const std::vector<VkAccelerationStructureKHR> &accelerationStructureHandles,
3546                                        VkAccelerationStructureBuildTypeKHR buildType, const VkQueryPool queryPool,
3547                                        VkQueryType queryType, uint32_t firstQuery, std::vector<VkDeviceSize> &results)
3548 {
3549     DE_ASSERT(queryType == VK_QUERY_TYPE_ACCELERATION_STRUCTURE_COMPACTED_SIZE_KHR ||
3550               queryType == VK_QUERY_TYPE_ACCELERATION_STRUCTURE_SERIALIZATION_SIZE_KHR);
3551 
3552     if (buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR)
3553     {
3554         // queryPool must be large enough to contain at least (firstQuery + accelerationStructureHandles.size()) queries
3555         vk.cmdResetQueryPool(cmdBuffer, queryPool, firstQuery, uint32_t(accelerationStructureHandles.size()));
3556         vk.cmdWriteAccelerationStructuresPropertiesKHR(cmdBuffer, uint32_t(accelerationStructureHandles.size()),
3557                                                        accelerationStructureHandles.data(), queryType, queryPool,
3558                                                        firstQuery);
3559         // results cannot be retrieved to CPU at the moment - you need to do it using getQueryPoolResults after cmdBuffer is executed. Meanwhile function returns a vector of 0s.
3560         results.resize(accelerationStructureHandles.size(), 0u);
3561         return false;
3562     }
3563     // buildType != VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR
3564     results.resize(accelerationStructureHandles.size(), 0u);
3565     vk.writeAccelerationStructuresPropertiesKHR(
3566         device, uint32_t(accelerationStructureHandles.size()), accelerationStructureHandles.data(), queryType,
3567         sizeof(VkDeviceSize) * accelerationStructureHandles.size(), results.data(), sizeof(VkDeviceSize));
3568     // results will contain proper values
3569     return true;
3570 }
3571 
queryAccelerationStructureSize(const DeviceInterface & vk,const VkDevice device,const VkCommandBuffer cmdBuffer,const std::vector<VkAccelerationStructureKHR> & accelerationStructureHandles,VkAccelerationStructureBuildTypeKHR buildType,const VkQueryPool queryPool,VkQueryType queryType,uint32_t firstQuery,std::vector<VkDeviceSize> & results)3572 bool queryAccelerationStructureSize(const DeviceInterface &vk, const VkDevice device, const VkCommandBuffer cmdBuffer,
3573                                     const std::vector<VkAccelerationStructureKHR> &accelerationStructureHandles,
3574                                     VkAccelerationStructureBuildTypeKHR buildType, const VkQueryPool queryPool,
3575                                     VkQueryType queryType, uint32_t firstQuery, std::vector<VkDeviceSize> &results)
3576 {
3577     return queryAccelerationStructureSizeKHR(vk, device, cmdBuffer, accelerationStructureHandles, buildType, queryPool,
3578                                              queryType, firstQuery, results);
3579 }
3580 
RayTracingPipeline()3581 RayTracingPipeline::RayTracingPipeline()
3582     : m_shadersModules()
3583     , m_pipelineLibraries()
3584     , m_shaderCreateInfos()
3585     , m_shadersGroupCreateInfos()
3586     , m_pipelineCreateFlags(0U)
3587     , m_pipelineCreateFlags2(0U)
3588     , m_maxRecursionDepth(1U)
3589     , m_maxPayloadSize(0U)
3590     , m_maxAttributeSize(0U)
3591     , m_deferredOperation(false)
3592     , m_workerThreadCount(0)
3593 {
3594 }
3595 
~RayTracingPipeline()3596 RayTracingPipeline::~RayTracingPipeline()
3597 {
3598 }
3599 
3600 #define CHECKED_ASSIGN_SHADER(SHADER, STAGE) \
3601     if (SHADER == VK_SHADER_UNUSED_KHR)      \
3602         SHADER = STAGE;                      \
3603     else                                     \
3604         TCU_THROW(InternalError, "Attempt to reassign shader")
3605 
addShader(VkShaderStageFlagBits shaderStage,Move<VkShaderModule> shaderModule,uint32_t group,const VkSpecializationInfo * specializationInfo,const VkPipelineShaderStageCreateFlags pipelineShaderStageCreateFlags,const void * pipelineShaderStageCreateInfopNext)3606 void RayTracingPipeline::addShader(VkShaderStageFlagBits shaderStage, Move<VkShaderModule> shaderModule, uint32_t group,
3607                                    const VkSpecializationInfo *specializationInfo,
3608                                    const VkPipelineShaderStageCreateFlags pipelineShaderStageCreateFlags,
3609                                    const void *pipelineShaderStageCreateInfopNext)
3610 {
3611     addShader(shaderStage, makeVkSharedPtr(shaderModule), group, specializationInfo, pipelineShaderStageCreateFlags,
3612               pipelineShaderStageCreateInfopNext);
3613 }
3614 
addShader(VkShaderStageFlagBits shaderStage,de::SharedPtr<Move<VkShaderModule>> shaderModule,uint32_t group,const VkSpecializationInfo * specializationInfoPtr,const VkPipelineShaderStageCreateFlags pipelineShaderStageCreateFlags,const void * pipelineShaderStageCreateInfopNext)3615 void RayTracingPipeline::addShader(VkShaderStageFlagBits shaderStage, de::SharedPtr<Move<VkShaderModule>> shaderModule,
3616                                    uint32_t group, const VkSpecializationInfo *specializationInfoPtr,
3617                                    const VkPipelineShaderStageCreateFlags pipelineShaderStageCreateFlags,
3618                                    const void *pipelineShaderStageCreateInfopNext)
3619 {
3620     addShader(shaderStage, **shaderModule, group, specializationInfoPtr, pipelineShaderStageCreateFlags,
3621               pipelineShaderStageCreateInfopNext);
3622     m_shadersModules.push_back(shaderModule);
3623 }
3624 
addShader(VkShaderStageFlagBits shaderStage,VkShaderModule shaderModule,uint32_t group,const VkSpecializationInfo * specializationInfoPtr,const VkPipelineShaderStageCreateFlags pipelineShaderStageCreateFlags,const void * pipelineShaderStageCreateInfopNext)3625 void RayTracingPipeline::addShader(VkShaderStageFlagBits shaderStage, VkShaderModule shaderModule, uint32_t group,
3626                                    const VkSpecializationInfo *specializationInfoPtr,
3627                                    const VkPipelineShaderStageCreateFlags pipelineShaderStageCreateFlags,
3628                                    const void *pipelineShaderStageCreateInfopNext)
3629 {
3630     if (group >= m_shadersGroupCreateInfos.size())
3631     {
3632         for (size_t groupNdx = m_shadersGroupCreateInfos.size(); groupNdx <= group; ++groupNdx)
3633         {
3634             VkRayTracingShaderGroupCreateInfoKHR shaderGroupCreateInfo = {
3635                 VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR, //  VkStructureType sType;
3636                 DE_NULL,                                                    //  const void* pNext;
3637                 VK_RAY_TRACING_SHADER_GROUP_TYPE_MAX_ENUM_KHR,              //  VkRayTracingShaderGroupTypeKHR type;
3638                 VK_SHADER_UNUSED_KHR,                                       //  uint32_t generalShader;
3639                 VK_SHADER_UNUSED_KHR,                                       //  uint32_t closestHitShader;
3640                 VK_SHADER_UNUSED_KHR,                                       //  uint32_t anyHitShader;
3641                 VK_SHADER_UNUSED_KHR,                                       //  uint32_t intersectionShader;
3642                 DE_NULL, //  const void* pShaderGroupCaptureReplayHandle;
3643             };
3644 
3645             m_shadersGroupCreateInfos.push_back(shaderGroupCreateInfo);
3646         }
3647     }
3648 
3649     const uint32_t shaderStageNdx                               = (uint32_t)m_shaderCreateInfos.size();
3650     VkRayTracingShaderGroupCreateInfoKHR &shaderGroupCreateInfo = m_shadersGroupCreateInfos[group];
3651 
3652     switch (shaderStage)
3653     {
3654     case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
3655         CHECKED_ASSIGN_SHADER(shaderGroupCreateInfo.generalShader, shaderStageNdx);
3656         break;
3657     case VK_SHADER_STAGE_MISS_BIT_KHR:
3658         CHECKED_ASSIGN_SHADER(shaderGroupCreateInfo.generalShader, shaderStageNdx);
3659         break;
3660     case VK_SHADER_STAGE_CALLABLE_BIT_KHR:
3661         CHECKED_ASSIGN_SHADER(shaderGroupCreateInfo.generalShader, shaderStageNdx);
3662         break;
3663     case VK_SHADER_STAGE_ANY_HIT_BIT_KHR:
3664         CHECKED_ASSIGN_SHADER(shaderGroupCreateInfo.anyHitShader, shaderStageNdx);
3665         break;
3666     case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
3667         CHECKED_ASSIGN_SHADER(shaderGroupCreateInfo.closestHitShader, shaderStageNdx);
3668         break;
3669     case VK_SHADER_STAGE_INTERSECTION_BIT_KHR:
3670         CHECKED_ASSIGN_SHADER(shaderGroupCreateInfo.intersectionShader, shaderStageNdx);
3671         break;
3672     default:
3673         TCU_THROW(InternalError, "Unacceptable stage");
3674     }
3675 
3676     switch (shaderStage)
3677     {
3678     case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
3679     case VK_SHADER_STAGE_MISS_BIT_KHR:
3680     case VK_SHADER_STAGE_CALLABLE_BIT_KHR:
3681     {
3682         DE_ASSERT(shaderGroupCreateInfo.type == VK_RAY_TRACING_SHADER_GROUP_TYPE_MAX_ENUM_KHR);
3683         shaderGroupCreateInfo.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR;
3684 
3685         break;
3686     }
3687 
3688     case VK_SHADER_STAGE_ANY_HIT_BIT_KHR:
3689     case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
3690     case VK_SHADER_STAGE_INTERSECTION_BIT_KHR:
3691     {
3692         DE_ASSERT(shaderGroupCreateInfo.type != VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR);
3693         shaderGroupCreateInfo.type = (shaderGroupCreateInfo.intersectionShader == VK_SHADER_UNUSED_KHR) ?
3694                                          VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR :
3695                                          VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR;
3696 
3697         break;
3698     }
3699 
3700     default:
3701         TCU_THROW(InternalError, "Unacceptable stage");
3702     }
3703 
3704     {
3705         const VkPipelineShaderStageCreateInfo shaderCreateInfo = {
3706             VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, //  VkStructureType sType;
3707             pipelineShaderStageCreateInfopNext,                  //  const void* pNext;
3708             pipelineShaderStageCreateFlags,                      //  VkPipelineShaderStageCreateFlags flags;
3709             shaderStage,                                         //  VkShaderStageFlagBits stage;
3710             shaderModule,                                        //  VkShaderModule module;
3711             "main",                                              //  const char* pName;
3712             specializationInfoPtr,                               //  const VkSpecializationInfo* pSpecializationInfo;
3713         };
3714 
3715         m_shaderCreateInfos.push_back(shaderCreateInfo);
3716     }
3717 }
3718 
setGroupCaptureReplayHandle(uint32_t group,const void * pShaderGroupCaptureReplayHandle)3719 void RayTracingPipeline::setGroupCaptureReplayHandle(uint32_t group, const void *pShaderGroupCaptureReplayHandle)
3720 {
3721     DE_ASSERT(static_cast<size_t>(group) < m_shadersGroupCreateInfos.size());
3722     m_shadersGroupCreateInfos[group].pShaderGroupCaptureReplayHandle = pShaderGroupCaptureReplayHandle;
3723 }
3724 
addLibrary(de::SharedPtr<de::MovePtr<RayTracingPipeline>> pipelineLibrary)3725 void RayTracingPipeline::addLibrary(de::SharedPtr<de::MovePtr<RayTracingPipeline>> pipelineLibrary)
3726 {
3727     m_pipelineLibraries.push_back(pipelineLibrary);
3728 }
3729 
getShaderGroupCount(void)3730 uint32_t RayTracingPipeline::getShaderGroupCount(void)
3731 {
3732     return de::sizeU32(m_shadersGroupCreateInfos);
3733 }
3734 
getFullShaderGroupCount(void)3735 uint32_t RayTracingPipeline::getFullShaderGroupCount(void)
3736 {
3737     uint32_t totalCount = getShaderGroupCount();
3738 
3739     for (const auto &lib : m_pipelineLibraries)
3740         totalCount += lib->get()->getFullShaderGroupCount();
3741 
3742     return totalCount;
3743 }
3744 
createPipelineKHR(const DeviceInterface & vk,const VkDevice device,const VkPipelineLayout pipelineLayout,const std::vector<VkPipeline> & pipelineLibraries,const VkPipelineCache pipelineCache)3745 Move<VkPipeline> RayTracingPipeline::createPipelineKHR(const DeviceInterface &vk, const VkDevice device,
3746                                                        const VkPipelineLayout pipelineLayout,
3747                                                        const std::vector<VkPipeline> &pipelineLibraries,
3748                                                        const VkPipelineCache pipelineCache)
3749 {
3750     for (size_t groupNdx = 0; groupNdx < m_shadersGroupCreateInfos.size(); ++groupNdx)
3751         DE_ASSERT(m_shadersGroupCreateInfos[groupNdx].sType ==
3752                   VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR);
3753 
3754     VkPipelineLibraryCreateInfoKHR librariesCreateInfo = {
3755         VK_STRUCTURE_TYPE_PIPELINE_LIBRARY_CREATE_INFO_KHR, //  VkStructureType sType;
3756         DE_NULL,                                            //  const void* pNext;
3757         de::sizeU32(pipelineLibraries),                     //  uint32_t libraryCount;
3758         de::dataOrNull(pipelineLibraries)                   //  VkPipeline* pLibraries;
3759     };
3760     const VkRayTracingPipelineInterfaceCreateInfoKHR pipelineInterfaceCreateInfo = {
3761         VK_STRUCTURE_TYPE_RAY_TRACING_PIPELINE_INTERFACE_CREATE_INFO_KHR, //  VkStructureType sType;
3762         DE_NULL,                                                          //  const void* pNext;
3763         m_maxPayloadSize,                                                 //  uint32_t maxPayloadSize;
3764         m_maxAttributeSize                                                //  uint32_t maxAttributeSize;
3765     };
3766     const bool addPipelineInterfaceCreateInfo = m_maxPayloadSize != 0 || m_maxAttributeSize != 0;
3767     const VkRayTracingPipelineInterfaceCreateInfoKHR *pipelineInterfaceCreateInfoPtr =
3768         addPipelineInterfaceCreateInfo ? &pipelineInterfaceCreateInfo : DE_NULL;
3769     const VkPipelineLibraryCreateInfoKHR *librariesCreateInfoPtr =
3770         (pipelineLibraries.empty() ? nullptr : &librariesCreateInfo);
3771 
3772     Move<VkDeferredOperationKHR> deferredOperation;
3773     if (m_deferredOperation)
3774         deferredOperation = createDeferredOperationKHR(vk, device);
3775 
3776     VkPipelineDynamicStateCreateInfo dynamicStateCreateInfo = {
3777         VK_STRUCTURE_TYPE_PIPELINE_DYNAMIC_STATE_CREATE_INFO, // VkStructureType sType;
3778         DE_NULL,                                              // const void* pNext;
3779         0,                                                    // VkPipelineDynamicStateCreateFlags flags;
3780         static_cast<uint32_t>(m_dynamicStates.size()),        // uint32_t dynamicStateCount;
3781         m_dynamicStates.data(),                               // const VkDynamicState* pDynamicStates;
3782     };
3783 
3784     VkRayTracingPipelineCreateInfoKHR pipelineCreateInfo{
3785         VK_STRUCTURE_TYPE_RAY_TRACING_PIPELINE_CREATE_INFO_KHR, //  VkStructureType sType;
3786         DE_NULL,                                                //  const void* pNext;
3787         m_pipelineCreateFlags,                                  //  VkPipelineCreateFlags flags;
3788         de::sizeU32(m_shaderCreateInfos),                       //  uint32_t stageCount;
3789         de::dataOrNull(m_shaderCreateInfos),                    //  const VkPipelineShaderStageCreateInfo* pStages;
3790         de::sizeU32(m_shadersGroupCreateInfos),                 //  uint32_t groupCount;
3791         de::dataOrNull(m_shadersGroupCreateInfos),              //  const VkRayTracingShaderGroupCreateInfoKHR* pGroups;
3792         m_maxRecursionDepth,                                    //  uint32_t maxRecursionDepth;
3793         librariesCreateInfoPtr,                                 //  VkPipelineLibraryCreateInfoKHR* pLibraryInfo;
3794         pipelineInterfaceCreateInfoPtr, //  VkRayTracingPipelineInterfaceCreateInfoKHR* pLibraryInterface;
3795         &dynamicStateCreateInfo,        //  const VkPipelineDynamicStateCreateInfo* pDynamicState;
3796         pipelineLayout,                 //  VkPipelineLayout layout;
3797         (VkPipeline)DE_NULL,            //  VkPipeline basePipelineHandle;
3798         0,                              //  int32_t basePipelineIndex;
3799     };
3800     VkPipeline object = DE_NULL;
3801     VkResult result   = vk.createRayTracingPipelinesKHR(device, deferredOperation.get(), pipelineCache, 1u,
3802                                                         &pipelineCreateInfo, DE_NULL, &object);
3803     const bool allowCompileRequired =
3804         ((m_pipelineCreateFlags & VK_PIPELINE_CREATE_FAIL_ON_PIPELINE_COMPILE_REQUIRED_BIT_EXT) != 0);
3805 
3806     VkPipelineCreateFlags2CreateInfoKHR pipelineFlags2CreateInfo = initVulkanStructure();
3807     if (m_pipelineCreateFlags2)
3808     {
3809         pipelineFlags2CreateInfo.flags = m_pipelineCreateFlags2;
3810         pipelineCreateInfo.pNext       = &pipelineFlags2CreateInfo;
3811         pipelineCreateInfo.flags       = 0;
3812     }
3813 
3814     if (m_deferredOperation)
3815     {
3816         DE_ASSERT(result == VK_OPERATION_DEFERRED_KHR || result == VK_OPERATION_NOT_DEFERRED_KHR ||
3817                   result == VK_SUCCESS || (allowCompileRequired && result == VK_PIPELINE_COMPILE_REQUIRED));
3818         finishDeferredOperation(vk, device, deferredOperation.get(), m_workerThreadCount,
3819                                 result == VK_OPERATION_NOT_DEFERRED_KHR);
3820     }
3821 
3822     if (allowCompileRequired && result == VK_PIPELINE_COMPILE_REQUIRED)
3823         throw CompileRequiredError("createRayTracingPipelinesKHR returned VK_PIPELINE_COMPILE_REQUIRED");
3824 
3825     Move<VkPipeline> pipeline(check<VkPipeline>(object), Deleter<VkPipeline>(vk, device, DE_NULL));
3826     return pipeline;
3827 }
3828 
createPipeline(const DeviceInterface & vk,const VkDevice device,const VkPipelineLayout pipelineLayout,const std::vector<de::SharedPtr<Move<VkPipeline>>> & pipelineLibraries)3829 Move<VkPipeline> RayTracingPipeline::createPipeline(
3830     const DeviceInterface &vk, const VkDevice device, const VkPipelineLayout pipelineLayout,
3831     const std::vector<de::SharedPtr<Move<VkPipeline>>> &pipelineLibraries)
3832 {
3833     std::vector<VkPipeline> rawPipelines;
3834     rawPipelines.reserve(pipelineLibraries.size());
3835     for (const auto &lib : pipelineLibraries)
3836         rawPipelines.push_back(lib.get()->get());
3837 
3838     return createPipelineKHR(vk, device, pipelineLayout, rawPipelines);
3839 }
3840 
createPipeline(const DeviceInterface & vk,const VkDevice device,const VkPipelineLayout pipelineLayout,const std::vector<VkPipeline> & pipelineLibraries,const VkPipelineCache pipelineCache)3841 Move<VkPipeline> RayTracingPipeline::createPipeline(const DeviceInterface &vk, const VkDevice device,
3842                                                     const VkPipelineLayout pipelineLayout,
3843                                                     const std::vector<VkPipeline> &pipelineLibraries,
3844                                                     const VkPipelineCache pipelineCache)
3845 {
3846     return createPipelineKHR(vk, device, pipelineLayout, pipelineLibraries, pipelineCache);
3847 }
3848 
createPipelineWithLibraries(const DeviceInterface & vk,const VkDevice device,const VkPipelineLayout pipelineLayout)3849 std::vector<de::SharedPtr<Move<VkPipeline>>> RayTracingPipeline::createPipelineWithLibraries(
3850     const DeviceInterface &vk, const VkDevice device, const VkPipelineLayout pipelineLayout)
3851 {
3852     for (size_t groupNdx = 0; groupNdx < m_shadersGroupCreateInfos.size(); ++groupNdx)
3853         DE_ASSERT(m_shadersGroupCreateInfos[groupNdx].sType ==
3854                   VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR);
3855 
3856     DE_ASSERT(m_shaderCreateInfos.size() > 0);
3857     DE_ASSERT(m_shadersGroupCreateInfos.size() > 0);
3858 
3859     std::vector<de::SharedPtr<Move<VkPipeline>>> result, allLibraries, firstLibraries;
3860     for (auto it = begin(m_pipelineLibraries), eit = end(m_pipelineLibraries); it != eit; ++it)
3861     {
3862         auto childLibraries = (*it)->get()->createPipelineWithLibraries(vk, device, pipelineLayout);
3863         DE_ASSERT(childLibraries.size() > 0);
3864         firstLibraries.push_back(childLibraries[0]);
3865         std::copy(begin(childLibraries), end(childLibraries), std::back_inserter(allLibraries));
3866     }
3867     result.push_back(makeVkSharedPtr(createPipeline(vk, device, pipelineLayout, firstLibraries)));
3868     std::copy(begin(allLibraries), end(allLibraries), std::back_inserter(result));
3869     return result;
3870 }
3871 
getShaderGroupHandles(const DeviceInterface & vk,const VkDevice device,const VkPipeline pipeline,const uint32_t shaderGroupHandleSize,const uint32_t firstGroup,const uint32_t groupCount) const3872 std::vector<uint8_t> RayTracingPipeline::getShaderGroupHandles(const DeviceInterface &vk, const VkDevice device,
3873                                                                const VkPipeline pipeline,
3874                                                                const uint32_t shaderGroupHandleSize,
3875                                                                const uint32_t firstGroup,
3876                                                                const uint32_t groupCount) const
3877 {
3878     const auto handleArraySizeBytes = groupCount * shaderGroupHandleSize;
3879     std::vector<uint8_t> shaderHandles(handleArraySizeBytes);
3880 
3881     VK_CHECK(getRayTracingShaderGroupHandles(vk, device, pipeline, firstGroup, groupCount,
3882                                              static_cast<uintptr_t>(shaderHandles.size()),
3883                                              de::dataOrNull(shaderHandles)));
3884 
3885     return shaderHandles;
3886 }
3887 
getShaderGroupReplayHandles(const DeviceInterface & vk,const VkDevice device,const VkPipeline pipeline,const uint32_t shaderGroupHandleReplaySize,const uint32_t firstGroup,const uint32_t groupCount) const3888 std::vector<uint8_t> RayTracingPipeline::getShaderGroupReplayHandles(const DeviceInterface &vk, const VkDevice device,
3889                                                                      const VkPipeline pipeline,
3890                                                                      const uint32_t shaderGroupHandleReplaySize,
3891                                                                      const uint32_t firstGroup,
3892                                                                      const uint32_t groupCount) const
3893 {
3894     const auto handleArraySizeBytes = groupCount * shaderGroupHandleReplaySize;
3895     std::vector<uint8_t> shaderHandles(handleArraySizeBytes);
3896 
3897     VK_CHECK(getRayTracingCaptureReplayShaderGroupHandles(vk, device, pipeline, firstGroup, groupCount,
3898                                                           static_cast<uintptr_t>(shaderHandles.size()),
3899                                                           de::dataOrNull(shaderHandles)));
3900 
3901     return shaderHandles;
3902 }
3903 
createShaderBindingTable(const DeviceInterface & vk,const VkDevice device,const VkPipeline pipeline,Allocator & allocator,const uint32_t & shaderGroupHandleSize,const uint32_t shaderGroupBaseAlignment,const uint32_t & firstGroup,const uint32_t & groupCount,const VkBufferCreateFlags & additionalBufferCreateFlags,const VkBufferUsageFlags & additionalBufferUsageFlags,const MemoryRequirement & additionalMemoryRequirement,const VkDeviceAddress & opaqueCaptureAddress,const uint32_t shaderBindingTableOffset,const uint32_t shaderRecordSize,const void ** shaderGroupDataPtrPerGroup,const bool autoAlignRecords)3904 de::MovePtr<BufferWithMemory> RayTracingPipeline::createShaderBindingTable(
3905     const DeviceInterface &vk, const VkDevice device, const VkPipeline pipeline, Allocator &allocator,
3906     const uint32_t &shaderGroupHandleSize, const uint32_t shaderGroupBaseAlignment, const uint32_t &firstGroup,
3907     const uint32_t &groupCount, const VkBufferCreateFlags &additionalBufferCreateFlags,
3908     const VkBufferUsageFlags &additionalBufferUsageFlags, const MemoryRequirement &additionalMemoryRequirement,
3909     const VkDeviceAddress &opaqueCaptureAddress, const uint32_t shaderBindingTableOffset,
3910     const uint32_t shaderRecordSize, const void **shaderGroupDataPtrPerGroup, const bool autoAlignRecords)
3911 {
3912     const auto shaderHandles =
3913         getShaderGroupHandles(vk, device, pipeline, shaderGroupHandleSize, firstGroup, groupCount);
3914     return createShaderBindingTable(vk, device, allocator, shaderGroupHandleSize, shaderGroupBaseAlignment,
3915                                     shaderHandles, additionalBufferCreateFlags, additionalBufferUsageFlags,
3916                                     additionalMemoryRequirement, opaqueCaptureAddress, shaderBindingTableOffset,
3917                                     shaderRecordSize, shaderGroupDataPtrPerGroup, autoAlignRecords);
3918 }
3919 
createShaderBindingTable(const DeviceInterface & vk,const VkDevice device,Allocator & allocator,const uint32_t shaderGroupHandleSize,const uint32_t shaderGroupBaseAlignment,const std::vector<uint8_t> & shaderHandles,const VkBufferCreateFlags additionalBufferCreateFlags,const VkBufferUsageFlags additionalBufferUsageFlags,const MemoryRequirement & additionalMemoryRequirement,const VkDeviceAddress opaqueCaptureAddress,const uint32_t shaderBindingTableOffset,const uint32_t shaderRecordSize,const void ** shaderGroupDataPtrPerGroup,const bool autoAlignRecords)3920 de::MovePtr<BufferWithMemory> RayTracingPipeline::createShaderBindingTable(
3921     const DeviceInterface &vk, const VkDevice device, Allocator &allocator, const uint32_t shaderGroupHandleSize,
3922     const uint32_t shaderGroupBaseAlignment, const std::vector<uint8_t> &shaderHandles,
3923     const VkBufferCreateFlags additionalBufferCreateFlags, const VkBufferUsageFlags additionalBufferUsageFlags,
3924     const MemoryRequirement &additionalMemoryRequirement, const VkDeviceAddress opaqueCaptureAddress,
3925     const uint32_t shaderBindingTableOffset, const uint32_t shaderRecordSize, const void **shaderGroupDataPtrPerGroup,
3926     const bool autoAlignRecords)
3927 {
3928     DE_ASSERT(shaderGroupBaseAlignment != 0u);
3929     DE_ASSERT((shaderBindingTableOffset % shaderGroupBaseAlignment) == 0);
3930     DE_UNREF(shaderGroupBaseAlignment);
3931 
3932     const auto groupCount = de::sizeU32(shaderHandles) / shaderGroupHandleSize;
3933     const auto totalEntrySize =
3934         (autoAlignRecords ? (deAlign32(shaderGroupHandleSize + shaderRecordSize, shaderGroupHandleSize)) :
3935                             (shaderGroupHandleSize + shaderRecordSize));
3936     const uint32_t sbtSize            = shaderBindingTableOffset + groupCount * totalEntrySize;
3937     const VkBufferUsageFlags sbtFlags = VK_BUFFER_USAGE_TRANSFER_DST_BIT |
3938                                         VK_BUFFER_USAGE_SHADER_BINDING_TABLE_BIT_KHR |
3939                                         VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT | additionalBufferUsageFlags;
3940     VkBufferCreateInfo sbtCreateInfo = makeBufferCreateInfo(sbtSize, sbtFlags);
3941     sbtCreateInfo.flags |= additionalBufferCreateFlags;
3942     VkBufferUsageFlags2CreateInfoKHR bufferUsageFlags2           = vk::initVulkanStructure();
3943     VkBufferOpaqueCaptureAddressCreateInfo sbtCaptureAddressInfo = {
3944         VK_STRUCTURE_TYPE_BUFFER_OPAQUE_CAPTURE_ADDRESS_CREATE_INFO, // VkStructureType sType;
3945         DE_NULL,                                                     // const void* pNext;
3946         uint64_t(opaqueCaptureAddress)                               // uint64_t opaqueCaptureAddress;
3947     };
3948 
3949     // when maintenance5 is tested then m_pipelineCreateFlags2 is non-zero
3950     if (m_pipelineCreateFlags2)
3951     {
3952         bufferUsageFlags2.usage = (VkBufferUsageFlags2KHR)sbtFlags;
3953         sbtCreateInfo.pNext     = &bufferUsageFlags2;
3954         sbtCreateInfo.usage     = 0;
3955     }
3956 
3957     if (opaqueCaptureAddress != 0u)
3958     {
3959         sbtCreateInfo.pNext = &sbtCaptureAddressInfo;
3960         sbtCreateInfo.flags |= VK_BUFFER_CREATE_DEVICE_ADDRESS_CAPTURE_REPLAY_BIT;
3961     }
3962     const MemoryRequirement sbtMemRequirements = MemoryRequirement::HostVisible | MemoryRequirement::Coherent |
3963                                                  MemoryRequirement::DeviceAddress | additionalMemoryRequirement;
3964     de::MovePtr<BufferWithMemory> sbtBuffer =
3965         de::MovePtr<BufferWithMemory>(new BufferWithMemory(vk, device, allocator, sbtCreateInfo, sbtMemRequirements));
3966     vk::Allocation &sbtAlloc = sbtBuffer->getAllocation();
3967 
3968     // Copy handles to table, leaving space for ShaderRecordKHR after each handle.
3969     uint8_t *shaderBegin = (uint8_t *)sbtAlloc.getHostPtr() + shaderBindingTableOffset;
3970     for (uint32_t idx = 0; idx < groupCount; ++idx)
3971     {
3972         const uint8_t *shaderSrcPos = shaderHandles.data() + idx * shaderGroupHandleSize;
3973         uint8_t *shaderDstPos       = shaderBegin + idx * totalEntrySize;
3974         deMemcpy(shaderDstPos, shaderSrcPos, shaderGroupHandleSize);
3975 
3976         if (shaderGroupDataPtrPerGroup != nullptr && shaderGroupDataPtrPerGroup[idx] != nullptr)
3977         {
3978             DE_ASSERT(sbtSize >= static_cast<uint32_t>(shaderDstPos - shaderBegin) + shaderGroupHandleSize);
3979 
3980             deMemcpy(shaderDstPos + shaderGroupHandleSize, shaderGroupDataPtrPerGroup[idx], shaderRecordSize);
3981         }
3982     }
3983 
3984     flushMappedMemoryRange(vk, device, sbtAlloc.getMemory(), sbtAlloc.getOffset(), VK_WHOLE_SIZE);
3985 
3986     return sbtBuffer;
3987 }
3988 
setCreateFlags(const VkPipelineCreateFlags & pipelineCreateFlags)3989 void RayTracingPipeline::setCreateFlags(const VkPipelineCreateFlags &pipelineCreateFlags)
3990 {
3991     m_pipelineCreateFlags = pipelineCreateFlags;
3992 }
3993 
setCreateFlags2(const VkPipelineCreateFlags2KHR & pipelineCreateFlags2)3994 void RayTracingPipeline::setCreateFlags2(const VkPipelineCreateFlags2KHR &pipelineCreateFlags2)
3995 {
3996     m_pipelineCreateFlags2 = pipelineCreateFlags2;
3997 }
3998 
setMaxRecursionDepth(const uint32_t & maxRecursionDepth)3999 void RayTracingPipeline::setMaxRecursionDepth(const uint32_t &maxRecursionDepth)
4000 {
4001     m_maxRecursionDepth = maxRecursionDepth;
4002 }
4003 
setMaxPayloadSize(const uint32_t & maxPayloadSize)4004 void RayTracingPipeline::setMaxPayloadSize(const uint32_t &maxPayloadSize)
4005 {
4006     m_maxPayloadSize = maxPayloadSize;
4007 }
4008 
setMaxAttributeSize(const uint32_t & maxAttributeSize)4009 void RayTracingPipeline::setMaxAttributeSize(const uint32_t &maxAttributeSize)
4010 {
4011     m_maxAttributeSize = maxAttributeSize;
4012 }
4013 
setDeferredOperation(const bool deferredOperation,const uint32_t workerThreadCount)4014 void RayTracingPipeline::setDeferredOperation(const bool deferredOperation, const uint32_t workerThreadCount)
4015 {
4016     m_deferredOperation = deferredOperation;
4017     m_workerThreadCount = workerThreadCount;
4018 }
4019 
addDynamicState(const VkDynamicState & dynamicState)4020 void RayTracingPipeline::addDynamicState(const VkDynamicState &dynamicState)
4021 {
4022     m_dynamicStates.push_back(dynamicState);
4023 }
4024 
4025 class RayTracingPropertiesKHR : public RayTracingProperties
4026 {
4027 public:
4028     RayTracingPropertiesKHR() = delete;
4029     RayTracingPropertiesKHR(const InstanceInterface &vki, const VkPhysicalDevice physicalDevice);
4030     virtual ~RayTracingPropertiesKHR();
4031 
getShaderGroupHandleSize(void)4032     uint32_t getShaderGroupHandleSize(void) override
4033     {
4034         return m_rayTracingPipelineProperties.shaderGroupHandleSize;
4035     }
getShaderGroupHandleAlignment(void)4036     uint32_t getShaderGroupHandleAlignment(void) override
4037     {
4038         return m_rayTracingPipelineProperties.shaderGroupHandleAlignment;
4039     }
getShaderGroupHandleCaptureReplaySize(void)4040     uint32_t getShaderGroupHandleCaptureReplaySize(void) override
4041     {
4042         return m_rayTracingPipelineProperties.shaderGroupHandleCaptureReplaySize;
4043     }
getMaxRecursionDepth(void)4044     uint32_t getMaxRecursionDepth(void) override
4045     {
4046         return m_rayTracingPipelineProperties.maxRayRecursionDepth;
4047     }
getMaxShaderGroupStride(void)4048     uint32_t getMaxShaderGroupStride(void) override
4049     {
4050         return m_rayTracingPipelineProperties.maxShaderGroupStride;
4051     }
getShaderGroupBaseAlignment(void)4052     uint32_t getShaderGroupBaseAlignment(void) override
4053     {
4054         return m_rayTracingPipelineProperties.shaderGroupBaseAlignment;
4055     }
getMaxGeometryCount(void)4056     uint64_t getMaxGeometryCount(void) override
4057     {
4058         return m_accelerationStructureProperties.maxGeometryCount;
4059     }
getMaxInstanceCount(void)4060     uint64_t getMaxInstanceCount(void) override
4061     {
4062         return m_accelerationStructureProperties.maxInstanceCount;
4063     }
getMaxPrimitiveCount(void)4064     uint64_t getMaxPrimitiveCount(void) override
4065     {
4066         return m_accelerationStructureProperties.maxPrimitiveCount;
4067     }
getMaxDescriptorSetAccelerationStructures(void)4068     uint32_t getMaxDescriptorSetAccelerationStructures(void) override
4069     {
4070         return m_accelerationStructureProperties.maxDescriptorSetAccelerationStructures;
4071     }
getMaxRayDispatchInvocationCount(void)4072     uint32_t getMaxRayDispatchInvocationCount(void) override
4073     {
4074         return m_rayTracingPipelineProperties.maxRayDispatchInvocationCount;
4075     }
getMaxRayHitAttributeSize(void)4076     uint32_t getMaxRayHitAttributeSize(void) override
4077     {
4078         return m_rayTracingPipelineProperties.maxRayHitAttributeSize;
4079     }
getMaxMemoryAllocationCount(void)4080     uint32_t getMaxMemoryAllocationCount(void) override
4081     {
4082         return m_maxMemoryAllocationCount;
4083     }
4084 
4085 protected:
4086     VkPhysicalDeviceAccelerationStructurePropertiesKHR m_accelerationStructureProperties;
4087     VkPhysicalDeviceRayTracingPipelinePropertiesKHR m_rayTracingPipelineProperties;
4088     uint32_t m_maxMemoryAllocationCount;
4089 };
4090 
~RayTracingPropertiesKHR()4091 RayTracingPropertiesKHR::~RayTracingPropertiesKHR()
4092 {
4093 }
4094 
RayTracingPropertiesKHR(const InstanceInterface & vki,const VkPhysicalDevice physicalDevice)4095 RayTracingPropertiesKHR::RayTracingPropertiesKHR(const InstanceInterface &vki, const VkPhysicalDevice physicalDevice)
4096     : RayTracingProperties(vki, physicalDevice)
4097 {
4098     m_accelerationStructureProperties = getPhysicalDeviceExtensionProperties(vki, physicalDevice);
4099     m_rayTracingPipelineProperties    = getPhysicalDeviceExtensionProperties(vki, physicalDevice);
4100     m_maxMemoryAllocationCount = getPhysicalDeviceProperties(vki, physicalDevice).limits.maxMemoryAllocationCount;
4101 }
4102 
makeRayTracingProperties(const InstanceInterface & vki,const VkPhysicalDevice physicalDevice)4103 de::MovePtr<RayTracingProperties> makeRayTracingProperties(const InstanceInterface &vki,
4104                                                            const VkPhysicalDevice physicalDevice)
4105 {
4106     return de::MovePtr<RayTracingProperties>(new RayTracingPropertiesKHR(vki, physicalDevice));
4107 }
4108 
cmdTraceRaysKHR(const DeviceInterface & vk,VkCommandBuffer commandBuffer,const VkStridedDeviceAddressRegionKHR * raygenShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * missShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * hitShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * callableShaderBindingTableRegion,uint32_t width,uint32_t height,uint32_t depth)4109 static inline void cmdTraceRaysKHR(const DeviceInterface &vk, VkCommandBuffer commandBuffer,
4110                                    const VkStridedDeviceAddressRegionKHR *raygenShaderBindingTableRegion,
4111                                    const VkStridedDeviceAddressRegionKHR *missShaderBindingTableRegion,
4112                                    const VkStridedDeviceAddressRegionKHR *hitShaderBindingTableRegion,
4113                                    const VkStridedDeviceAddressRegionKHR *callableShaderBindingTableRegion,
4114                                    uint32_t width, uint32_t height, uint32_t depth)
4115 {
4116     return vk.cmdTraceRaysKHR(commandBuffer, raygenShaderBindingTableRegion, missShaderBindingTableRegion,
4117                               hitShaderBindingTableRegion, callableShaderBindingTableRegion, width, height, depth);
4118 }
4119 
cmdTraceRays(const DeviceInterface & vk,VkCommandBuffer commandBuffer,const VkStridedDeviceAddressRegionKHR * raygenShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * missShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * hitShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * callableShaderBindingTableRegion,uint32_t width,uint32_t height,uint32_t depth)4120 void cmdTraceRays(const DeviceInterface &vk, VkCommandBuffer commandBuffer,
4121                   const VkStridedDeviceAddressRegionKHR *raygenShaderBindingTableRegion,
4122                   const VkStridedDeviceAddressRegionKHR *missShaderBindingTableRegion,
4123                   const VkStridedDeviceAddressRegionKHR *hitShaderBindingTableRegion,
4124                   const VkStridedDeviceAddressRegionKHR *callableShaderBindingTableRegion, uint32_t width,
4125                   uint32_t height, uint32_t depth)
4126 {
4127     DE_ASSERT(raygenShaderBindingTableRegion != DE_NULL);
4128     DE_ASSERT(missShaderBindingTableRegion != DE_NULL);
4129     DE_ASSERT(hitShaderBindingTableRegion != DE_NULL);
4130     DE_ASSERT(callableShaderBindingTableRegion != DE_NULL);
4131 
4132     return cmdTraceRaysKHR(vk, commandBuffer, raygenShaderBindingTableRegion, missShaderBindingTableRegion,
4133                            hitShaderBindingTableRegion, callableShaderBindingTableRegion, width, height, depth);
4134 }
4135 
cmdTraceRaysIndirectKHR(const DeviceInterface & vk,VkCommandBuffer commandBuffer,const VkStridedDeviceAddressRegionKHR * raygenShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * missShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * hitShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * callableShaderBindingTableRegion,VkDeviceAddress indirectDeviceAddress)4136 static inline void cmdTraceRaysIndirectKHR(const DeviceInterface &vk, VkCommandBuffer commandBuffer,
4137                                            const VkStridedDeviceAddressRegionKHR *raygenShaderBindingTableRegion,
4138                                            const VkStridedDeviceAddressRegionKHR *missShaderBindingTableRegion,
4139                                            const VkStridedDeviceAddressRegionKHR *hitShaderBindingTableRegion,
4140                                            const VkStridedDeviceAddressRegionKHR *callableShaderBindingTableRegion,
4141                                            VkDeviceAddress indirectDeviceAddress)
4142 {
4143     DE_ASSERT(raygenShaderBindingTableRegion != DE_NULL);
4144     DE_ASSERT(missShaderBindingTableRegion != DE_NULL);
4145     DE_ASSERT(hitShaderBindingTableRegion != DE_NULL);
4146     DE_ASSERT(callableShaderBindingTableRegion != DE_NULL);
4147     DE_ASSERT(indirectDeviceAddress != 0);
4148 
4149     return vk.cmdTraceRaysIndirectKHR(commandBuffer, raygenShaderBindingTableRegion, missShaderBindingTableRegion,
4150                                       hitShaderBindingTableRegion, callableShaderBindingTableRegion,
4151                                       indirectDeviceAddress);
4152 }
4153 
cmdTraceRaysIndirect(const DeviceInterface & vk,VkCommandBuffer commandBuffer,const VkStridedDeviceAddressRegionKHR * raygenShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * missShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * hitShaderBindingTableRegion,const VkStridedDeviceAddressRegionKHR * callableShaderBindingTableRegion,VkDeviceAddress indirectDeviceAddress)4154 void cmdTraceRaysIndirect(const DeviceInterface &vk, VkCommandBuffer commandBuffer,
4155                           const VkStridedDeviceAddressRegionKHR *raygenShaderBindingTableRegion,
4156                           const VkStridedDeviceAddressRegionKHR *missShaderBindingTableRegion,
4157                           const VkStridedDeviceAddressRegionKHR *hitShaderBindingTableRegion,
4158                           const VkStridedDeviceAddressRegionKHR *callableShaderBindingTableRegion,
4159                           VkDeviceAddress indirectDeviceAddress)
4160 {
4161     return cmdTraceRaysIndirectKHR(vk, commandBuffer, raygenShaderBindingTableRegion, missShaderBindingTableRegion,
4162                                    hitShaderBindingTableRegion, callableShaderBindingTableRegion,
4163                                    indirectDeviceAddress);
4164 }
4165 
cmdTraceRaysIndirect2KHR(const DeviceInterface & vk,VkCommandBuffer commandBuffer,VkDeviceAddress indirectDeviceAddress)4166 static inline void cmdTraceRaysIndirect2KHR(const DeviceInterface &vk, VkCommandBuffer commandBuffer,
4167                                             VkDeviceAddress indirectDeviceAddress)
4168 {
4169     DE_ASSERT(indirectDeviceAddress != 0);
4170 
4171     return vk.cmdTraceRaysIndirect2KHR(commandBuffer, indirectDeviceAddress);
4172 }
4173 
cmdTraceRaysIndirect2(const DeviceInterface & vk,VkCommandBuffer commandBuffer,VkDeviceAddress indirectDeviceAddress)4174 void cmdTraceRaysIndirect2(const DeviceInterface &vk, VkCommandBuffer commandBuffer,
4175                            VkDeviceAddress indirectDeviceAddress)
4176 {
4177     return cmdTraceRaysIndirect2KHR(vk, commandBuffer, indirectDeviceAddress);
4178 }
4179 
4180 constexpr uint32_t NO_INT_VALUE = spv::RayQueryCommittedIntersectionTypeMax;
4181 
generateRayQueryShaders(SourceCollections & programCollection,RayQueryTestParams params,std::string rayQueryPart,float max_t)4182 void generateRayQueryShaders(SourceCollections &programCollection, RayQueryTestParams params, std::string rayQueryPart,
4183                              float max_t)
4184 {
4185     std::stringstream genericMiss;
4186     genericMiss << "#version 460\n"
4187                    "#extension GL_EXT_ray_tracing : require\n"
4188                    "#extension GL_EXT_ray_query : require\n"
4189                    "layout(location = 0) rayPayloadInEXT vec4 payload;\n"
4190                    "void main()\n"
4191                    "{\n"
4192                    "  payload.x = 2000;\n"
4193                    "  payload.y = 2000;\n"
4194                    "  payload.z = 2000;\n"
4195                    "  payload.w = 2000;\n"
4196                    "}\n";
4197 
4198     std::stringstream genericIsect;
4199     genericIsect << "#version 460\n"
4200                     "#extension GL_EXT_ray_tracing : require\n"
4201                     "hitAttributeEXT uvec4 hitValue;\n"
4202                     "void main()\n"
4203                     "{\n"
4204                     "  reportIntersectionEXT(0.5f, 0);\n"
4205                     "}\n";
4206 
4207     std::stringstream rtChit;
4208     rtChit << "#version 460    \n"
4209               "#extension GL_EXT_ray_tracing : require\n"
4210               "#extension GL_EXT_ray_query : require\n"
4211               "layout(location = 0) rayPayloadInEXT vec4 payload;\n"
4212               "void main()\n"
4213               "{\n"
4214               "  uint index = (gl_LaunchIDEXT.z * gl_LaunchSizeEXT.x * gl_LaunchSizeEXT.y) + (gl_LaunchIDEXT.y * "
4215               "gl_LaunchSizeEXT.x) + gl_LaunchIDEXT.x;\n"
4216               "  payload.x = index;\n"
4217               "  payload.y = gl_HitTEXT;\n"
4218               "  payload.z = 1000;\n"
4219               "  payload.w = 1000;\n"
4220               "}\n";
4221 
4222     std::stringstream genericChit;
4223     genericChit << "#version 460    \n"
4224                    "#extension GL_EXT_ray_tracing : require\n"
4225                    "#extension GL_EXT_ray_query : require\n"
4226                    "layout(location = 0) rayPayloadInEXT vec4 payload;\n"
4227                    "void main()\n"
4228                    "{\n"
4229                    "  payload.x = 1000;\n"
4230                    "  payload.y = 1000;\n"
4231                    "  payload.z = 1000;\n"
4232                    "  payload.w = 1000;\n"
4233                    "}\n";
4234 
4235     std::stringstream genericRayTracingSetResultsShader;
4236     genericRayTracingSetResultsShader << "#version 460    \n"
4237                                          "#extension GL_EXT_ray_tracing : require\n"
4238                                          "#extension GL_EXT_ray_query : require\n"
4239                                          "layout(location = 0) rayPayloadInEXT vec4 payload;\n"
4240                                          "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n"
4241                                          "layout(set = 0, binding = 1) uniform accelerationStructureEXT scene;\n"
4242                                          "layout(std430, set = 0, binding = 2) buffer Rays { Ray rays[]; };\n"
4243                                       << params.shaderFunctions
4244                                       << "void main()\n"
4245                                          "{\n"
4246                                          "  uint index = (gl_LaunchIDEXT.z * gl_LaunchSizeEXT.x * gl_LaunchSizeEXT.y) "
4247                                          "+ (gl_LaunchIDEXT.y * gl_LaunchSizeEXT.x) + gl_LaunchIDEXT.x;\n"
4248                                       << rayQueryPart
4249                                       << "  payload.x = x;\n"
4250                                          "  payload.y = y;\n"
4251                                          "  payload.z = z;\n"
4252                                          "  payload.w = w;\n"
4253                                          "}\n";
4254 
4255     const vk::ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_5, 0u, true);
4256 
4257     switch (params.pipelineType)
4258     {
4259     case RayQueryShaderSourcePipeline::COMPUTE:
4260     {
4261         std::ostringstream compute;
4262         compute << "#version 460\n"
4263                    "#extension GL_EXT_ray_tracing : enable\n"
4264                    "#extension GL_EXT_ray_query : require\n"
4265                    "\n"
4266                    "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n"
4267                    "struct ResultType { float x; float y; float z; float w; };\n"
4268                    "layout(std430, set = 0, binding = 0) buffer Results { ResultType results[]; };\n"
4269                    "layout(set = 0, binding = 1) uniform accelerationStructureEXT scene;\n"
4270                    "layout(std430, set = 0, binding = 2) buffer Rays { Ray rays[]; };\n"
4271                    "layout (local_size_x = 1, local_size_y = 1, local_size_z = 1) in;\n"
4272                 << params.shaderFunctions
4273                 << "void main() {\n"
4274                    "   uint index = (gl_NumWorkGroups.x * gl_WorkGroupSize.x) * gl_GlobalInvocationID.y + "
4275                    "gl_GlobalInvocationID.x;\n"
4276                 << rayQueryPart
4277                 << "   results[index].x = x;\n"
4278                    "   results[index].y = y;\n"
4279                    "   results[index].z = z;\n"
4280                    "   results[index].w = w;\n"
4281                    "}";
4282 
4283         programCollection.glslSources.add("comp", &buildOptions) << glu::ComputeSource(compute.str());
4284 
4285         break;
4286     }
4287     case RayQueryShaderSourcePipeline::GRAPHICS:
4288     {
4289         std::ostringstream vertex;
4290 
4291         if (params.shaderSourceType == RayQueryShaderSourceType::VERTEX)
4292         {
4293             vertex << "#version 460\n"
4294                       "#extension GL_EXT_ray_tracing : enable\n"
4295                       "#extension GL_EXT_ray_query : require\n"
4296                       "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n"
4297                       "layout(location = 0) in vec4 in_position;\n"
4298                       "layout(rgba32f, set = 0, binding = 0) uniform image3D resultImage;\n"
4299                       "layout(set = 0, binding = 1) uniform accelerationStructureEXT scene;\n"
4300                       "layout(std430, set = 0, binding = 2) buffer Rays { Ray rays[]; };\n"
4301                    << params.shaderFunctions
4302                    << "void main(void)\n"
4303                       "{\n"
4304                       "  const int  vertId = int(gl_VertexIndex % 3);\n"
4305                       "  if (vertId == 0)\n"
4306                       "  {\n"
4307                       "    ivec3 sz = imageSize(resultImage);\n"
4308                       "    int index = int(in_position.z);\n"
4309                       "    int idx = int(index % sz.x);\n"
4310                       "    int idy = int(index / sz.y);\n"
4311                    << rayQueryPart
4312                    << "     imageStore(resultImage, ivec3(idx, idy, 0), vec4(x, y, z, w));\n"
4313                       "  }\n"
4314                       "}\n";
4315         }
4316         else
4317         {
4318             vertex << "#version 460\n"
4319                       "layout(location = 0) in highp vec3 position;\n"
4320                       "\n"
4321                       "out gl_PerVertex {\n"
4322                       "   vec4 gl_Position;\n"
4323                       "};\n"
4324                       "\n"
4325                       "void main (void)\n"
4326                       "{\n"
4327                       "    gl_Position = vec4(position, 1.0);\n"
4328                       "}\n";
4329         }
4330 
4331         programCollection.glslSources.add("vert", &buildOptions) << glu::VertexSource(vertex.str());
4332 
4333         if (params.shaderSourceType == RayQueryShaderSourceType::FRAGMENT)
4334         {
4335             std::ostringstream frag;
4336             frag << "#version 460\n"
4337                     "#extension GL_EXT_ray_tracing : enable\n"
4338                     "#extension GL_EXT_ray_query : require\n"
4339                     "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n"
4340                     "layout(rgba32f, set = 0, binding = 0) uniform image3D resultImage;\n"
4341                     "layout(set = 0, binding = 1) uniform accelerationStructureEXT scene;\n"
4342                     "layout(std430, set = 0, binding = 2) buffer Rays { Ray rays[]; };\n"
4343                  << params.shaderFunctions
4344                  << "void main() {\n"
4345                     "    ivec3 sz = imageSize(resultImage);\n"
4346                     "    uint index = uint(gl_FragCoord.x) + sz.x * uint(gl_FragCoord.y);\n"
4347                  << rayQueryPart
4348                  << "    imageStore(resultImage, ivec3(gl_FragCoord.xy, 0), vec4(x, y, z, w));\n"
4349                     "}";
4350 
4351             programCollection.glslSources.add("frag", &buildOptions) << glu::FragmentSource(frag.str());
4352         }
4353         else if (params.shaderSourceType == RayQueryShaderSourceType::GEOMETRY)
4354         {
4355             std::stringstream geom;
4356             geom << "#version 460\n"
4357                     "#extension GL_EXT_ray_tracing : enable\n"
4358                     "#extension GL_EXT_ray_query : require\n"
4359                     "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n"
4360                     "layout(triangles) in;\n"
4361                     "layout (triangle_strip, max_vertices = 3) out;\n"
4362                     "layout(rgba32f, set = 0, binding = 0) uniform image3D resultImage;\n"
4363                     "layout(set = 0, binding = 1) uniform accelerationStructureEXT scene;\n"
4364                     "layout(std430, set = 0, binding = 2) buffer Rays { Ray rays[]; };\n"
4365                     "\n"
4366                     "in gl_PerVertex {\n"
4367                     "  vec4  gl_Position;\n"
4368                     "} gl_in[];\n"
4369                     "out gl_PerVertex {\n"
4370                     "  vec4 gl_Position;\n"
4371                     "};\n"
4372                  << params.shaderFunctions
4373                  << "void main (void)\n"
4374                     "{\n"
4375                     "  ivec3 sz = imageSize(resultImage);\n"
4376                     "  int index = int(gl_in[0].gl_Position.z);\n"
4377                     "  int idx = int(index % sz.x);\n"
4378                     "  int idy = int(index / sz.y);\n"
4379                  << rayQueryPart
4380                  << "  imageStore(resultImage, ivec3(idx, idy, 0), vec4(x, y, z, w));\n"
4381                     "  for (int i = 0; i < gl_in.length(); ++i)\n"
4382                     "  {\n"
4383                     "        gl_Position      = gl_in[i].gl_Position;\n"
4384                     "        EmitVertex();\n"
4385                     "  }\n"
4386                     "  EndPrimitive();\n"
4387                     "}\n";
4388 
4389             programCollection.glslSources.add("geom", &buildOptions) << glu::GeometrySource(geom.str());
4390         }
4391         else if (params.shaderSourceType == RayQueryShaderSourceType::TESSELLATION_EVALUATION)
4392         {
4393             {
4394                 std::stringstream tesc;
4395                 tesc << "#version 460\n"
4396                         "#extension GL_EXT_tessellation_shader : require\n"
4397                         "in gl_PerVertex\n"
4398                         "{\n"
4399                         "  vec4 gl_Position;\n"
4400                         "} gl_in[];\n"
4401                         "layout(vertices = 4) out;\n"
4402                         "out gl_PerVertex\n"
4403                         "{\n"
4404                         "  vec4 gl_Position;\n"
4405                         "} gl_out[];\n"
4406                         "\n"
4407                         "void main (void)\n"
4408                         "{\n"
4409                         "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
4410                         "  gl_TessLevelInner[0] = 1;\n"
4411                         "  gl_TessLevelInner[1] = 1;\n"
4412                         "  gl_TessLevelOuter[gl_InvocationID] = 1;\n"
4413                         "}\n";
4414                 programCollection.glslSources.add("tesc", &buildOptions) << glu::TessellationControlSource(tesc.str());
4415             }
4416 
4417             {
4418                 std::ostringstream tese;
4419                 tese << "#version 460\n"
4420                         "#extension GL_EXT_ray_tracing : enable\n"
4421                         "#extension GL_EXT_tessellation_shader : require\n"
4422                         "#extension GL_EXT_ray_query : require\n"
4423                         "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n"
4424                         "layout(rgba32f, set = 0, binding = 0) uniform image3D resultImage;\n"
4425                         "layout(set = 0, binding = 1) uniform accelerationStructureEXT scene;\n"
4426                         "layout(std430, set = 0, binding = 2) buffer Rays { Ray rays[]; };\n"
4427                         "layout(quads, equal_spacing, ccw) in;\n"
4428                         "in gl_PerVertex\n"
4429                         "{\n"
4430                         "  vec4 gl_Position;\n"
4431                         "} gl_in[];\n"
4432                      << params.shaderFunctions
4433                      << "void main(void)\n"
4434                         "{\n"
4435                         "  ivec3 sz = imageSize(resultImage);\n"
4436                         "  int index = int(gl_in[0].gl_Position.z);\n"
4437                         "  int idx = int(index % sz.x);\n"
4438                         "  int idy = int(index / sz.y);\n"
4439                      << rayQueryPart
4440                      << "  imageStore(resultImage, ivec3(idx, idy, 0), vec4(x, y, z, w));\n"
4441                         "  gl_Position = gl_in[0].gl_Position;\n"
4442                         "}\n";
4443 
4444                 programCollection.glslSources.add("tese", &buildOptions)
4445                     << glu::TessellationEvaluationSource(tese.str());
4446             }
4447         }
4448         else if (params.shaderSourceType == RayQueryShaderSourceType::TESSELLATION_CONTROL)
4449         {
4450             {
4451                 std::ostringstream tesc;
4452                 tesc << "#version 460\n"
4453                         "#extension GL_EXT_ray_tracing : enable\n"
4454                         "#extension GL_EXT_tessellation_shader : require\n"
4455                         "#extension GL_EXT_ray_query : require\n"
4456                         "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n"
4457                         "layout(rgba32f, set = 0, binding = 0) uniform image3D resultImage;\n"
4458                         "layout(set = 0, binding = 1) uniform accelerationStructureEXT scene;\n"
4459                         "layout(std430, set = 0, binding = 2) buffer Rays { Ray rays[]; };\n"
4460                         "in gl_PerVertex\n"
4461                         "{\n"
4462                         "  vec4 gl_Position;\n"
4463                         "} gl_in[];\n"
4464                         "layout(vertices = 4) out;\n"
4465                         "out gl_PerVertex\n"
4466                         "{\n"
4467                         "  vec4 gl_Position;\n"
4468                         "} gl_out[];\n"
4469                         "\n"
4470                      << params.shaderFunctions
4471                      << "void main(void)\n"
4472                         "{\n"
4473                         "  ivec3 sz = imageSize(resultImage);\n"
4474                         "  int index = int(gl_in[0].gl_Position.z);\n"
4475                         "  int idx = int(index % sz.x);\n"
4476                         "  int idy = int(index / sz.y);\n"
4477                      << rayQueryPart
4478                      << "  imageStore(resultImage, ivec3(idx, idy, 0), vec4(x, y, z, w));\n"
4479                         "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
4480                         "  gl_TessLevelInner[0] = 1;\n"
4481                         "  gl_TessLevelInner[1] = 1;\n"
4482                         "  gl_TessLevelOuter[gl_InvocationID] = 1;\n"
4483                         "}\n";
4484 
4485                 programCollection.glslSources.add("tesc", &buildOptions) << glu::TessellationControlSource(tesc.str());
4486             }
4487 
4488             {
4489                 std::ostringstream tese;
4490                 tese << "#version 460\n"
4491                         "#extension GL_EXT_tessellation_shader : require\n"
4492                         "layout(quads, equal_spacing, ccw) in;\n"
4493                         "in gl_PerVertex\n"
4494                         "{\n"
4495                         "  vec4 gl_Position;\n"
4496                         "} gl_in[];\n"
4497                         "\n"
4498                         "void main(void)\n"
4499                         "{\n"
4500                         "  gl_Position = gl_in[0].gl_Position;\n"
4501                         "}\n";
4502 
4503                 programCollection.glslSources.add("tese", &buildOptions)
4504                     << glu::TessellationEvaluationSource(tese.str());
4505             }
4506         }
4507 
4508         break;
4509     }
4510     case RayQueryShaderSourcePipeline::RAYTRACING:
4511     {
4512         std::stringstream rayGen;
4513 
4514         if (params.shaderSourceType == RayQueryShaderSourceType::RAY_GENERATION_RT)
4515         {
4516             rayGen << "#version 460\n"
4517                       "#extension GL_EXT_ray_tracing : enable\n"
4518                       "#extension GL_EXT_ray_query : require\n"
4519                       "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n"
4520                       "struct ResultType { float x; float y; float z; float w; };\n"
4521                       "layout(std430, set = 0, binding = 0) buffer Results { ResultType results[]; };\n"
4522                       "layout(set = 0, binding = 1) uniform accelerationStructureEXT scene;\n"
4523                       "layout(std430, set = 0, binding = 2) buffer Rays { Ray rays[]; };\n"
4524                       "layout(location = 0) rayPayloadEXT vec4 payload;\n"
4525                    << params.shaderFunctions
4526                    << "void main() {\n"
4527                       "   payload = vec4("
4528                    << NO_INT_VALUE << "," << max_t * 2
4529                    << ",0,0);\n"
4530                       "   uint index = (gl_LaunchIDEXT.z * gl_LaunchSizeEXT.x * gl_LaunchSizeEXT.y) + "
4531                       "(gl_LaunchIDEXT.y * gl_LaunchSizeEXT.x) + gl_LaunchIDEXT.x;\n"
4532                    << rayQueryPart
4533                    << "   results[index].x = x;\n"
4534                       "   results[index].y = y;\n"
4535                       "   results[index].z = z;\n"
4536                       "   results[index].w = w;\n"
4537                       "}";
4538 
4539             programCollection.glslSources.add("isect_rt", &buildOptions)
4540                 << glu::IntersectionSource(updateRayTracingGLSL(genericIsect.str()));
4541             programCollection.glslSources.add("chit_rt", &buildOptions) << glu::ClosestHitSource(rtChit.str());
4542             programCollection.glslSources.add("ahit_rt", &buildOptions) << glu::AnyHitSource(genericChit.str());
4543             programCollection.glslSources.add("miss_rt", &buildOptions) << glu::MissSource(genericMiss.str());
4544         }
4545         else if (params.shaderSourceType == RayQueryShaderSourceType::RAY_GENERATION)
4546         {
4547             rayGen << "#version 460\n"
4548                       "#extension GL_EXT_ray_tracing : enable\n"
4549                       "#extension GL_EXT_ray_query : require\n"
4550                       "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n"
4551                       "struct ResultType { float x; float y; float z; float w; };\n"
4552                       "layout(std430, set = 0, binding = 0) buffer Results { ResultType results[]; };\n"
4553                       "layout(set = 0, binding = 1) uniform accelerationStructureEXT scene;\n"
4554                       "layout(std430, set = 0, binding = 2) buffer Rays { Ray rays[]; };\n"
4555                    << params.shaderFunctions
4556                    << "void main() {\n"
4557                       "   uint index = (gl_LaunchIDEXT.z * gl_LaunchSizeEXT.x * gl_LaunchSizeEXT.y) + "
4558                       "(gl_LaunchIDEXT.y * gl_LaunchSizeEXT.x) + gl_LaunchIDEXT.x;\n"
4559                    << rayQueryPart
4560                    << "   results[index].x = x;\n"
4561                       "   results[index].y = y;\n"
4562                       "   results[index].z = z;\n"
4563                       "   results[index].w = w;\n"
4564                       "}";
4565         }
4566         else if (params.shaderSourceType == RayQueryShaderSourceType::CALLABLE)
4567         {
4568             rayGen << "#version 460\n"
4569                       "#extension GL_EXT_ray_tracing : require\n"
4570                       "struct CallValue\n{\n"
4571                       "  uint index;\n"
4572                       "  vec4 hitAttrib;\n"
4573                       "};\n"
4574                       "layout(location = 0) callableDataEXT CallValue param;\n"
4575                       "struct ResultType { float x; float y; float z; float w; };\n"
4576                       "layout(std430, set = 0, binding = 0) buffer Results { ResultType results[]; };\n"
4577                       "void main()\n"
4578                       "{\n"
4579                       "  uint index = (gl_LaunchIDEXT.z * gl_LaunchSizeEXT.x * gl_LaunchSizeEXT.y) + (gl_LaunchIDEXT.y "
4580                       "* gl_LaunchSizeEXT.x) + gl_LaunchIDEXT.x;\n"
4581                       "  param.index = index;\n"
4582                       "  param.hitAttrib = vec4(0, 0, 0, 0);\n"
4583                       "  executeCallableEXT(0, 0);\n"
4584                       "  results[index].x = param.hitAttrib.x;\n"
4585                       "  results[index].y = param.hitAttrib.y;\n"
4586                       "  results[index].z = param.hitAttrib.z;\n"
4587                       "  results[index].w = param.hitAttrib.w;\n"
4588                       "}\n";
4589         }
4590         else
4591         {
4592             rayGen << "#version 460\n"
4593                       "#extension GL_EXT_ray_tracing : require\n"
4594                       "#extension GL_EXT_ray_query : require\n"
4595                       "layout(location = 0) rayPayloadEXT vec4 payload;\n"
4596                       "struct ResultType { float x; float y; float z; float w; };\n"
4597                       "layout(std430, set = 0, binding = 0) buffer Results { ResultType results[]; };\n"
4598                       "layout(set = 0, binding = 3) uniform accelerationStructureEXT traceEXTAccel;\n"
4599                       "void main()\n"
4600                       "{\n"
4601                       "  payload = vec4("
4602                    << NO_INT_VALUE << "," << max_t * 2
4603                    << ",0,0);\n"
4604                       "  uint index = (gl_LaunchIDEXT.z * gl_LaunchSizeEXT.x * gl_LaunchSizeEXT.y) + (gl_LaunchIDEXT.y "
4605                       "* gl_LaunchSizeEXT.x) + gl_LaunchIDEXT.x;\n"
4606                       "  traceRayEXT(traceEXTAccel, 0, 0xFF, 0, 0, 0, vec3(0.1, 0.1, 0.0), 0.0, vec3(0.0, 0.0, 1.0), "
4607                       "500.0, 0);\n"
4608                       "  results[index].x = payload.x;\n"
4609                       "  results[index].y = payload.y;\n"
4610                       "  results[index].z = payload.z;\n"
4611                       "  results[index].w = payload.w;\n"
4612                       "}\n";
4613         }
4614 
4615         programCollection.glslSources.add("rgen", &buildOptions) << glu::RaygenSource(rayGen.str());
4616 
4617         if (params.shaderSourceType == RayQueryShaderSourceType::CLOSEST_HIT)
4618         {
4619             programCollection.glslSources.add("chit", &buildOptions)
4620                 << glu::ClosestHitSource(genericRayTracingSetResultsShader.str());
4621             programCollection.glslSources.add("miss", &buildOptions) << glu::MissSource(genericMiss.str());
4622             programCollection.glslSources.add("isect", &buildOptions)
4623                 << glu::IntersectionSource(updateRayTracingGLSL(genericIsect.str()));
4624         }
4625         else if (params.shaderSourceType == RayQueryShaderSourceType::ANY_HIT)
4626         {
4627             programCollection.glslSources.add("ahit", &buildOptions)
4628                 << glu::AnyHitSource(genericRayTracingSetResultsShader.str());
4629             programCollection.glslSources.add("miss", &buildOptions) << glu::MissSource(genericMiss.str());
4630             programCollection.glslSources.add("isect", &buildOptions)
4631                 << glu::IntersectionSource(updateRayTracingGLSL(genericIsect.str()));
4632         }
4633         else if (params.shaderSourceType == RayQueryShaderSourceType::MISS)
4634         {
4635 
4636             programCollection.glslSources.add("chit", &buildOptions) << glu::ClosestHitSource(genericChit.str());
4637             programCollection.glslSources.add("miss_1", &buildOptions)
4638                 << glu::MissSource(genericRayTracingSetResultsShader.str());
4639             programCollection.glslSources.add("isect", &buildOptions)
4640                 << glu::IntersectionSource(updateRayTracingGLSL(genericIsect.str()));
4641         }
4642         else if (params.shaderSourceType == RayQueryShaderSourceType::INTERSECTION)
4643         {
4644             {
4645                 std::stringstream chit;
4646                 chit << "#version 460    \n"
4647                         "#extension GL_EXT_ray_tracing : require\n"
4648                         "#extension GL_EXT_ray_query : require\n"
4649                         "layout(location = 0) rayPayloadInEXT vec4 payload;\n"
4650                         "hitAttributeEXT vec4 hitAttrib;\n"
4651                         "void main()\n"
4652                         "{\n"
4653                         "  payload = hitAttrib;\n"
4654                         "}\n";
4655 
4656                 programCollection.glslSources.add("chit", &buildOptions) << glu::ClosestHitSource(chit.str());
4657             }
4658 
4659             programCollection.glslSources.add("miss", &buildOptions) << glu::MissSource(genericMiss.str());
4660 
4661             {
4662                 std::stringstream isect;
4663                 isect << "#version 460\n"
4664                          "#extension GL_EXT_ray_tracing : require\n"
4665                          "#extension GL_EXT_ray_query : require\n"
4666                          "hitAttributeEXT vec4 hitValue;\n"
4667                          "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n"
4668                          "layout(set = 0, binding = 1) uniform accelerationStructureEXT scene;\n"
4669                          "layout(std430, set = 0, binding = 2) buffer Rays { Ray rays[]; };\n"
4670                       << params.shaderFunctions
4671                       << "void main()\n"
4672                          "{\n"
4673                          "  uint index = (gl_LaunchIDEXT.z * gl_LaunchSizeEXT.x * gl_LaunchSizeEXT.y) + "
4674                          "(gl_LaunchIDEXT.y * gl_LaunchSizeEXT.x) + gl_LaunchIDEXT.x;\n"
4675                       << rayQueryPart
4676                       << "  hitValue.x = x;\n"
4677                          "  hitValue.y = y;\n"
4678                          "  hitValue.z = z;\n"
4679                          "  hitValue.w = w;\n"
4680                          "  reportIntersectionEXT(0.5f, 0);\n"
4681                          "}\n";
4682 
4683                 programCollection.glslSources.add("isect_1", &buildOptions)
4684                     << glu::IntersectionSource(updateRayTracingGLSL(isect.str()));
4685             }
4686         }
4687         else if (params.shaderSourceType == RayQueryShaderSourceType::CALLABLE)
4688         {
4689             {
4690                 std::stringstream call;
4691                 call << "#version 460\n"
4692                         "#extension GL_EXT_ray_tracing : require\n"
4693                         "#extension GL_EXT_ray_query : require\n"
4694                         "struct CallValue\n{\n"
4695                         "  uint index;\n"
4696                         "  vec4 hitAttrib;\n"
4697                         "};\n"
4698                         "layout(location = 0) callableDataInEXT CallValue result;\n"
4699                         "struct Ray { vec3 pos; float tmin; vec3 dir; float tmax; };\n"
4700                         "layout(set = 0, binding = 1) uniform accelerationStructureEXT scene;\n"
4701                         "layout(std430, set = 0, binding = 2) buffer Rays { Ray rays[]; };\n"
4702                      << params.shaderFunctions
4703                      << "void main()\n"
4704                         "{\n"
4705                         "  uint index = result.index;\n"
4706                      << rayQueryPart
4707                      << "  result.hitAttrib.x = x;\n"
4708                         "  result.hitAttrib.y = y;\n"
4709                         "  result.hitAttrib.z = z;\n"
4710                         "  result.hitAttrib.w = w;\n"
4711                         "}\n";
4712 
4713                 programCollection.glslSources.add("call", &buildOptions)
4714                     << glu::CallableSource(updateRayTracingGLSL(call.str()));
4715             }
4716 
4717             programCollection.glslSources.add("chit", &buildOptions) << glu::ClosestHitSource(genericChit.str());
4718             programCollection.glslSources.add("miss", &buildOptions) << glu::MissSource(genericMiss.str());
4719         }
4720 
4721         break;
4722     }
4723     default:
4724     {
4725         TCU_FAIL("Shader type not valid.");
4726     }
4727     }
4728 }
4729 
4730 #else
4731 
4732 uint32_t rayTracingDefineAnything()
4733 {
4734     return 0;
4735 }
4736 
4737 #endif // CTS_USES_VULKANSC
4738 
4739 } // namespace vk
4740