xref: /aosp_15_r20/external/mesa3d/src/gfxstream/codegen/scripts/cereal/counting.py (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1# Copyright 2023 Google LLC
2# SPDX-License-Identifier: MIT
3
4from copy import copy
5
6from .common.codegen import CodeGen
7from .common.vulkantypes import \
8        VulkanAPI, makeVulkanTypeSimple, iterateVulkanType, VulkanTypeIterator, Atom, FuncExpr, FuncExprVal, FuncLambda
9
10from .wrapperdefs import VulkanWrapperGenerator
11from .wrapperdefs import ROOT_TYPE_VAR_NAME, ROOT_TYPE_PARAM
12from .wrapperdefs import STRUCT_EXTENSION_PARAM, STRUCT_EXTENSION_PARAM_FOR_WRITE, EXTENSION_SIZE_WITH_STREAM_FEATURES_API_NAME
13
14class VulkanCountingCodegen(VulkanTypeIterator):
15    def __init__(self, cgen, featureBitsVar, toCountVar, countVar, rootTypeVar, prefix, forApiOutput=False, mapHandles=True, handleMapOverwrites=False, doFiltering=True):
16        self.cgen = cgen
17        self.featureBitsVar = featureBitsVar
18        self.toCountVar = toCountVar
19        self.rootTypeVar = rootTypeVar
20        self.countVar = countVar
21        self.prefix = prefix
22        self.forApiOutput = forApiOutput
23
24        self.exprAccessor = lambda t: self.cgen.generalAccess(t, parentVarName = self.toCountVar, asPtr = True)
25        self.exprValueAccessor = lambda t: self.cgen.generalAccess(t, parentVarName = self.toCountVar, asPtr = False)
26        self.exprPrimitiveValueAccessor = lambda t: self.cgen.generalAccess(t, parentVarName = self.toCountVar, asPtr = False)
27
28        self.lenAccessor = lambda t: self.cgen.generalLengthAccess(t, parentVarName = self.toCountVar)
29        self.lenAccessorGuard = lambda t: self.cgen.generalLengthAccessGuard(t, parentVarName = self.toCountVar)
30        self.filterVarAccessor = lambda t: self.cgen.filterVarAccess(t, parentVarName = self.toCountVar)
31
32        self.checked = False
33
34        self.mapHandles = mapHandles
35        self.handleMapOverwrites = handleMapOverwrites
36        self.doFiltering = doFiltering
37
38    def getTypeForStreaming(self, vulkanType):
39        res = copy(vulkanType)
40
41        if not vulkanType.accessibleAsPointer():
42            res = res.getForAddressAccess()
43
44        if vulkanType.staticArrExpr:
45            res = res.getForAddressAccess()
46
47        return res
48
49    def makeCastExpr(self, vulkanType):
50        return "(%s)" % (
51            self.cgen.makeCTypeDecl(vulkanType, useParamName=False))
52
53    def genCount(self, sizeExpr):
54        self.cgen.stmt("*%s += %s" % (self.countVar, sizeExpr))
55
56    def genPrimitiveStreamCall(self, vulkanType):
57        self.genCount(str(self.cgen.countPrimitive(
58            self.typeInfo,
59            vulkanType)))
60
61    def genHandleMappingCall(self, vulkanType, access, lenAccess):
62
63        if lenAccess is None:
64            lenAccess = "1"
65            handle64Bytes = "8"
66        else:
67            handle64Bytes = "%s * 8" % lenAccess
68
69        handle64Var = self.cgen.var()
70        if lenAccess != "1":
71            self.cgen.beginIf(lenAccess)
72            # self.cgen.stmt("uint64_t* %s" % handle64Var)
73            # self.cgen.stmt(
74                # "%s->alloc((void**)&%s, %s * 8)" % \
75                # (self.streamVarName, handle64Var, lenAccess))
76            handle64VarAccess = handle64Var
77            handle64VarType = \
78                makeVulkanTypeSimple(False, "uint64_t", 1, paramName=handle64Var)
79        else:
80            self.cgen.stmt("uint64_t %s" % handle64Var)
81            handle64VarAccess = "&%s" % handle64Var
82            handle64VarType = \
83                makeVulkanTypeSimple(False, "uint64_t", 0, paramName=handle64Var)
84
85        if self.handleMapOverwrites:
86            # self.cgen.stmt(
87                # "static_assert(8 == sizeof(%s), \"handle map overwrite requres %s to be 8 bytes long\")" % \
88                        # (vulkanType.typeName, vulkanType.typeName))
89            # self.cgen.stmt(
90                # "%s->handleMapping()->mapHandles_%s((%s*)%s, %s)" %
91                # (self.streamVarName, vulkanType.typeName, vulkanType.typeName,
92                # access, lenAccess))
93            self.genCount("8 * %s" % lenAccess)
94        else:
95            # self.cgen.stmt(
96                # "%s->handleMapping()->mapHandles_%s_u64(%s, %s, %s)" %
97                # (self.streamVarName, vulkanType.typeName,
98                # access,
99                # handle64VarAccess, lenAccess))
100            self.genCount(handle64Bytes)
101
102        if lenAccess != "1":
103            self.cgen.endIf()
104
105    def doAllocSpace(self, vulkanType):
106        pass
107
108    def getOptionalStringFeatureExpr(self, vulkanType):
109        feature = vulkanType.getProtectStreamFeature()
110        if feature is None:
111            return None
112        return "%s & %s" % (self.featureBitsVar, feature)
113
114    def onCheck(self, vulkanType):
115
116        if self.forApiOutput:
117            return
118
119        featureExpr = self.getOptionalStringFeatureExpr(vulkanType);
120
121        self.checked = True
122
123        access = self.exprAccessor(vulkanType)
124
125        needConsistencyCheck = False
126
127        self.cgen.line("// WARNING PTR CHECK")
128        checkAccess = self.exprAccessor(vulkanType)
129        addrExpr = "&" + checkAccess
130        sizeExpr = self.cgen.sizeofExpr(vulkanType)
131
132        if featureExpr is not None:
133            self.cgen.beginIf(featureExpr)
134
135        self.genPrimitiveStreamCall(
136            vulkanType)
137
138        if featureExpr is not None:
139            self.cgen.endIf()
140
141        if featureExpr is not None:
142            self.cgen.beginIf("(!(%s) || %s)" % (featureExpr, access))
143        else:
144            self.cgen.beginIf(access)
145
146        if needConsistencyCheck and featureExpr is None:
147            self.cgen.beginIf("!(%s)" % checkName)
148            self.cgen.stmt(
149                "fprintf(stderr, \"fatal: %s inconsistent between guest and host\\n\")" % (access))
150            self.cgen.endIf()
151
152
153    def onCheckWithNullOptionalStringFeature(self, vulkanType):
154        self.cgen.beginIf("%s & VULKAN_STREAM_FEATURE_NULL_OPTIONAL_STRINGS_BIT" % self.featureBitsVar)
155        self.onCheck(vulkanType)
156
157    def endCheckWithNullOptionalStringFeature(self, vulkanType):
158        self.endCheck(vulkanType)
159        self.cgen.endIf()
160        self.cgen.beginElse()
161
162    def finalCheckWithNullOptionalStringFeature(self, vulkanType):
163        self.cgen.endElse()
164
165    def endCheck(self, vulkanType):
166
167        if self.checked:
168            self.cgen.endIf()
169            self.checked = False
170
171    def genFilterFunc(self, filterfunc, env):
172
173        def loop(expr, lambdaEnv={}):
174            def do_func(expr):
175                fnamestr = expr.name.name
176                if "not" == fnamestr:
177                    return "!(%s)" % (loop(expr.args[0], lambdaEnv))
178                if "eq" == fnamestr:
179                    return "(%s == %s)" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv))
180                if "and" == fnamestr:
181                    return "(%s && %s)" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv))
182                if "or" == fnamestr:
183                    return "(%s || %s)" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv))
184                if "bitwise_and" == fnamestr:
185                    return "(%s & %s)" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv))
186                if "getfield" == fnamestr:
187                    ptrlevels = get_ptrlevels(expr.args[0].val.name)
188                    if ptrlevels == 0:
189                        return "%s.%s" % (loop(expr.args[0], lambdaEnv), expr.args[1].val)
190                    else:
191                        return "(%s(%s)).%s" % ("*" * ptrlevels, loop(expr.args[0], lambdaEnv), expr.args[1].val)
192
193                if "if" == fnamestr:
194                    return "((%s) ? (%s) : (%s))" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv), loop(expr.args[2], lambdaEnv))
195
196                return "%s(%s)" % (fnamestr, ", ".join(map(lambda e: loop(e, lambdaEnv), expr.args)))
197
198            def do_expratom(atomname, lambdaEnv= {}):
199                if lambdaEnv.get(atomname, None) is not None:
200                    return atomname
201
202                enventry = env.get(atomname, None)
203                if None != enventry:
204                    return self.getEnvAccessExpr(atomname)
205                return atomname
206
207            def get_ptrlevels(atomname, lambdaEnv= {}):
208                if lambdaEnv.get(atomname, None) is not None:
209                    return 0
210
211                enventry = env.get(atomname, None)
212                if None != enventry:
213                    return self.getPointerIndirectionLevels(atomname)
214
215                return 0
216
217            def do_exprval(expr, lambdaEnv= {}):
218                expratom = expr.val
219
220                if Atom == type(expratom):
221                    return do_expratom(expratom.name, lambdaEnv)
222
223                return "%s" % expratom
224
225            def do_lambda(expr, lambdaEnv= {}):
226                params = expr.vs
227                body = expr.body
228                newEnv = {}
229
230                for (k, v) in lambdaEnv.items():
231                    newEnv[k] = v
232
233                for p in params:
234                    newEnv[p.name] = p.typ
235
236                return "[](%s) { return %s; }" % (", ".join(list(map(lambda p: "%s %s" % (p.typ, p.name), params))), loop(body, lambdaEnv=newEnv))
237
238            if FuncExpr == type(expr):
239                return do_func(expr)
240            if FuncLambda == type(expr):
241                return do_lambda(expr)
242            elif FuncExprVal == type(expr):
243                return do_exprval(expr)
244
245        return loop(filterfunc)
246
247    def beginFilterGuard(self, vulkanType):
248        if vulkanType.filterVar == None:
249            return
250
251        if self.doFiltering == False:
252            return
253
254        filterVarAccess = self.getEnvAccessExpr(vulkanType.filterVar)
255
256        filterValsExpr = None
257        filterFuncExpr = None
258        filterExpr = None
259
260        filterFeature = "%s & VULKAN_STREAM_FEATURE_IGNORED_HANDLES_BIT" % self.featureBitsVar
261
262        if None != vulkanType.filterVals:
263            filterValsExpr = " || ".join(map(lambda filterval: "(%s == %s)" % (filterval, filterVarAccess), vulkanType.filterVals))
264
265        if None != vulkanType.filterFunc:
266            filterFuncExpr = self.genFilterFunc(vulkanType.filterFunc, self.currentStructInfo.environment)
267
268        if None != filterValsExpr and None != filterFuncExpr:
269            filterExpr = "%s || %s" % (filterValsExpr, filterFuncExpr)
270        elif None == filterValsExpr and None == filterFuncExpr:
271            # Assume is bool
272            self.cgen.beginIf(filterVarAccess)
273        elif None != filterValsExpr:
274            self.cgen.beginIf("(!(%s) || (%s))" % (filterFeature, filterValsExpr))
275        elif None != filterFuncExpr:
276            self.cgen.beginIf("(!(%s) || (%s))" % (filterFeature, filterFuncExpr))
277
278    def endFilterGuard(self, vulkanType, cleanupExpr=None):
279        if vulkanType.filterVar == None:
280            return
281
282        if self.doFiltering == False:
283            return
284
285        if cleanupExpr == None:
286            self.cgen.endIf()
287        else:
288            self.cgen.endIf()
289            self.cgen.beginElse()
290            self.cgen.stmt(cleanupExpr)
291            self.cgen.endElse()
292
293    def getEnvAccessExpr(self, varName):
294        parentEnvEntry = self.currentStructInfo.environment.get(varName, None)
295
296        if parentEnvEntry != None:
297            isParentMember = parentEnvEntry["structmember"]
298
299            if isParentMember:
300                envAccess = self.exprValueAccessor(list(filter(lambda member: member.paramName == varName, self.currentStructInfo.members))[0])
301            else:
302                envAccess = varName
303            return envAccess
304
305        return None
306
307    def getPointerIndirectionLevels(self, varName):
308        parentEnvEntry = self.currentStructInfo.environment.get(varName, None)
309
310        if parentEnvEntry != None:
311            isParentMember = parentEnvEntry["structmember"]
312
313            if isParentMember:
314                return list(filter(lambda member: member.paramName == varName, self.currentStructInfo.members))[0].pointerIndirectionLevels
315            else:
316                return 0
317            return 0
318
319        return 0
320
321
322    def onCompoundType(self, vulkanType):
323
324        access = self.exprAccessor(vulkanType)
325        lenAccess = self.lenAccessor(vulkanType)
326        lenAccessGuard = self.lenAccessorGuard(vulkanType)
327
328        self.beginFilterGuard(vulkanType)
329
330        if vulkanType.pointerIndirectionLevels > 0:
331            self.doAllocSpace(vulkanType)
332
333        if lenAccess is not None:
334            if lenAccessGuard is not None:
335                self.cgen.beginIf(lenAccessGuard)
336            loopVar = "i"
337            access = "%s + %s" % (access, loopVar)
338            forInit = "uint32_t %s = 0" % loopVar
339            forCond = "%s < (uint32_t)%s" % (loopVar, lenAccess)
340            forIncr = "++%s" % loopVar
341            self.cgen.beginFor(forInit, forCond, forIncr)
342
343        accessWithCast = "%s(%s)" % (self.makeCastExpr(
344            self.getTypeForStreaming(vulkanType)), access)
345
346        callParams = [self.featureBitsVar,
347                      self.rootTypeVar, accessWithCast, self.countVar]
348
349        for (bindName, localName) in vulkanType.binds.items():
350            callParams.append(self.getEnvAccessExpr(localName))
351
352        self.cgen.funcCall(None, self.prefix + vulkanType.typeName,
353                           callParams)
354
355        if lenAccess is not None:
356            self.cgen.endFor()
357            if lenAccessGuard is not None:
358                self.cgen.endIf()
359
360        self.endFilterGuard(vulkanType)
361
362    def onString(self, vulkanType):
363        access = self.exprAccessor(vulkanType)
364        self.genCount("sizeof(uint32_t) + (%s ? strlen(%s) : 0)" % (access, access))
365
366    def onStringArray(self, vulkanType):
367        access = self.exprAccessor(vulkanType)
368        lenAccess = self.lenAccessor(vulkanType)
369        lenAccessGuard = self.lenAccessorGuard(vulkanType)
370
371        self.genCount("sizeof(uint32_t)")
372        if lenAccessGuard is not None:
373            self.cgen.beginIf(lenAccessGuard)
374        self.cgen.beginFor("uint32_t i = 0", "i < %s" % lenAccess, "++i")
375        self.cgen.stmt("size_t l = %s[i] ? strlen(%s[i]) : 0" % (access, access))
376        self.genCount("sizeof(uint32_t) + (%s[i] ? strlen(%s[i]) : 0)" % (access, access))
377        self.cgen.endFor()
378        if lenAccessGuard is not None:
379            self.cgen.endIf()
380
381    def onStaticArr(self, vulkanType):
382        access = self.exprValueAccessor(vulkanType)
383        lenAccess = self.lenAccessor(vulkanType)
384        lenAccessGuard = self.lenAccessorGuard(vulkanType)
385
386        if lenAccessGuard is not None:
387            self.cgen.beginIf(lenAccessGuard)
388        finalLenExpr = "%s * %s" % (lenAccess, self.cgen.sizeofExpr(vulkanType))
389        if lenAccessGuard is not None:
390            self.cgen.endIf()
391        self.genCount(finalLenExpr)
392
393    def onStructExtension(self, vulkanType):
394        sTypeParam = copy(vulkanType)
395        sTypeParam.paramName = "sType"
396
397        access = self.exprAccessor(vulkanType)
398        sizeVar = "%s_size" % vulkanType.paramName
399
400        castedAccessExpr = access
401
402        sTypeAccess = self.exprAccessor(sTypeParam)
403        self.cgen.beginIf("%s == VK_STRUCTURE_TYPE_MAX_ENUM" %
404                          self.rootTypeVar)
405        self.cgen.stmt("%s = %s" % (self.rootTypeVar, sTypeAccess))
406        self.cgen.endIf()
407
408        self.cgen.funcCall(None, self.prefix + "extension_struct",
409                           [self.featureBitsVar, self.rootTypeVar, castedAccessExpr, self.countVar])
410
411
412    def onPointer(self, vulkanType):
413        access = self.exprAccessor(vulkanType)
414
415        lenAccess = self.lenAccessor(vulkanType)
416        lenAccessGuard = self.lenAccessorGuard(vulkanType)
417
418        self.beginFilterGuard(vulkanType)
419        self.doAllocSpace(vulkanType)
420
421        if vulkanType.isHandleType() and self.mapHandles:
422            self.genHandleMappingCall(vulkanType, access, lenAccess)
423        else:
424            if self.typeInfo.isNonAbiPortableType(vulkanType.typeName):
425                if lenAccess is not None:
426                    if lenAccessGuard is not None:
427                        self.cgen.beginIf(lenAccessGuard)
428                    self.cgen.beginFor("uint32_t i = 0", "i < (uint32_t)%s" % lenAccess, "++i")
429                    self.genPrimitiveStreamCall(vulkanType.getForValueAccess())
430                    self.cgen.endFor()
431                    if lenAccessGuard is not None:
432                        self.cgen.endIf()
433                else:
434                    self.genPrimitiveStreamCall(vulkanType.getForValueAccess())
435            else:
436                if lenAccess is not None:
437                    needLenAccessGuard = True
438                    finalLenExpr = "%s * %s" % (
439                        lenAccess, self.cgen.sizeofExpr(vulkanType.getForValueAccess()))
440                else:
441                    needLenAccessGuard = False
442                    finalLenExpr = "%s" % (
443                        self.cgen.sizeofExpr(vulkanType.getForValueAccess()))
444                if needLenAccessGuard and lenAccessGuard is not None:
445                    self.cgen.beginIf(lenAccessGuard)
446                self.genCount(finalLenExpr)
447                if needLenAccessGuard and lenAccessGuard is not None:
448                    self.cgen.endIf()
449
450        self.endFilterGuard(vulkanType)
451
452    def onValue(self, vulkanType):
453        self.beginFilterGuard(vulkanType)
454
455        if vulkanType.isHandleType() and self.mapHandles:
456            access = self.exprAccessor(vulkanType)
457            self.genHandleMappingCall(
458                vulkanType.getForAddressAccess(), access, "1")
459        elif self.typeInfo.isNonAbiPortableType(vulkanType.typeName):
460            access = self.exprPrimitiveValueAccessor(vulkanType)
461            self.genPrimitiveStreamCall(vulkanType)
462        else:
463            access = self.exprAccessor(vulkanType)
464            self.genCount(self.cgen.sizeofExpr(vulkanType))
465
466        self.endFilterGuard(vulkanType)
467
468    def streamLetParameter(self, structInfo, letParamInfo):
469        filterFeature = "%s & VULKAN_STREAM_FEATURE_IGNORED_HANDLES_BIT" % (self.featureBitsVar)
470        self.cgen.stmt("%s %s = 1" % (letParamInfo.typeName, letParamInfo.paramName))
471
472        self.cgen.beginIf(filterFeature)
473
474        bodyExpr = self.currentStructInfo.environment[letParamInfo.paramName]["body"]
475        self.cgen.stmt("%s = %s" % (letParamInfo.paramName, self.genFilterFunc(bodyExpr, self.currentStructInfo.environment)))
476
477        self.genPrimitiveStreamCall(letParamInfo)
478
479        self.cgen.endIf()
480
481class VulkanCounting(VulkanWrapperGenerator):
482
483    def __init__(self, module, typeInfo):
484        VulkanWrapperGenerator.__init__(self, module, typeInfo)
485
486        self.codegen = CodeGen()
487
488        self.featureBitsVar = "featureBits"
489        self.featureBitsVarType = makeVulkanTypeSimple(False, "uint32_t", 0, self.featureBitsVar)
490        self.countingPrefix = "count_"
491        self.countVars = ["toCount", "count"]
492        self.countVarType = makeVulkanTypeSimple(False, "size_t", 1, self.countVars[1])
493        self.voidType = makeVulkanTypeSimple(False, "void", 0)
494        self.rootTypeVar = ROOT_TYPE_VAR_NAME
495
496        self.countingCodegen = \
497            VulkanCountingCodegen(
498                self.codegen,
499                self.featureBitsVar,
500                self.countVars[0],
501                self.countVars[1],
502                self.rootTypeVar,
503                self.countingPrefix)
504
505        self.knownDefs = {}
506
507        self.extensionCountingPrototype = \
508            VulkanAPI(self.countingPrefix + "extension_struct",
509                      self.voidType,
510                      [self.featureBitsVarType,
511                       ROOT_TYPE_PARAM,
512                       STRUCT_EXTENSION_PARAM,
513                       self.countVarType])
514
515    def onBegin(self,):
516        VulkanWrapperGenerator.onBegin(self)
517        self.module.appendImpl(self.codegen.makeFuncDecl(
518            self.extensionCountingPrototype))
519
520    def onGenType(self, typeXml, name, alias):
521        VulkanWrapperGenerator.onGenType(self, typeXml, name, alias)
522
523        if name in self.knownDefs:
524            return
525
526        category = self.typeInfo.categoryOf(name)
527
528        if category in ["struct", "union"] and alias:
529            # TODO(liyl): might not work if freeParams != []
530            self.module.appendHeader(
531                self.codegen.makeFuncAlias(self.countingPrefix + name,
532                                           self.countingPrefix + alias))
533
534        if category in ["struct", "union"] and not alias:
535
536            structInfo = self.typeInfo.structs[name]
537
538            freeParams = []
539            letParams = []
540
541            for (envname, bindingInfo) in list(sorted(structInfo.environment.items(), key = lambda kv: kv[0])):
542                if None == bindingInfo["binding"]:
543                    freeParams.append(makeVulkanTypeSimple(True, bindingInfo["type"], 0, envname))
544                else:
545                    if not bindingInfo["structmember"]:
546                        letParams.append(makeVulkanTypeSimple(True, bindingInfo["type"], 0, envname))
547
548            typeFromName = \
549                lambda varname: \
550                    makeVulkanTypeSimple(True, name, 1, varname)
551
552            countingParams = \
553                [makeVulkanTypeSimple(False, "uint32_t", 0, self.featureBitsVar),
554                 ROOT_TYPE_PARAM,
555                 typeFromName(self.countVars[0]),
556                 makeVulkanTypeSimple(False, "size_t", 1, self.countVars[1])]
557
558            countingPrototype = \
559                VulkanAPI(self.countingPrefix + name,
560                          self.voidType,
561                          countingParams + freeParams)
562
563            countingPrototypeNoFilter = \
564                VulkanAPI(self.countingPrefix + name,
565                          self.voidType,
566                          countingParams)
567
568            def structCountingDef(cgen):
569                self.countingCodegen.cgen = cgen
570                self.countingCodegen.currentStructInfo = structInfo
571                cgen.stmt("(void)%s" % self.featureBitsVar);
572                cgen.stmt("(void)%s" % self.rootTypeVar);
573                cgen.stmt("(void)%s" % self.countVars[0]);
574                cgen.stmt("(void)%s" % self.countVars[1]);
575
576                if category == "struct":
577                    # marshal 'let' parameters first
578                    for letp in letParams:
579                        self.countingCodegen.streamLetParameter(self.typeInfo, letp)
580
581                    for member in structInfo.members:
582                        iterateVulkanType(self.typeInfo, member, self.countingCodegen)
583                if category == "union":
584                    iterateVulkanType(self.typeInfo, structInfo.members[0], self.countingCodegen)
585
586            def structCountingDefNoFilter(cgen):
587                self.countingCodegen.cgen = cgen
588                self.countingCodegen.currentStructInfo = structInfo
589                self.countingCodegen.doFiltering = False
590                cgen.stmt("(void)%s" % self.featureBitsVar);
591                cgen.stmt("(void)%s" % self.rootTypeVar);
592                cgen.stmt("(void)%s" % self.countVars[0]);
593                cgen.stmt("(void)%s" % self.countVars[1]);
594
595                if category == "struct":
596                    # marshal 'let' parameters first
597                    for letp in letParams:
598                        self.countingCodegen.streamLetParameter(self.typeInfo, letp)
599
600                    for member in structInfo.members:
601                        iterateVulkanType(self.typeInfo, member, self.countingCodegen)
602                if category == "union":
603                    iterateVulkanType(self.typeInfo, structInfo.members[0], self.countingCodegen)
604
605                self.countingCodegen.doFiltering = True
606
607            self.module.appendHeader(
608                self.codegen.makeFuncDecl(countingPrototype))
609            self.module.appendImpl(
610                self.codegen.makeFuncImpl(countingPrototype, structCountingDef))
611
612            if freeParams != []:
613                self.module.appendHeader(
614                    self.cgenHeader.makeFuncDecl(countingPrototypeNoFilter))
615                self.module.appendImpl(
616                    self.cgenImpl.makeFuncImpl(
617                        countingPrototypeNoFilter, structCountingDefNoFilter))
618
619    def onGenCmd(self, cmdinfo, name, alias):
620        VulkanWrapperGenerator.onGenCmd(self, cmdinfo, name, alias)
621
622    def doExtensionStructCountCodegen(self, cgen, extParam, forEach, funcproto):
623        accessVar = "structAccess"
624        sizeVar = "currExtSize"
625        cgen.stmt("VkInstanceCreateInfo* %s = (VkInstanceCreateInfo*)(%s)" % (accessVar, extParam.paramName))
626        cgen.stmt("size_t %s = %s(%s, %s, %s)" % (sizeVar, EXTENSION_SIZE_WITH_STREAM_FEATURES_API_NAME,
627                                                  self.featureBitsVar, ROOT_TYPE_VAR_NAME, extParam.paramName))
628
629        cgen.beginIf("!%s && %s" % (sizeVar, extParam.paramName))
630
631        cgen.line("// unknown struct extension; skip and call on its pNext field");
632        cgen.funcCall(None, funcproto.name, [
633                      self.featureBitsVar, ROOT_TYPE_VAR_NAME, "(void*)%s->pNext" % accessVar, self.countVars[1]])
634        cgen.stmt("return")
635
636        cgen.endIf()
637        cgen.beginElse()
638
639        cgen.line("// known or null extension struct")
640
641        cgen.stmt("*%s += sizeof(uint32_t)" % self.countVars[1])
642
643        cgen.beginIf("!%s" % (sizeVar))
644        cgen.line("// exit if this was a null extension struct (size == 0 in this branch)")
645        cgen.stmt("return")
646        cgen.endIf()
647
648        cgen.endIf()
649
650        cgen.stmt("*%s += sizeof(VkStructureType)" % self.countVars[1])
651
652        def fatalDefault(cgen):
653            cgen.line("// fatal; the switch is only taken if the extension struct is known");
654            cgen.stmt("abort()")
655            pass
656
657        self.emitForEachStructExtension(
658            cgen,
659            makeVulkanTypeSimple(False, "void", 0, "void"),
660            extParam,
661            forEach,
662            defaultEmit=fatalDefault,
663            rootTypeVar=ROOT_TYPE_PARAM)
664
665    def onEnd(self,):
666        VulkanWrapperGenerator.onEnd(self)
667
668        def forEachExtensionCounting(ext, castedAccess, cgen):
669            cgen.funcCall(None, self.countingPrefix + ext.name,
670                          [self.featureBitsVar, self.rootTypeVar, castedAccess, self.countVars[1]])
671
672        self.module.appendImpl(
673            self.codegen.makeFuncImpl(
674                self.extensionCountingPrototype,
675                lambda cgen: self.doExtensionStructCountCodegen(
676                    cgen,
677                    STRUCT_EXTENSION_PARAM,
678                    forEachExtensionCounting,
679                    self.extensionCountingPrototype)))
680