xref: /aosp_15_r20/external/mesa3d/src/intel/vulkan/grl/gpu/bvh_copy.cl (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1//
2// Copyright (C) 2009-2021 Intel Corporation
3//
4// SPDX-License-Identifier: MIT
5//
6//
7
8#include "api_interface.h"
9#include "d3d12.h"
10#include "common.h"
11#include "mem_utils.h"
12#include "misc_shared.h"
13
14#define offsetof(TYPE, ELEMENT) ((size_t)&(((TYPE *)0)->ELEMENT))
15
16GRL_INLINE
17uint GroupCountForCopySize(uint size)
18{
19    return (size >> 8) + 4;
20}
21
22GRL_INLINE
23uint GroupCountForCopy(BVHBase* base)
24{
25    return GroupCountForCopySize(base->Meta.allocationSize);
26}
27
28GRL_INLINE void copyInstanceDescs(InstanceDesc* instances, D3D12_RAYTRACING_INSTANCE_DESC* descs, uint64_t numInstances)
29{
30    for (uint64_t instanceIndex = get_local_id(0); instanceIndex < numInstances; instanceIndex += get_local_size(0))
31    {
32        for (uint row = 0; row < 3; row++)
33        {
34            for (uint column = 0; column < 4; column++)
35            {
36                D3D12_set_transform(&descs[instanceIndex], row, column, InstanceDesc_get_transform(&instances[instanceIndex], row, column));
37            }
38        }
39        D3D12_set_instanceID(&descs[instanceIndex], InstanceDesc_get_instanceID(&instances[instanceIndex]));
40        D3D12_set_InstanceMask(&descs[instanceIndex], InstanceDesc_get_InstanceMask(&instances[instanceIndex]));
41        D3D12_set_InstanceContributionToHitGroupIndex(&descs[instanceIndex], InstanceDesc_get_InstanceContributionToHitGroupIndex(&instances[instanceIndex]));
42        D3D12_set_InstanceFlags(&descs[instanceIndex], InstanceDesc_get_InstanceFlags(&instances[instanceIndex]));
43        D3D12_set_AccelerationStructure(&descs[instanceIndex], InstanceDesc_get_AccelerationStructure(&instances[instanceIndex]));
44    }
45}
46
47GRL_INLINE void createGeoDescs(GeoMetaData* geoMetaData, D3D12_RAYTRACING_GEOMETRY_DESC* descs, uint64_t numGeos, const uint64_t dataBufferStart)
48{
49    if (get_local_id(0) == 0)
50    {
51        uint64_t previousGeoDataBufferEnd = dataBufferStart;
52        for (uint64_t geoIndex = 0; geoIndex < numGeos; geoIndex += 1)
53        {
54            D3D12_set_Type(&descs[geoIndex], (uint8_t)(0xffff & geoMetaData[geoIndex].Type));
55            D3D12_set_Flags(&descs[geoIndex], (uint8_t)(0xffff & geoMetaData[geoIndex].Flags));
56            if (geoMetaData[geoIndex].Type == GEOMETRY_TYPE_TRIANGLES)
57            {
58                // Every triangle is stored separately
59                uint64_t vertexBufferSize = 9 * sizeof(float) * geoMetaData[geoIndex].PrimitiveCount;
60                D3D12_set_triangles_Transform(&descs[geoIndex], 0);
61                D3D12_set_triangles_IndexFormat(&descs[geoIndex], INDEX_FORMAT_NONE);
62                D3D12_set_triangles_VertexFormat(&descs[geoIndex], VERTEX_FORMAT_R32G32B32_FLOAT);
63                D3D12_set_triangles_IndexCount(&descs[geoIndex], 0);
64                D3D12_set_triangles_VertexCount(&descs[geoIndex], geoMetaData[geoIndex].PrimitiveCount * 3);
65                D3D12_set_triangles_IndexBuffer(&descs[geoIndex], (D3D12_GPU_VIRTUAL_ADDRESS)previousGeoDataBufferEnd);
66                D3D12_set_triangles_VertexBuffer_StartAddress(&descs[geoIndex], (D3D12_GPU_VIRTUAL_ADDRESS)previousGeoDataBufferEnd);
67                D3D12_set_triangles_VertexBuffer_StrideInBytes(&descs[geoIndex], 3 * sizeof(float));
68                previousGeoDataBufferEnd += vertexBufferSize;
69            }
70            else
71            {
72                D3D12_set_procedurals_AABBCount(&descs[geoIndex], geoMetaData[geoIndex].PrimitiveCount);
73                D3D12_set_procedurals_AABBs_StartAddress(&descs[geoIndex], (D3D12_GPU_VIRTUAL_ADDRESS)previousGeoDataBufferEnd);
74                D3D12_set_procedurals_AABBs_StrideInBytes(&descs[geoIndex], sizeof(D3D12_RAYTRACING_AABB));
75                previousGeoDataBufferEnd += sizeof(D3D12_RAYTRACING_AABB) * geoMetaData[geoIndex].PrimitiveCount;
76            }
77        }
78    }
79}
80
81GRL_INLINE void copyIndiciesAndVerticies(D3D12_RAYTRACING_GEOMETRY_DESC* desc, QuadLeaf* quad)
82{
83    float* vertices = (float*)D3D12_get_triangles_VertexBuffer_StartAddress(desc);
84    uint64_t firstTriangleIndex = quad->primIndex0;
85    uint64_t numTriangles = QuadLeaf_IsSingleTriangle(quad) ? 1 : 2;
86
87    vertices[firstTriangleIndex * 9] = quad->v[0][0];
88    vertices[firstTriangleIndex * 9 + 1] = quad->v[0][1];
89    vertices[firstTriangleIndex * 9 + 2] = quad->v[0][2];
90
91    vertices[firstTriangleIndex * 9 + 3] = quad->v[1][0];
92    vertices[firstTriangleIndex * 9 + 4] = quad->v[1][1];
93    vertices[firstTriangleIndex * 9 + 5] = quad->v[1][2];
94
95    vertices[firstTriangleIndex * 9 + 6] = quad->v[2][0];
96    vertices[firstTriangleIndex * 9 + 7] = quad->v[2][1];
97    vertices[firstTriangleIndex * 9 + 8] = quad->v[2][2];
98
99    if (numTriangles == 2)
100    {
101        uint64_t secondTriangleIndex = firstTriangleIndex + QuadLeaf_GetPrimIndexDelta(quad);
102        uint32_t packed_indices = QuadLeaf_GetSecondTriangleIndices(quad);
103        for( size_t i=0; i<3; i++ )
104        {
105            uint32_t idx = packed_indices & 3 ; packed_indices >>= 2;
106            for( size_t j=0; j<3; j++ )
107                vertices[secondTriangleIndex * 9 + i * 3 + j] = quad->v[idx][j];
108        }
109    }
110}
111
112GRL_INLINE
113void storeProceduralDesc(
114    struct AABB     procAABB,
115    uint32_t        primId,
116    D3D12_RAYTRACING_GEOMETRY_DESC* geoDesc)
117{
118    D3D12_RAYTRACING_AABB* proceduralDescs = (D3D12_RAYTRACING_AABB*)D3D12_get_procedurals_AABBs_StartAddress(geoDesc);
119    D3D12_set_raytracing_aabb(&proceduralDescs[primId], &procAABB);
120}
121
122GRL_INLINE
123void copyDataFromLProcedurals(
124    BVHBase* base,
125    D3D12_RAYTRACING_GEOMETRY_DESC* descs)
126{
127    unsigned numProcedurals = BVHBase_GetNumProcedurals(base);
128    InternalNode* innerNodes = BVHBase_GetInternalNodes(base);
129    unsigned numInnerNodes = BVHBase_GetNumInternalNodes(base);
130
131    if (BVHBase_GetNumProcedurals(base) > 0) //< there's no point entering here if there are no procedurals
132    {
133
134        // iterate on all inner nodes to identify those with procedural children, we have to take aabbs from them
135        for (uint32_t nodeI = get_local_id(0); nodeI < numInnerNodes; nodeI += get_local_size(0))
136        {
137            InternalNode* innerNode = innerNodes + nodeI;
138
139            if (innerNode->nodeType == NODE_TYPE_PROCEDURAL)
140            {
141                float* origin = innerNode->lower;
142
143                global struct ProceduralLeaf* leaf = (global struct ProceduralLeaf*)QBVHNodeN_childrenPointer((struct QBVHNodeN*)innerNode);
144
145                for (uint k = 0; k < 6; k++)
146                {
147                    if (InternalNode_IsChildValid(innerNode, k))
148                    {
149                        struct AABB3f qbounds = {
150                            (float)(innerNode->lower_x[k]), (float)(innerNode->lower_y[k]), (float)(innerNode->lower_z[k]),
151                            (float)(innerNode->upper_x[k]), (float)(innerNode->upper_y[k]), (float)(innerNode->upper_z[k]) };
152
153                        struct AABB dequantizedAABB;
154
155                        dequantizedAABB.lower[0] = origin[0] + bitShiftLdexp(qbounds.lower[0], innerNode->exp_x - 8);
156                        dequantizedAABB.lower[1] = origin[1] + bitShiftLdexp(qbounds.lower[1], innerNode->exp_y - 8);
157                        dequantizedAABB.lower[2] = origin[2] + bitShiftLdexp(qbounds.lower[2], innerNode->exp_z - 8);
158                        dequantizedAABB.upper[0] = origin[0] + bitShiftLdexp(qbounds.upper[0], innerNode->exp_x - 8);
159                        dequantizedAABB.upper[1] = origin[1] + bitShiftLdexp(qbounds.upper[1], innerNode->exp_y - 8);
160                        dequantizedAABB.upper[2] = origin[2] + bitShiftLdexp(qbounds.upper[2], innerNode->exp_z - 8);
161
162                        dequantizedAABB = conservativeAABB(&dequantizedAABB);
163                        /* extract geomID and primID from leaf */
164                        const uint startPrim = QBVHNodeN_startPrim((struct QBVHNodeN*) innerNode, k);
165                        const uint geomID = ProceduralLeaf_geomIndex(leaf);
166                        const uint primID = ProceduralLeaf_primIndex(leaf, startPrim); // FIXME: have to iterate over all primitives of leaf!
167
168                        storeProceduralDesc(dequantizedAABB, primID, descs + geomID);
169                    }
170                    /* advance leaf pointer to next child */
171                    leaf += QBVHNodeN_blockIncr((struct QBVHNodeN*)innerNode, k);
172                }
173
174            }
175            else if (innerNode->nodeType == NODE_TYPE_MIXED) { ERROR(); }
176            else {/* do nothing for other internal node types, they can't have procedural child (directly)*/; }
177        }
178    }
179}
180
181GRL_INLINE
182void copyDataFromQuadLeaves(BVHBase* base,
183    D3D12_RAYTRACING_GEOMETRY_DESC* descs)
184{
185    QuadLeaf* quads = BVHBase_GetQuadLeaves(base);
186    uint64_t numQuads = BVHBase_GetNumQuads(base);
187    for (uint64_t quadIdx = get_local_id(0); quadIdx < numQuads; quadIdx += get_local_size(0))
188    {
189        uint64_t descIdx = PrimLeaf_GetGeoIndex(&quads[quadIdx].leafDesc);
190        copyIndiciesAndVerticies(&descs[descIdx], &quads[quadIdx]);
191    }
192}
193
194GRL_ANNOTATE_IGC_DO_NOT_SPILL
195__attribute__((reqd_work_group_size(MAX_HW_SIMD_WIDTH, 1, 1)))
196__attribute__((intel_reqd_sub_group_size(MAX_HW_SIMD_WIDTH)))
197void kernel clone_indirect(global char* dest,
198    global char* src)
199{
200    BVHBase* base = (BVHBase*)src;
201    uint64_t bvhSize = base->Meta.allocationSize;
202
203    uint numGroups = GroupCountForCopy(base);
204    CopyMemory(dest, src, bvhSize, numGroups);
205}
206
207GRL_INLINE void compactT(global char* dest, global char* src, uint64_t compactedSize, uint skipCopy, uint groupCnt)
208{
209    global BVHBase* baseSrc = (global BVHBase*)src;
210    global BVHBase* baseDest = (global BVHBase*)dest;
211
212    uint32_t offset = sizeof(BVHBase);
213    uint32_t numNodes = BVHBase_GetNumInternalNodes(baseSrc);
214    uint32_t nodeSize = numNodes * sizeof(InternalNode);
215    offset += nodeSize;
216
217    int quadChildFix = baseSrc->quadLeafStart;
218    int procChildFix = baseSrc->proceduralDataStart;
219    int instChildFix = baseSrc->instanceLeafStart;
220
221    // serialization already copies part of bvh base so skip this part
222    CopyMemory(dest + skipCopy, src + skipCopy, sizeof(BVHBase) - skipCopy, groupCnt);
223    baseDest->Meta.allocationSize = compactedSize;
224
225    if (baseSrc->Meta.instanceCount)
226    {
227        const uint32_t instLeafsSize = BVHBase_GetNumHWInstanceLeaves(baseSrc) * sizeof(HwInstanceLeaf);
228        CopyMemory(dest + offset, (global char*)BVHBase_GetHWInstanceLeaves(baseSrc), instLeafsSize, groupCnt);
229        const uint instanceLeafStart = (uint)(offset / 64);
230        baseDest->instanceLeafStart = instanceLeafStart;
231        instChildFix -= instanceLeafStart;
232        offset += instLeafsSize;
233        baseDest->instanceLeafEnd = (uint)(offset / 64);
234    }
235    if (baseSrc->Meta.geoCount)
236    {
237        const uint quadLeafsSize = BVHBase_GetNumQuads(baseSrc) * sizeof(QuadLeaf);
238        if (quadLeafsSize)
239        {
240            CopyMemory(dest + offset, (global char*)BVHBase_GetQuadLeaves(baseSrc), quadLeafsSize, groupCnt);
241            const uint quadLeafStart = (uint)(offset / 64);
242            baseDest->quadLeafStart = quadLeafStart;
243            quadChildFix -= quadLeafStart;
244            offset += quadLeafsSize;
245            baseDest->quadLeafCur = (uint)(offset / 64);
246        }
247
248        const uint procLeafsSize = BVHBase_GetNumProcedurals(baseSrc) * sizeof(ProceduralLeaf);
249        if (procLeafsSize)
250        {
251            CopyMemory(dest + offset, (global char*)BVHBase_GetProceduralLeaves(baseSrc), procLeafsSize, groupCnt);
252            const uint proceduralDataStart = (uint)(offset / 64);
253            baseDest->proceduralDataStart = proceduralDataStart;
254            procChildFix -= proceduralDataStart;
255            offset += procLeafsSize;
256            baseDest->proceduralDataCur = (uint)(offset / 64);
257        }
258    }
259    // copy nodes with fixed child offsets
260    global uint* nodeDest = (global uint*)(dest + sizeof(BVHBase));
261    global InternalNode* nodeSrc = (global InternalNode*)BVHBase_GetInternalNodes(baseSrc);
262    // used in mixed case
263    char* instanceLeavesBegin = (char*)BVHBase_GetHWInstanceLeaves(baseSrc);
264    char* instanceLeavesEnd = (char*)BVHBase_GetHWInstanceLeaves_End(baseSrc);
265    uint localId = get_sub_group_local_id();
266    for (uint i = get_group_id(0); i < numNodes; i += groupCnt)
267    {
268        uint nodePart = CacheLineSubgroupRead((const global char*)&nodeSrc[i]);
269        char nodeType = as_char4(sub_group_broadcast(nodePart, offsetof(InternalNode, nodeType) / 4))[0];
270        if (localId * 4 == offsetof(InternalNode, childOffset))
271        {
272            int childOffset = as_int(nodePart);
273            if (nodeType == NODE_TYPE_MIXED)
274            {
275                char* childPtr = (char*)&nodeSrc[i] + 64 * childOffset;
276                if (childPtr > instanceLeavesBegin && childPtr < instanceLeavesEnd)
277                    nodePart = as_int(childOffset - instChildFix);
278            }
279            else if (nodeType == NODE_TYPE_INSTANCE)
280                nodePart = as_int(childOffset - instChildFix);
281            else if (nodeType == NODE_TYPE_QUAD)
282                nodePart = as_int(childOffset - quadChildFix);
283            else if (nodeType == NODE_TYPE_PROCEDURAL)
284                nodePart = as_int(childOffset - procChildFix);
285        }
286        nodeDest[i * 16 + localId] = nodePart;
287    }
288
289    if (baseSrc->Meta.instanceCount)
290    {
291        const uint32_t instanceDescSize = baseSrc->Meta.instanceCount * sizeof(InstanceDesc);
292        CopyMemory(dest + offset, src + baseSrc->Meta.instanceDescsStart, instanceDescSize, groupCnt);
293        baseDest->Meta.instanceDescsStart = offset;
294        offset += instanceDescSize;
295    }
296    if (baseSrc->Meta.geoCount)
297    {
298        const uint32_t geoMetaSize = baseSrc->Meta.geoCount * sizeof(GeoMetaData);
299        CopyMemory(dest + offset, src + baseSrc->Meta.geoDescsStart, geoMetaSize, groupCnt);
300        baseDest->Meta.geoDescsStart = offset;
301        offset += (geoMetaSize + 63) & ~63; // align to 64
302    }
303
304    uint backPointerDataStart     = offset / 64;
305    uint refitTreeletsDataStart   = backPointerDataStart;
306    uint refitStartPointDataStart = backPointerDataStart;
307    uint dataEnd                  = backPointerDataStart;
308    uint fatLeafTableStart = dataEnd;
309    uint fatLeafCount      = baseSrc->fatLeafCount;
310    uint innerTableStart   = dataEnd;
311    uint innerCount        = baseSrc->innerCount;
312
313    uint quadLeftoversCountNewAtomicUpdate = baseSrc->quadLeftoversCountNewAtomicUpdate;
314    uint quadTableSizeNewAtomicUpdate = baseSrc->quadTableSizeNewAtomicUpdate;
315    uint quadIndicesDataStart = dataEnd;
316
317    if (BVHBase_HasBackPointers(baseSrc))
318    {
319#if 0 //
320        const uint oldbackpontersDataStart = baseSrc->backPointerDataStart;
321        const uint shift = oldbackpontersDataStart - backPointerDataStart;
322        const uint refitStructsSize = ((BVHBase_GetRefitStructsDataSize(baseSrc)) + 63) & ~63;
323
324        CopyMemory(dest + offset, (global char*)BVHBase_GetBackPointers(baseSrc), refitStructsSize, groupCnt);
325
326        refitTreeletsDataStart   = baseSrc->refitTreeletsDataStart - shift;
327        refitStartPointDataStart = baseSrc->refitStartPointDataStart - shift;
328        dataEnd                  = baseSrc->BVHDataEnd - shift;
329#else // compacting version
330        const uint backpointersSize = ((numNodes*sizeof(uint)) + 63) & ~63;
331        CopyMemory(dest + offset, (global char*)BVHBase_GetBackPointers(baseSrc), backpointersSize, groupCnt);
332        offset += backpointersSize;
333
334        refitTreeletsDataStart = offset / 64;
335        refitStartPointDataStart = offset / 64;
336
337        // TODO: remove treelets from .... everywhere
338        const uint treeletExecutedCnt = *BVHBase_GetRefitTreeletCntPtr(baseSrc);
339
340        if (treeletExecutedCnt)
341        {
342            const uint treeletCnt = treeletExecutedCnt > 1 ? treeletExecutedCnt + 1 : 1;
343
344            refitTreeletsDataStart = offset / 64;
345            const uint treeletsSize = ((treeletCnt * sizeof(RefitTreelet)) + 63) & ~63;
346            RefitTreelet* destTreelets = (RefitTreelet*)(dest + offset);
347            RefitTreelet* srcTreelets = BVHBase_GetRefitTreeletDescs(baseSrc);
348
349            uint numThreads = groupCnt * get_local_size(0);
350            uint globalID = (get_group_id(0) * get_local_size(0)) + get_local_id(0);
351
352            for (uint i = globalID; i < treeletCnt; i += numThreads)
353            {
354                RefitTreelet dsc = srcTreelets[i];
355                RefitTreeletTrivial* trivial_dsc = (RefitTreeletTrivial*)&dsc;
356                if (trivial_dsc->numStartpoints == 1 && trivial_dsc->childrenOffsetOfTheNode > numNodes) {
357                    trivial_dsc->childrenOffsetOfTheNode -= quadChildFix;
358                }
359                destTreelets[i] = dsc;
360            }
361
362            offset += treeletsSize;
363
364            refitStartPointDataStart = offset / 64;
365            const uint startPointsSize = (BVHBase_GetRefitStartPointsSize(baseSrc) + 63) & ~63;
366            CopyMemory(dest + offset, (global char*)BVHBase_GetRefitStartPoints(baseSrc), startPointsSize, groupCnt);
367            offset += startPointsSize;
368            dataEnd = offset / 64;
369        }
370
371        uint fatleafEntriesSize = ((fatLeafCount * sizeof(LeafTableEntry) + 63) & ~63);
372        fatLeafTableStart = offset / 64;
373        if (fatleafEntriesSize) {
374            CopyMemory(dest + offset, (global char*)BVHBase_GetFatLeafTable(baseSrc), fatleafEntriesSize, groupCnt);
375        }
376        offset += fatleafEntriesSize;
377
378        // New atomic update
379        if(baseSrc->quadIndicesDataStart > baseSrc->backPointerDataStart)
380        {
381            uint numQuads = BVHBase_GetNumQuads(baseSrc);
382            uint quadTableMainBufferSize = (numQuads + 255) & ~255;
383            uint quadLeftoversSize = (quadLeftoversCountNewAtomicUpdate + 255) & ~255;
384            uint quadTableEntriesSize = (((quadTableMainBufferSize + quadLeftoversSize) * sizeof(LeafTableEntry) + 63) & ~63);
385            if (quadTableEntriesSize) {
386                CopyMemory(dest + offset, (global char*)BVHBase_GetFatLeafTable(baseSrc), quadTableEntriesSize, groupCnt);
387            }
388            offset += quadTableEntriesSize;
389
390            uint quadIndicesDataSize = ((numQuads * sizeof(QuadDataIndices) + 63) & ~63);
391            quadIndicesDataStart = offset / 64;
392            if (quadIndicesDataSize) {
393                CopyMemory(dest + offset, (global char*)BVHBase_GetQuadDataIndicesTable(baseSrc), quadIndicesDataSize, groupCnt);
394            }
395            offset += quadIndicesDataSize;
396        }
397
398        uint innerEntriesSize = ((innerCount * sizeof(InnerNodeTableEntry) + 63) & ~63);
399        innerTableStart = offset / 64;
400        if (innerEntriesSize) {
401            CopyMemory(dest + offset, (global char*)BVHBase_GetInnerNodeTable(baseSrc), innerEntriesSize, groupCnt);
402        }
403        offset += innerEntriesSize;
404
405        dataEnd = offset / 64;
406#endif
407    }
408
409    baseDest->backPointerDataStart = backPointerDataStart;
410    baseDest->refitTreeletsDataStart = refitTreeletsDataStart;
411    baseDest->refitStartPointDataStart = refitStartPointDataStart;
412    baseDest->fatLeafTableStart = fatLeafTableStart ;
413    baseDest->fatLeafCount = fatLeafCount;
414    baseDest->innerTableStart = innerTableStart;
415    baseDest->innerCount = innerCount;
416
417    baseDest->quadLeftoversCountNewAtomicUpdate = quadLeftoversCountNewAtomicUpdate;
418    baseDest->quadTableSizeNewAtomicUpdate = quadTableSizeNewAtomicUpdate;
419    baseDest->quadIndicesDataStart = quadIndicesDataStart;
420    baseDest->BVHDataEnd = dataEnd;
421}
422
423GRL_ANNOTATE_IGC_DO_NOT_SPILL
424__attribute__((intel_reqd_sub_group_size(MAX_HW_SIMD_WIDTH)))
425__attribute__((reqd_work_group_size(MAX_HW_SIMD_WIDTH, 1, 1)))
426void kernel compact(global char* dest,
427    global char* src,
428    uint groupCnt)
429{
430    uint64_t compactedSize = compute_compacted_size((BVHBase*)src);
431    compactT(dest, src, compactedSize, 0, groupCnt);
432}
433
434// set serialization header along all lanes, each lane will get one dword of header plus 64bit reminding data
435GRL_INLINE
436unsigned prepare_header(
437    uint64_t headerSize,
438    uint64_t instancePtrSize,
439    uint64_t numInstances,
440    uint64_t bvhSize,
441    uint8_t* driverID,
442    uint64_t reminder)
443{
444
445    unsigned loc_id = get_sub_group_local_id();
446
447    uint64_t SerializedSizeInBytesIncludingHeader = headerSize + instancePtrSize * numInstances + bvhSize;
448    uint64_t DeserializedSizeInBytes = bvhSize;
449    uint64_t InstanceHandleCount = numInstances;
450
451    char bvh_magic_str[] = BVH_MAGIC_MACRO;
452    uint* bvh_magic_uint = (uint*)bvh_magic_str;
453
454    unsigned headerTempLanePiece;
455    if (loc_id < 4) { headerTempLanePiece = *((unsigned*)&driverID[4*loc_id]); }
456    else if (loc_id == 4) { headerTempLanePiece = bvh_magic_uint[0]; }
457    else if (loc_id == 5) { headerTempLanePiece = bvh_magic_uint[1]; }
458    else if (loc_id == 6) { headerTempLanePiece = bvh_magic_uint[2]; }
459    else if (loc_id == 7) { headerTempLanePiece = bvh_magic_uint[3]; }
460    else if (loc_id == 8) { headerTempLanePiece = (uint)SerializedSizeInBytesIncludingHeader; }
461    else if (loc_id == 9) { headerTempLanePiece = (uint)(SerializedSizeInBytesIncludingHeader >> 32ul); }
462    else if (loc_id == 10) { headerTempLanePiece = (uint)DeserializedSizeInBytes; }
463    else if (loc_id == 11) { headerTempLanePiece = (uint)(DeserializedSizeInBytes >> 32ul); }
464    else if (loc_id == 12) { headerTempLanePiece = (uint)InstanceHandleCount; }
465    else if (loc_id == 13) { headerTempLanePiece = (uint)(InstanceHandleCount >> 32ul); }
466    else if (loc_id == 14) { headerTempLanePiece = (uint)reminder; }
467    else if (loc_id == 15) { headerTempLanePiece = (uint)(reminder >> 32ul); }
468
469    return headerTempLanePiece;
470}
471
472
473
474
475GRL_INLINE
476void serializeT(
477    global byte_align64B* dest,
478    global byte_align64B* src,
479    global uint8_t* driverID,
480    uint groups_count)
481{
482    SerializationHeader* header = (SerializationHeader*)dest;
483    BVHBase* base = (BVHBase*)src;
484
485    const uint headerSize = sizeof(SerializationHeader);
486    const uint numInstances = base->Meta.instanceCount;
487    const uint instancePtrSize = sizeof(gpuva_t);
488    const uint compactedSize = compute_compacted_size(base);
489    uint local_id = get_sub_group_local_id();
490
491    // this is not 64byte aligned :(
492    const uint offsetToBvh = headerSize + instancePtrSize * numInstances;
493
494    global InstanceDesc* src_instances = 0;
495
496    if (numInstances) {
497        src_instances = (global InstanceDesc*)((uint64_t)base + base->Meta.instanceDescsStart);
498    }
499
500    // effectively this part should end up as one 64B aligned 64B write
501    if (get_group_id(0) == groups_count - 1)
502    {
503        Block64B headerPlus;
504
505        // we patch the missing piece with instance or bhv beginning (TRICK A and B)
506        // we assume header is 56B.
507        global uint64_t* srcPiece = (numInstances != 0) ? &src_instances[0].AccelerationStructureGPUVA : (global uint64_t*)src;
508
509        unsigned headerTemp;
510
511        headerTemp = prepare_header(
512            headerSize,
513            instancePtrSize,
514            numInstances,
515            compactedSize,
516            driverID,
517            *srcPiece);
518
519        CacheLineSubgroupWrite((global byte_align64B*)dest, headerTemp);
520    }
521
522    if (numInstances > 0)
523    {
524        uint instancesOffset = headerSize;
525        uint aligned_instance_ptrs_offset = ((instancesOffset + 63) >> 6) << 6;
526        uint unaligned_prefixing_instance_cnt = (aligned_instance_ptrs_offset - instancesOffset) >> 3;
527        unaligned_prefixing_instance_cnt = min(unaligned_prefixing_instance_cnt, numInstances);
528
529        global uint64_t* dst_instances = (global uint64_t*)(dest + instancesOffset);
530
531        // we've copied first instance onto a header, (see TRICK A)
532        // now we have only instances start at aligned memory
533        uint numAlignedInstances = numInstances - unaligned_prefixing_instance_cnt;
534        dst_instances += unaligned_prefixing_instance_cnt;
535        src_instances += unaligned_prefixing_instance_cnt;
536
537        if (numAlignedInstances)
538        {
539            // each 8 instances form a cacheline
540            uint numCachelines = numAlignedInstances >> 3; //qwords -> 64Bs
541            // qwords besides multiple of 8;
542            uint startReminder = numAlignedInstances & ~((1 << 3) - 1);
543            uint numreminder = numAlignedInstances & ((1 << 3) - 1);
544
545            uint task_id = get_group_id(0);
546
547            while (task_id < numCachelines)
548            {
549                uint src_id = task_id * 8 + (local_id >> 1);
550                uint* src_uncorected = (uint*)& src_instances[src_id].AccelerationStructureGPUVA;
551                uint* src = ((local_id & 1) != 0) ? src_uncorected + 1 : src_uncorected;
552                uint data = *src;
553
554                global char* dst = (global byte_align64B*)(dst_instances + (8 * task_id));
555                CacheLineSubgroupWrite(dst, data);
556                task_id += groups_count;
557            }
558
559            if (task_id == numCachelines && local_id < 8 && numreminder > 0)
560            {
561                // this should write full cacheline
562
563                uint index = startReminder + local_id;
564                // data will be taken from instances for lanes (local_id < numreminder)
565                // copy srcbvh beginning as uint64_t for remaining lanes (TRICK B)
566                global uint64_t* srcData = (local_id < numreminder) ?
567                    &src_instances[index].AccelerationStructureGPUVA :
568                    ((global uint64_t*)src) + (local_id - numreminder);
569                dst_instances[index] = *srcData;
570            }
571        }
572    }
573
574    // the parts above copied unaligned dst beginning of bvh (see TRICK B)
575    uint32_t unalignedPartCopiedElsewhere = (64u - (offsetToBvh & (64u - 1u)))&(64u - 1u);
576
577    compactT(dest + offsetToBvh, src, compactedSize, unalignedPartCopiedElsewhere, groups_count);
578}
579
580GRL_ANNOTATE_IGC_DO_NOT_SPILL
581__attribute__((intel_reqd_sub_group_size(MAX_HW_SIMD_WIDTH)))
582__attribute__((reqd_work_group_size(MAX_HW_SIMD_WIDTH, 1, 1)))
583void kernel serialize_indirect(
584    global char* dest,
585    global char* src,
586    global uint8_t* driverID)
587{
588    BVHBase* base = (BVHBase*)src;
589    uint groups_count = GroupCountForCopy(base);
590    serializeT(dest, src, driverID, groups_count);
591}
592
593GRL_ANNOTATE_IGC_DO_NOT_SPILL
594__attribute__((intel_reqd_sub_group_size(MAX_HW_SIMD_WIDTH)))
595__attribute__((reqd_work_group_size(MAX_HW_SIMD_WIDTH, 1, 1)))
596void kernel serialize_for_input_dump_indirect(
597    global struct OutputBatchPtrs* batchPtrs,
598    global dword* dstOffset,
599    global char* src,
600    global uint8_t* driverID)
601{
602    BVHBase* base = (BVHBase*)src;
603    uint groups_count = GroupCountForCopy(base);
604    global char* dest = (global char*)(batchPtrs->dataStart + *dstOffset);
605    dest += (sizeof(OutputData) + 127) & ~127;
606    serializeT(dest, src, driverID, groups_count);
607}
608
609GRL_INLINE
610void deserializeT(
611    global char* dest,
612    global char* src,
613    unsigned groupCnt)
614{
615    SerializationHeader* header = (SerializationHeader*)src;
616
617    const uint64_t headerSize = sizeof(struct SerializationHeader);
618    const uint64_t instancePtrSize = sizeof(gpuva_t);
619    const uint64_t numInstances = header->InstanceHandleCount;
620    const uint64_t offsetToBvh = headerSize + instancePtrSize * numInstances;
621    const uint64_t bvhSize = header->DeserializedSizeInBytes;
622
623    if (numInstances)
624    {
625        const bool instances_mixed_with_inner_nodes = false;
626        if (instances_mixed_with_inner_nodes)
627        {
628            // not implemented !
629            // copy each node with 64byte granularity if node is instance, patch it mid-copy
630        }
631        else
632        {
633            BVHBase* srcBvhBase = (BVHBase*)(src + offsetToBvh);
634
635            // numHWInstances can be bigger (because of rebraiding) or smaller (because of inactive instances) than
636            // numInstances (count of pointers and descriptors).
637            uint offsetToHwInstances = srcBvhBase->instanceLeafStart << 6;
638            uint numHwInstances = (srcBvhBase->instanceLeafEnd - srcBvhBase->instanceLeafStart) >> 1;
639
640            //
641            // instances are in separate memory intervals
642            // copy all the other data simple way
643            //
644            uint nodesEnd = srcBvhBase->Meta.instanceDescsStart;
645            // copy before instance leafs
646            CopyMemory(dest, (global char*)(src + offsetToBvh), offsetToHwInstances, groupCnt);
647
648            uint offsetPostInstances = srcBvhBase->instanceLeafEnd << 6;
649            uint instanceDescStart = srcBvhBase->Meta.instanceDescsStart;
650            uint sizePostInstances = instanceDescStart - offsetPostInstances;
651            // copy after instance leafs before instance desc
652            CopyMemory(dest + offsetPostInstances, (global char*)(src + offsetToBvh + offsetPostInstances), sizePostInstances, groupCnt);
653
654            uint instanceDescEnd = instanceDescStart + numInstances * sizeof(InstanceDesc);
655            uint sizePostInstanceDescs = bvhSize - instanceDescEnd;
656            // copy after instance desc
657            CopyMemory(dest + instanceDescEnd, (global char*)(src + offsetToBvh + instanceDescEnd), sizePostInstanceDescs, groupCnt);
658
659            global gpuva_t* newInstancePtrs = (global gpuva_t*)(src + headerSize);
660            global InstanceDesc* dstDesc = (global InstanceDesc*)(dest + instanceDescStart);
661            global InstanceDesc* srcDesc = (global InstanceDesc*)(src + offsetToBvh + instanceDescStart);
662
663            // copy and patch instance descriptors
664            for (uint64_t instanceIndex = get_group_id(0); instanceIndex < numInstances; instanceIndex += groupCnt)
665            {
666                InstanceDesc desc = srcDesc[instanceIndex];
667                uint64_t newInstancePtr = newInstancePtrs[instanceIndex];
668                desc.AccelerationStructureGPUVA = newInstancePtr; // patch it with new ptr;
669
670                dstDesc[instanceIndex] = desc;
671            }
672
673            // copy and patch hw instance leafs
674            global HwInstanceLeaf* dstInstleafs = (global HwInstanceLeaf*)(dest + offsetToHwInstances);
675            global HwInstanceLeaf* srcInstleafs = (global HwInstanceLeaf*)(src + offsetToBvh + offsetToHwInstances);
676
677            for (uint hwLeafIndex = get_group_id(0); hwLeafIndex < numHwInstances; hwLeafIndex += groupCnt)
678            {
679                // pull the instance from srcBVH
680                HwInstanceLeaf tmpInstleaf = srcInstleafs[hwLeafIndex];
681
682                uint swInstanceIndex = HwInstanceLeaf_GetInstanceIndex(&tmpInstleaf);
683                uint64_t childBvhPtr = (uint64_t)newInstancePtrs[swInstanceIndex];
684                uint64_t originalBvhPtr = (uint64_t)HwInstanceLeaf_GetBVH(&tmpInstleaf);
685
686                HwInstanceLeaf_SetBVH(&tmpInstleaf, childBvhPtr);
687                uint64_t startNode = HwInstanceLeaf_GetStartNode(&tmpInstleaf);
688
689                if (startNode != 0) {
690                    uint64_t rootNodeOffset = startNode - originalBvhPtr;
691                    HwInstanceLeaf_SetStartNode(&tmpInstleaf, childBvhPtr + rootNodeOffset);
692                }
693
694                dstInstleafs[hwLeafIndex] = tmpInstleaf;
695            }
696        }
697    }
698    else
699    {
700        CopyMemory(dest, (global char*)(src + offsetToBvh), bvhSize, groupCnt);
701    }
702}
703
704GRL_ANNOTATE_IGC_DO_NOT_SPILL
705__attribute__((reqd_work_group_size(MAX_HW_SIMD_WIDTH, 1, 1)))
706__attribute__((intel_reqd_sub_group_size(MAX_HW_SIMD_WIDTH)))
707void kernel deserialize_indirect(
708    global char* dest,
709    global char* src)
710{
711    SerializationHeader* header = (SerializationHeader*)src;
712    const uint64_t bvhSize = header->DeserializedSizeInBytes;
713    unsigned groupCnt = GroupCountForCopySize(bvhSize);
714    deserializeT(dest, src, groupCnt);
715}
716
717GRL_ANNOTATE_IGC_DO_NOT_SPILL
718__attribute__((reqd_work_group_size(MAX_HW_SIMD_WIDTH, 1, 1))) void kernel dxr_decode(global char* dest,
719    global char* src)
720{
721
722    DecodeHeader* header = (DecodeHeader*)dest;
723    BVHBase* base = (BVHBase*)src;
724
725    uint32_t numGeos = base->Meta.geoCount;
726    uint32_t numInstances = base->Meta.instanceCount;
727
728    if (numInstances > 0)
729    {
730        header->Type = TOP_LEVEL;
731        header->NumDesc = numInstances;
732
733        D3D12_RAYTRACING_INSTANCE_DESC* instanceDesc = (D3D12_RAYTRACING_INSTANCE_DESC*)(dest + sizeof(DecodeHeader));
734        copyInstanceDescs((InstanceDesc*)((uint64_t)base + (uint64_t)base->Meta.instanceDescsStart),
735            instanceDesc,
736            numInstances);
737    }
738    else if (numGeos > 0)
739    {
740        header->Type = BOTTOM_LEVEL;
741        header->NumDesc = numGeos;
742
743        D3D12_RAYTRACING_GEOMETRY_DESC* geomDescs = (D3D12_RAYTRACING_GEOMETRY_DESC*)(dest + sizeof(DecodeHeader));
744        uint64_t data = (uint64_t)geomDescs + sizeof(D3D12_RAYTRACING_GEOMETRY_DESC) * numGeos;
745        createGeoDescs((GeoMetaData*)((uint64_t)base + (uint64_t)base->Meta.geoDescsStart),
746            geomDescs,
747            numGeos,
748            data);
749
750        work_group_barrier(CLK_GLOBAL_MEM_FENCE);
751
752        copyDataFromQuadLeaves(base,
753            geomDescs);
754
755        copyDataFromLProcedurals(base,
756            geomDescs);
757    }
758    else
759    {
760        header->Type = BOTTOM_LEVEL;
761        header->NumDesc = 0;
762    }
763}
764