xref: /aosp_15_r20/external/angle/src/libANGLE/renderer/metal/ProvokingVertexHelper.mm (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1//
2// Copyright 2021 The ANGLE Project Authors. All rights reserved.
3// Use of this source code is governed by a BSD-style license that can be
4// found in the LICENSE file.
5//
6// ProvokingVertexHelper.mm:
7//    Implements the class methods for ProvokingVertexHelper.
8//
9
10#include "libANGLE/renderer/metal/ProvokingVertexHelper.h"
11#import <Foundation/Foundation.h>
12#include "libANGLE/Display.h"
13#include "libANGLE/renderer/metal/ContextMtl.h"
14#include "libANGLE/renderer/metal/DisplayMtl.h"
15#include "libANGLE/renderer/metal/mtl_common.h"
16#include "libANGLE/renderer/metal/shaders/rewrite_indices_shared.h"
17namespace rx
18{
19
20namespace
21{
22constexpr size_t kInitialIndexBufferSize = 0xFFFF;  // Initial 64k pool.
23}
24static inline uint primCountForIndexCount(const uint fixIndexBufferKey, const uint indexCount)
25{
26    const uint fixIndexBufferMode =
27        (fixIndexBufferKey >> MtlFixIndexBufferKeyModeShift) & MtlFixIndexBufferKeyModeMask;
28
29    switch (fixIndexBufferMode)
30    {
31        case MtlFixIndexBufferKeyPoints:
32            return indexCount;
33        case MtlFixIndexBufferKeyLines:
34            return indexCount / 2;
35        case MtlFixIndexBufferKeyLineStrip:
36            return (uint)MAX(0, (int)indexCount - 1);
37        case MtlFixIndexBufferKeyLineLoop:
38            return (uint)MAX(0, (int)indexCount);
39        case MtlFixIndexBufferKeyTriangles:
40            return indexCount / 3;
41        case MtlFixIndexBufferKeyTriangleStrip:
42            return (uint)MAX(0, (int)indexCount - 2);
43        case MtlFixIndexBufferKeyTriangleFan:
44            return (uint)MAX(0, (int)indexCount - 2);
45        default:
46            ASSERT(false);
47            return 0;
48    }
49}
50
51static inline uint indexCountForPrimCount(const uint fixIndexBufferKey, const uint primCount)
52{
53    const uint fixIndexBufferMode =
54        (fixIndexBufferKey >> MtlFixIndexBufferKeyModeShift) & MtlFixIndexBufferKeyModeMask;
55    switch (fixIndexBufferMode)
56    {
57        case MtlFixIndexBufferKeyPoints:
58            return primCount;
59        case MtlFixIndexBufferKeyLines:
60            return primCount * 2;
61        case MtlFixIndexBufferKeyLineStrip:
62            return primCount * 2;
63        case MtlFixIndexBufferKeyLineLoop:
64            return primCount * 2;
65        case MtlFixIndexBufferKeyTriangles:
66            return primCount * 3;
67        case MtlFixIndexBufferKeyTriangleStrip:
68            return primCount * 3;
69        case MtlFixIndexBufferKeyTriangleFan:
70            return primCount * 3;
71        default:
72            ASSERT(false);
73            return 0;
74    }
75}
76
77static inline gl::PrimitiveMode getNewPrimitiveMode(const uint fixIndexBufferKey)
78{
79    const uint fixIndexBufferMode =
80        (fixIndexBufferKey >> MtlFixIndexBufferKeyModeShift) & MtlFixIndexBufferKeyModeMask;
81    switch (fixIndexBufferMode)
82    {
83        case MtlFixIndexBufferKeyPoints:
84            return gl::PrimitiveMode::Points;
85        case MtlFixIndexBufferKeyLines:
86            return gl::PrimitiveMode::Lines;
87        case MtlFixIndexBufferKeyLineStrip:
88            return gl::PrimitiveMode::Lines;
89        case MtlFixIndexBufferKeyLineLoop:
90            return gl::PrimitiveMode::Lines;
91        case MtlFixIndexBufferKeyTriangles:
92            return gl::PrimitiveMode::Triangles;
93        case MtlFixIndexBufferKeyTriangleStrip:
94            return gl::PrimitiveMode::Triangles;
95        case MtlFixIndexBufferKeyTriangleFan:
96            return gl::PrimitiveMode::Triangles;
97        default:
98            ASSERT(false);
99            return gl::PrimitiveMode::InvalidEnum;
100    }
101}
102ProvokingVertexHelper::ProvokingVertexHelper(ContextMtl *context) : mIndexBuffers(false)
103{
104    mIndexBuffers.initialize(context, kInitialIndexBufferSize, mtl::kIndexBufferOffsetAlignment, 0);
105}
106
107void ProvokingVertexHelper::onDestroy(ContextMtl *context)
108{
109    mIndexBuffers.destroy(context);
110}
111
112void ProvokingVertexHelper::releaseInFlightBuffers(ContextMtl *contextMtl)
113{
114    mIndexBuffers.releaseInFlightBuffers(contextMtl);
115}
116
117static uint buildIndexBufferKey(const mtl::ProvokingVertexComputePipelineDesc &pipelineDesc)
118{
119    uint indexBufferKey              = 0;
120    gl::DrawElementsType elementType = (gl::DrawElementsType)pipelineDesc.elementType;
121    bool doPrimPrestart              = pipelineDesc.primitiveRestartEnabled;
122    gl::PrimitiveMode primMode       = pipelineDesc.primitiveMode;
123    switch (elementType)
124    {
125        case gl::DrawElementsType::UnsignedShort:
126            indexBufferKey |= MtlFixIndexBufferKeyUint16 << MtlFixIndexBufferKeyInShift;
127            indexBufferKey |= MtlFixIndexBufferKeyUint16 << MtlFixIndexBufferKeyOutShift;
128            break;
129        case gl::DrawElementsType::UnsignedInt:
130            indexBufferKey |= MtlFixIndexBufferKeyUint32 << MtlFixIndexBufferKeyInShift;
131            indexBufferKey |= MtlFixIndexBufferKeyUint32 << MtlFixIndexBufferKeyOutShift;
132            break;
133        default:
134            ASSERT(false);  // Index type should only be short or int.
135            break;
136    }
137    indexBufferKey |= (uint)primMode << MtlFixIndexBufferKeyModeShift;
138    indexBufferKey |= doPrimPrestart ? MtlFixIndexBufferKeyPrimRestart : 0;
139    // We only rewrite indices if we're switching the provoking vertex mode.
140    indexBufferKey |= MtlFixIndexBufferKeyProvokingVertexLast;
141    return indexBufferKey;
142}
143
144angle::Result ProvokingVertexHelper::getComputePipleineState(
145    ContextMtl *context,
146    const mtl::ProvokingVertexComputePipelineDesc &desc,
147    mtl::AutoObjCPtr<id<MTLComputePipelineState>> *outComputePipeline)
148{
149    auto iter = mComputeFunctions.find(desc);
150    if (iter != mComputeFunctions.end())
151    {
152        return context->getPipelineCache().getComputePipeline(context, iter->second,
153                                                              outComputePipeline);
154    }
155
156    id<MTLLibrary> provokingVertexLibrary = context->getDisplay()->getDefaultShadersLib();
157    uint indexBufferKey                   = buildIndexBufferKey(desc);
158    auto fcValues = mtl::adoptObjCObj([[MTLFunctionConstantValues alloc] init]);
159    [fcValues setConstantValue:&indexBufferKey type:MTLDataTypeUInt withName:@"fixIndexBufferKey"];
160
161    mtl::AutoObjCPtr<id<MTLFunction>> computeShader;
162    if (desc.generateIndices)
163    {
164        ANGLE_TRY(CreateMslShader(context, provokingVertexLibrary, @"genIndexBuffer",
165                                  fcValues.get(), &computeShader));
166    }
167    else
168    {
169        ANGLE_TRY(CreateMslShader(context, provokingVertexLibrary, @"fixIndexBuffer",
170                                  fcValues.get(), &computeShader));
171    }
172    mComputeFunctions[desc] = computeShader;
173
174    return context->getPipelineCache().getComputePipeline(context, computeShader,
175                                                          outComputePipeline);
176}
177
178angle::Result ProvokingVertexHelper::prepareCommandEncoderForDescriptor(
179    ContextMtl *context,
180    mtl::ComputeCommandEncoder *encoder,
181    mtl::ProvokingVertexComputePipelineDesc desc)
182{
183    mtl::AutoObjCPtr<id<MTLComputePipelineState>> pipelineState;
184    ANGLE_TRY(getComputePipleineState(context, desc, &pipelineState));
185
186    encoder->setComputePipelineState(pipelineState);
187
188    return angle::Result::Continue;
189}
190
191angle::Result ProvokingVertexHelper::preconditionIndexBuffer(ContextMtl *context,
192                                                             mtl::BufferRef indexBuffer,
193                                                             size_t indexCount,
194                                                             size_t indexOffset,
195                                                             bool primitiveRestartEnabled,
196                                                             gl::PrimitiveMode primitiveMode,
197                                                             gl::DrawElementsType elementsType,
198                                                             size_t &outIndexCount,
199                                                             size_t &outIndexOffset,
200                                                             gl::PrimitiveMode &outPrimitiveMode,
201                                                             mtl::BufferRef &outNewBuffer)
202{
203    // Get specialized program
204    // Upload index buffer
205    // dispatch per-primitive?
206    mtl::ProvokingVertexComputePipelineDesc pipelineDesc;
207    pipelineDesc.elementType             = (uint8_t)elementsType;
208    pipelineDesc.primitiveMode           = primitiveMode;
209    pipelineDesc.primitiveRestartEnabled = primitiveRestartEnabled;
210    pipelineDesc.generateIndices         = false;
211    uint indexBufferKey                  = buildIndexBufferKey(pipelineDesc);
212    uint primCount     = primCountForIndexCount(indexBufferKey, (uint32_t)indexCount);
213    uint newIndexCount = indexCountForPrimCount(indexBufferKey, primCount);
214    size_t indexSize   = gl::GetDrawElementsTypeSize(elementsType);
215    size_t newOffset   = 0;
216    mtl::BufferRef newBuffer;
217    ANGLE_TRY(mIndexBuffers.allocate(context, newIndexCount * indexSize + indexOffset, nullptr,
218                                     &newBuffer, &newOffset));
219    uint indexCountEncoded     = (uint)indexCount;
220    auto threadsPerThreadgroup = MTLSizeMake(MIN(primCount, 64u), 1, 1);
221
222    mtl::ComputeCommandEncoder *encoder =
223        context->getComputeCommandEncoderWithoutEndingRenderEncoder();
224    ANGLE_TRY(prepareCommandEncoderForDescriptor(context, encoder, pipelineDesc));
225    encoder->setBuffer(indexBuffer, static_cast<uint32_t>(indexOffset), 0);
226    encoder->setBufferForWrite(
227        newBuffer, static_cast<uint32_t>(indexOffset) + static_cast<uint32_t>(newOffset), 1);
228    encoder->setData(&indexCountEncoded, 2);
229    encoder->setData(&primCount, 3);
230    encoder->dispatch(
231        MTLSizeMake((primCount + threadsPerThreadgroup.width - 1) / threadsPerThreadgroup.width, 1,
232                    1),
233        threadsPerThreadgroup);
234    outIndexCount    = newIndexCount;
235    outIndexOffset   = newOffset;
236    outPrimitiveMode = getNewPrimitiveMode(indexBufferKey);
237    outNewBuffer     = newBuffer;
238    return angle::Result::Continue;
239}
240
241angle::Result ProvokingVertexHelper::generateIndexBuffer(ContextMtl *context,
242                                                         size_t first,
243                                                         size_t indexCount,
244                                                         gl::PrimitiveMode primitiveMode,
245                                                         gl::DrawElementsType elementsType,
246                                                         size_t &outIndexCount,
247                                                         size_t &outIndexOffset,
248                                                         gl::PrimitiveMode &outPrimitiveMode,
249                                                         mtl::BufferRef &outNewBuffer)
250{
251    // Get specialized program
252    // Upload index buffer
253    // dispatch per-primitive?
254    mtl::ProvokingVertexComputePipelineDesc pipelineDesc;
255    pipelineDesc.elementType             = (uint8_t)elementsType;
256    pipelineDesc.primitiveMode           = primitiveMode;
257    pipelineDesc.primitiveRestartEnabled = false;
258    pipelineDesc.generateIndices         = true;
259    uint indexBufferKey                  = buildIndexBufferKey(pipelineDesc);
260    uint primCount        = primCountForIndexCount(indexBufferKey, (uint32_t)indexCount);
261    uint newIndexCount    = indexCountForPrimCount(indexBufferKey, primCount);
262    size_t indexSize      = gl::GetDrawElementsTypeSize(elementsType);
263    size_t newIndexOffset = 0;
264    mtl::BufferRef newBuffer;
265    ANGLE_TRY(mIndexBuffers.allocate(context, newIndexCount * indexSize, nullptr, &newBuffer,
266                                     &newIndexOffset));
267    uint indexCountEncoded     = static_cast<uint>(indexCount);
268    uint firstVertexEncoded    = static_cast<uint>(first);
269    uint indexOffsetEncoded    = static_cast<uint>(newIndexOffset);
270    auto threadsPerThreadgroup = MTLSizeMake(MIN(primCount, 64u), 1, 1);
271
272    mtl::ComputeCommandEncoder *encoder =
273        context->getComputeCommandEncoderWithoutEndingRenderEncoder();
274    ANGLE_TRY(prepareCommandEncoderForDescriptor(context, encoder, pipelineDesc));
275    encoder->setBufferForWrite(newBuffer, indexOffsetEncoded, 1);
276    encoder->setData(indexCountEncoded, 2);
277    encoder->setData(primCount, 3);
278    encoder->setData(firstVertexEncoded, 4);
279    encoder->dispatch(
280        MTLSizeMake((primCount + threadsPerThreadgroup.width - 1) / threadsPerThreadgroup.width, 1,
281                    1),
282        threadsPerThreadgroup);
283    outIndexCount    = newIndexCount;
284    outIndexOffset   = newIndexOffset;
285    outPrimitiveMode = getNewPrimitiveMode(indexBufferKey);
286    outNewBuffer     = newBuffer;
287    return angle::Result::Continue;
288}
289
290}  // namespace rx
291