xref: /aosp_15_r20/external/mesa3d/src/gfxstream/codegen/scripts/cereal/marshaling.py (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1# Copyright 2018 Google LLC
2# SPDX-License-Identifier: MIT
3
4from copy import copy
5import hashlib, sys
6
7from .common.codegen import CodeGen, VulkanAPIWrapper
8from .common.vulkantypes import \
9        VulkanAPI, makeVulkanTypeSimple, iterateVulkanType, VulkanTypeIterator, Atom, FuncExpr, FuncExprVal, FuncLambda
10
11from .wrapperdefs import VulkanWrapperGenerator
12from .wrapperdefs import VULKAN_STREAM_VAR_NAME
13from .wrapperdefs import ROOT_TYPE_VAR_NAME, ROOT_TYPE_PARAM
14from .wrapperdefs import STREAM_RET_TYPE
15from .wrapperdefs import MARSHAL_INPUT_VAR_NAME
16from .wrapperdefs import UNMARSHAL_INPUT_VAR_NAME
17from .wrapperdefs import PARAMETERS_MARSHALING
18from .wrapperdefs import PARAMETERS_MARSHALING_GUEST
19from .wrapperdefs import STYPE_OVERRIDE
20from .wrapperdefs import STRUCT_EXTENSION_PARAM, STRUCT_EXTENSION_PARAM_FOR_WRITE, EXTENSION_SIZE_WITH_STREAM_FEATURES_API_NAME
21from .wrapperdefs import API_PREFIX_MARSHAL
22from .wrapperdefs import API_PREFIX_UNMARSHAL
23
24from .marshalingdefs import KNOWN_FUNCTION_OPCODES, CUSTOM_MARSHAL_TYPES
25
26class VulkanMarshalingCodegen(VulkanTypeIterator):
27
28    def __init__(self,
29                 cgen,
30                 streamVarName,
31                 rootTypeVarName,
32                 inputVarName,
33                 marshalPrefix,
34                 direction = "write",
35                 forApiOutput = False,
36                 dynAlloc = False,
37                 mapHandles = True,
38                 handleMapOverwrites = False,
39                 doFiltering = True):
40        self.cgen = cgen
41        self.direction = direction
42        self.processSimple = "write" if self.direction == "write" else "read"
43        self.forApiOutput = forApiOutput
44
45        self.checked = False
46
47        self.streamVarName = streamVarName
48        self.rootTypeVarName = rootTypeVarName
49        self.inputVarName = inputVarName
50        self.marshalPrefix = marshalPrefix
51
52        self.exprAccessor = lambda t: self.cgen.generalAccess(t, parentVarName = self.inputVarName, asPtr = True)
53        self.exprValueAccessor = lambda t: self.cgen.generalAccess(t, parentVarName = self.inputVarName, asPtr = False)
54        self.exprPrimitiveValueAccessor = lambda t: self.cgen.generalAccess(t, parentVarName = self.inputVarName, asPtr = False)
55        self.lenAccessor = lambda t: self.cgen.generalLengthAccess(t, parentVarName = self.inputVarName)
56        self.lenAccessorGuard = lambda t: self.cgen.generalLengthAccessGuard(
57            t, parentVarName=self.inputVarName)
58        self.filterVarAccessor = lambda t: self.cgen.filterVarAccess(t, parentVarName = self.inputVarName)
59
60        self.dynAlloc = dynAlloc
61        self.mapHandles = mapHandles
62        self.handleMapOverwrites = handleMapOverwrites
63        self.doFiltering = doFiltering
64
65    def getTypeForStreaming(self, vulkanType):
66        res = copy(vulkanType)
67
68        if not vulkanType.accessibleAsPointer():
69            res = res.getForAddressAccess()
70
71        if vulkanType.staticArrExpr:
72            res = res.getForAddressAccess()
73
74        if self.direction == "write":
75            return res
76        else:
77            return res.getForNonConstAccess()
78
79    def makeCastExpr(self, vulkanType):
80        return "(%s)" % (
81            self.cgen.makeCTypeDecl(vulkanType, useParamName=False))
82
83    def genStreamCall(self, vulkanType, toStreamExpr, sizeExpr):
84        varname = self.streamVarName
85        func = self.processSimple
86        cast = self.makeCastExpr(self.getTypeForStreaming(vulkanType))
87
88        self.cgen.stmt(
89            "%s->%s(%s%s, %s)" % (varname, func, cast, toStreamExpr, sizeExpr))
90
91    def genPrimitiveStreamCall(self, vulkanType, access):
92        varname = self.streamVarName
93
94        self.cgen.streamPrimitive(
95            self.typeInfo,
96            varname,
97            access,
98            vulkanType,
99            direction=self.direction)
100
101    def genHandleMappingCall(self, vulkanType, access, lenAccess):
102
103        if lenAccess is None:
104            lenAccess = "1"
105            handle64Bytes = "8"
106        else:
107            handle64Bytes = "%s * 8" % lenAccess
108
109        handle64Var = self.cgen.var()
110        if lenAccess != "1":
111            self.cgen.beginIf(lenAccess)
112            self.cgen.stmt("uint64_t* %s" % handle64Var)
113            self.cgen.stmt(
114                "%s->alloc((void**)&%s, %s * 8)" % \
115                (self.streamVarName, handle64Var, lenAccess))
116            handle64VarAccess = handle64Var
117            handle64VarType = \
118                makeVulkanTypeSimple(False, "uint64_t", 1, paramName=handle64Var)
119        else:
120            self.cgen.stmt("uint64_t %s" % handle64Var)
121            handle64VarAccess = "&%s" % handle64Var
122            handle64VarType = \
123                makeVulkanTypeSimple(False, "uint64_t", 0, paramName=handle64Var)
124
125        if self.direction == "write":
126            if self.handleMapOverwrites:
127                self.cgen.stmt(
128                    "static_assert(8 == sizeof(%s), \"handle map overwrite requires %s to be 8 bytes long\")" % \
129                            (vulkanType.typeName, vulkanType.typeName))
130                self.cgen.stmt(
131                    "%s->handleMapping()->mapHandles_%s((%s*)%s, %s)" %
132                    (self.streamVarName, vulkanType.typeName, vulkanType.typeName,
133                    access, lenAccess))
134                self.genStreamCall(vulkanType, access, "8 * %s" % lenAccess)
135            else:
136                self.cgen.stmt(
137                    "%s->handleMapping()->mapHandles_%s_u64(%s, %s, %s)" %
138                    (self.streamVarName, vulkanType.typeName,
139                    access,
140                    handle64VarAccess, lenAccess))
141                self.genStreamCall(handle64VarType, handle64VarAccess, handle64Bytes)
142        else:
143            self.genStreamCall(handle64VarType, handle64VarAccess, handle64Bytes)
144            self.cgen.stmt(
145                "%s->handleMapping()->mapHandles_u64_%s(%s, %s%s, %s)" %
146                (self.streamVarName, vulkanType.typeName,
147                handle64VarAccess,
148                self.makeCastExpr(vulkanType.getForNonConstAccess()), access,
149                lenAccess))
150
151        if lenAccess != "1":
152            self.cgen.endIf()
153
154    def doAllocSpace(self, vulkanType):
155        if self.dynAlloc and self.direction == "read":
156            access = self.exprAccessor(vulkanType)
157            lenAccess = self.lenAccessor(vulkanType)
158            sizeof = self.cgen.sizeofExpr( \
159                         vulkanType.getForValueAccess())
160            if lenAccess:
161                bytesExpr = "%s * %s" % (lenAccess, sizeof)
162            else:
163                bytesExpr = sizeof
164
165            self.cgen.stmt( \
166                "%s->alloc((void**)&%s, %s)" %
167                    (self.streamVarName,
168                     access, bytesExpr))
169
170    def getOptionalStringFeatureExpr(self, vulkanType):
171        streamFeature = vulkanType.getProtectStreamFeature()
172        if streamFeature is None:
173            return None
174        return "%s->getFeatureBits() & %s" % (self.streamVarName, streamFeature)
175
176    def onCheck(self, vulkanType):
177
178        if self.forApiOutput:
179            return
180
181        featureExpr = self.getOptionalStringFeatureExpr(vulkanType);
182
183        self.checked = True
184
185        access = self.exprAccessor(vulkanType)
186
187        needConsistencyCheck = False
188
189        self.cgen.line("// WARNING PTR CHECK")
190        if (self.dynAlloc and self.direction == "read") or self.direction == "write":
191            checkAccess = self.exprAccessor(vulkanType)
192            addrExpr = "&" + checkAccess
193            sizeExpr = self.cgen.sizeofExpr(vulkanType)
194        else:
195            checkName = "check_%s" % vulkanType.paramName
196            self.cgen.stmt("%s %s" % (
197                self.cgen.makeCTypeDecl(vulkanType, useParamName = False), checkName))
198            checkAccess = checkName
199            addrExpr = "&" + checkAccess
200            sizeExpr = self.cgen.sizeofExpr(vulkanType)
201            needConsistencyCheck = True
202
203        if featureExpr is not None:
204            self.cgen.beginIf(featureExpr)
205
206        self.genPrimitiveStreamCall(
207            vulkanType,
208            checkAccess)
209
210        if featureExpr is not None:
211            self.cgen.endIf()
212
213        if featureExpr is not None:
214            self.cgen.beginIf("(!(%s) || %s)" % (featureExpr, access))
215        else:
216            self.cgen.beginIf(access)
217
218        if needConsistencyCheck and featureExpr is None:
219            self.cgen.beginIf("!(%s)" % checkName)
220            self.cgen.stmt(
221                "fprintf(stderr, \"fatal: %s inconsistent between guest and host\\n\")" % (access))
222            self.cgen.endIf()
223
224
225    def onCheckWithNullOptionalStringFeature(self, vulkanType):
226        self.cgen.beginIf("%s->getFeatureBits() & VULKAN_STREAM_FEATURE_NULL_OPTIONAL_STRINGS_BIT" % self.streamVarName)
227        self.onCheck(vulkanType)
228
229    def endCheckWithNullOptionalStringFeature(self, vulkanType):
230        self.endCheck(vulkanType)
231        self.cgen.endIf()
232        self.cgen.beginElse()
233
234    def finalCheckWithNullOptionalStringFeature(self, vulkanType):
235        self.cgen.endElse()
236
237    def endCheck(self, vulkanType):
238
239        if self.checked:
240            self.cgen.endIf()
241            self.checked = False
242
243    def genFilterFunc(self, filterfunc, env):
244
245        def loop(expr, lambdaEnv={}):
246            def do_func(expr):
247                fnamestr = expr.name.name
248                if "not" == fnamestr:
249                    return "!(%s)" % (loop(expr.args[0], lambdaEnv))
250                if "eq" == fnamestr:
251                    return "(%s == %s)" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv))
252                if "and" == fnamestr:
253                    return "(%s && %s)" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv))
254                if "or" == fnamestr:
255                    return "(%s || %s)" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv))
256                if "bitwise_and" == fnamestr:
257                    return "(%s & %s)" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv))
258                if "getfield" == fnamestr:
259                    ptrlevels = get_ptrlevels(expr.args[0].val.name)
260                    if ptrlevels == 0:
261                        return "%s.%s" % (loop(expr.args[0], lambdaEnv), expr.args[1].val)
262                    else:
263                        return "(%s(%s)).%s" % ("*" * ptrlevels, loop(expr.args[0], lambdaEnv), expr.args[1].val)
264
265                if "if" == fnamestr:
266                    return "((%s) ? (%s) : (%s))" % (loop(expr.args[0], lambdaEnv), loop(expr.args[1], lambdaEnv), loop(expr.args[2], lambdaEnv))
267
268                return "%s(%s)" % (fnamestr, ", ".join(map(lambda e: loop(e, lambdaEnv), expr.args)))
269
270            def do_expratom(atomname, lambdaEnv= {}):
271                if lambdaEnv.get(atomname, None) is not None:
272                    return atomname
273
274                enventry = env.get(atomname, None)
275                if None != enventry:
276                    return self.getEnvAccessExpr(atomname)
277                return atomname
278
279            def get_ptrlevels(atomname, lambdaEnv= {}):
280                if lambdaEnv.get(atomname, None) is not None:
281                    return 0
282
283                enventry = env.get(atomname, None)
284                if None != enventry:
285                    return self.getPointerIndirectionLevels(atomname)
286
287                return 0
288
289            def do_exprval(expr, lambdaEnv= {}):
290                expratom = expr.val
291
292                if Atom == type(expratom):
293                    return do_expratom(expratom.name, lambdaEnv)
294
295                return "%s" % expratom
296
297            def do_lambda(expr, lambdaEnv= {}):
298                params = expr.vs
299                body = expr.body
300                newEnv = {}
301
302                for (k, v) in lambdaEnv.items():
303                    newEnv[k] = v
304
305                for p in params:
306                    newEnv[p.name] = p.typ
307
308                return "[](%s) { return %s; }" % (", ".join(list(map(lambda p: "%s %s" % (p.typ, p.name), params))), loop(body, lambdaEnv=newEnv))
309
310            if FuncExpr == type(expr):
311                return do_func(expr)
312            if FuncLambda == type(expr):
313                return do_lambda(expr)
314            elif FuncExprVal == type(expr):
315                return do_exprval(expr)
316
317        return loop(filterfunc)
318
319    def beginFilterGuard(self, vulkanType):
320        if vulkanType.filterVar == None:
321            return
322
323        if self.doFiltering == False:
324            return
325
326        filterVarAccess = self.getEnvAccessExpr(vulkanType.filterVar)
327
328        filterValsExpr = None
329        filterFuncExpr = None
330        filterExpr = None
331
332        filterFeature = "%s->getFeatureBits() & VULKAN_STREAM_FEATURE_IGNORED_HANDLES_BIT" % self.streamVarName
333
334        if None != vulkanType.filterVals:
335            filterValsExpr = " || ".join(map(lambda filterval: "(%s == %s)" % (filterval, filterVarAccess), vulkanType.filterVals))
336
337        if None != vulkanType.filterFunc:
338            filterFuncExpr = self.genFilterFunc(vulkanType.filterFunc, self.currentStructInfo.environment)
339
340        if None != filterValsExpr and None != filterFuncExpr:
341            filterExpr = "%s || %s" % (filterValsExpr, filterFuncExpr)
342        elif None == filterValsExpr and None == filterFuncExpr:
343            # Assume is bool
344            self.cgen.beginIf(filterVarAccess)
345        elif None != filterValsExpr:
346            self.cgen.beginIf("(!(%s) || (%s))" % (filterFeature, filterValsExpr))
347        elif None != filterFuncExpr:
348            self.cgen.beginIf("(!(%s) || (%s))" % (filterFeature, filterFuncExpr))
349
350    def endFilterGuard(self, vulkanType, cleanupExpr=None):
351        if vulkanType.filterVar == None:
352            return
353
354        if self.doFiltering == False:
355            return
356
357        if cleanupExpr == None:
358            self.cgen.endIf()
359        else:
360            self.cgen.endIf()
361            self.cgen.beginElse()
362            self.cgen.stmt(cleanupExpr)
363            self.cgen.endElse()
364
365    def getEnvAccessExpr(self, varName):
366        parentEnvEntry = self.currentStructInfo.environment.get(varName, None)
367
368        if parentEnvEntry != None:
369            isParentMember = parentEnvEntry["structmember"]
370
371            if isParentMember:
372                envAccess = self.exprValueAccessor(list(filter(lambda member: member.paramName == varName, self.currentStructInfo.members))[0])
373            else:
374                envAccess = varName
375            return envAccess
376
377        return None
378
379    def getPointerIndirectionLevels(self, varName):
380        parentEnvEntry = self.currentStructInfo.environment.get(varName, None)
381
382        if parentEnvEntry != None:
383            isParentMember = parentEnvEntry["structmember"]
384
385            if isParentMember:
386                return list(filter(lambda member: member.paramName == varName, self.currentStructInfo.members))[0].pointerIndirectionLevels
387            else:
388                return 0
389            return 0
390
391        return 0
392
393
394    def onCompoundType(self, vulkanType):
395
396        access = self.exprAccessor(vulkanType)
397        lenAccess = self.lenAccessor(vulkanType)
398        lenAccessGuard = self.lenAccessorGuard(vulkanType)
399
400        self.beginFilterGuard(vulkanType)
401
402        if vulkanType.pointerIndirectionLevels > 0:
403            self.doAllocSpace(vulkanType)
404
405        if lenAccess is not None:
406            if lenAccessGuard is not None:
407                self.cgen.beginIf(lenAccessGuard)
408            loopVar = "i"
409            access = "%s + %s" % (access, loopVar)
410            forInit = "uint32_t %s = 0" % loopVar
411            forCond = "%s < (uint32_t)%s" % (loopVar, lenAccess)
412            forIncr = "++%s" % loopVar
413            self.cgen.beginFor(forInit, forCond, forIncr)
414
415        accessWithCast = "%s(%s)" % (self.makeCastExpr(
416            self.getTypeForStreaming(vulkanType)), access)
417
418        callParams = [self.streamVarName, self.rootTypeVarName, accessWithCast]
419
420        for (bindName, localName) in vulkanType.binds.items():
421            callParams.append(self.getEnvAccessExpr(localName))
422
423        self.cgen.funcCall(None, self.marshalPrefix + vulkanType.typeName,
424                           callParams)
425
426        if lenAccess is not None:
427            self.cgen.endFor()
428            if lenAccessGuard is not None:
429                self.cgen.endIf()
430
431        if self.direction == "read":
432            self.endFilterGuard(vulkanType, "%s = 0" % self.exprAccessor(vulkanType))
433        else:
434            self.endFilterGuard(vulkanType)
435
436    def onString(self, vulkanType):
437
438        access = self.exprAccessor(vulkanType)
439
440        if self.direction == "write":
441            self.cgen.stmt("%s->putString(%s)" % (self.streamVarName, access))
442        else:
443            castExpr = \
444                self.makeCastExpr( \
445                    self.getTypeForStreaming( \
446                        vulkanType.getForAddressAccess()))
447
448            self.cgen.stmt( \
449                "%s->loadStringInPlace(%s&%s)" % (self.streamVarName, castExpr, access))
450
451    def onStringArray(self, vulkanType):
452
453        access = self.exprAccessor(vulkanType)
454        lenAccess = self.lenAccessor(vulkanType)
455
456        if self.direction == "write":
457            self.cgen.stmt("saveStringArray(%s, %s, %s)" % (self.streamVarName,
458                                                            access, lenAccess))
459        else:
460            castExpr = \
461                self.makeCastExpr( \
462                    self.getTypeForStreaming( \
463                        vulkanType.getForAddressAccess()))
464
465            self.cgen.stmt("%s->loadStringArrayInPlace(%s&%s)" % (self.streamVarName, castExpr, access))
466
467    def onStaticArr(self, vulkanType):
468        access = self.exprValueAccessor(vulkanType)
469        lenAccess = self.lenAccessor(vulkanType)
470        finalLenExpr = "%s * %s" % (lenAccess, self.cgen.sizeofExpr(vulkanType))
471        self.genStreamCall(vulkanType, access, finalLenExpr)
472
473    # Old version VkEncoder may have some sType values conflict with VkDecoder
474    # of new versions. For host decoder, it should not carry the incorrect old
475    # sType values to the |forUnmarshaling| struct. Instead it should overwrite
476    # the sType value.
477    def overwriteSType(self, vulkanType):
478        if self.direction == "read":
479            sTypeParam = copy(vulkanType)
480            sTypeParam.paramName = "sType"
481            sTypeAccess = self.exprAccessor(sTypeParam)
482
483            typeName = vulkanType.parent.typeName
484            if typeName in STYPE_OVERRIDE:
485                self.cgen.stmt("%s = %s" %
486                               (sTypeAccess, STYPE_OVERRIDE[typeName]))
487
488    def onStructExtension(self, vulkanType):
489        self.overwriteSType(vulkanType)
490
491        sTypeParam = copy(vulkanType)
492        sTypeParam.paramName = "sType"
493
494        access = self.exprAccessor(vulkanType)
495        sizeVar = "%s_size" % vulkanType.paramName
496
497        if self.direction == "read":
498            castedAccessExpr = "(%s)(%s)" % ("void*", access)
499        else:
500            castedAccessExpr = access
501
502        sTypeAccess = self.exprAccessor(sTypeParam)
503        self.cgen.beginIf("%s == VK_STRUCTURE_TYPE_MAX_ENUM" %
504                          self.rootTypeVarName)
505        self.cgen.stmt("%s = %s" % (self.rootTypeVarName, sTypeAccess))
506        self.cgen.endIf()
507
508        if self.direction == "read" and self.dynAlloc:
509            self.cgen.stmt("size_t %s" % sizeVar)
510            self.cgen.stmt("%s = %s->getBe32()" % \
511                (sizeVar, self.streamVarName))
512            self.cgen.stmt("%s = nullptr" % access)
513            self.cgen.beginIf(sizeVar)
514            self.cgen.stmt( \
515                    "%s->alloc((void**)&%s, sizeof(VkStructureType))" %
516                    (self.streamVarName, access))
517
518            self.genStreamCall(vulkanType, access, "sizeof(VkStructureType)")
519            self.cgen.stmt("VkStructureType extType = *(VkStructureType*)(%s)" % access)
520            self.cgen.stmt( \
521                "%s->alloc((void**)&%s, %s(%s->getFeatureBits(), %s, %s))" %
522                (self.streamVarName, access, EXTENSION_SIZE_WITH_STREAM_FEATURES_API_NAME, self.streamVarName, self.rootTypeVarName, access))
523            self.cgen.stmt("*(VkStructureType*)%s = extType" % access)
524
525            self.cgen.funcCall(None, self.marshalPrefix + "extension_struct",
526                               [self.streamVarName, self.rootTypeVarName, castedAccessExpr])
527            self.cgen.endIf()
528        else:
529
530            self.cgen.funcCall(None, self.marshalPrefix + "extension_struct",
531                               [self.streamVarName, self.rootTypeVarName, castedAccessExpr])
532
533
534    def onPointer(self, vulkanType):
535        access = self.exprAccessor(vulkanType)
536
537        lenAccess = self.lenAccessor(vulkanType)
538        lenAccessGuard = self.lenAccessorGuard(vulkanType)
539
540        self.beginFilterGuard(vulkanType)
541        self.doAllocSpace(vulkanType)
542
543        if vulkanType.isHandleType() and self.mapHandles:
544            self.genHandleMappingCall(vulkanType, access, lenAccess)
545        else:
546            if self.typeInfo.isNonAbiPortableType(vulkanType.typeName):
547                if lenAccess is not None:
548                    if lenAccessGuard is not None:
549                        self.cgen.beginIf(lenAccessGuard)
550                    self.cgen.beginFor("uint32_t i = 0", "i < (uint32_t)%s" % lenAccess, "++i")
551                    self.genPrimitiveStreamCall(vulkanType.getForValueAccess(), "%s[i]" % access)
552                    self.cgen.endFor()
553                    if lenAccessGuard is not None:
554                        self.cgen.endIf()
555                else:
556                    self.genPrimitiveStreamCall(vulkanType.getForValueAccess(), "(*%s)" % access)
557            else:
558                if lenAccess is not None:
559                    finalLenExpr = "%s * %s" % (
560                        lenAccess, self.cgen.sizeofExpr(vulkanType.getForValueAccess()))
561                else:
562                    finalLenExpr = "%s" % (
563                        self.cgen.sizeofExpr(vulkanType.getForValueAccess()))
564                self.genStreamCall(vulkanType, access, finalLenExpr)
565
566        if self.direction == "read":
567            self.endFilterGuard(vulkanType, "%s = 0" % access)
568        else:
569            self.endFilterGuard(vulkanType)
570
571    def onValue(self, vulkanType):
572        self.beginFilterGuard(vulkanType)
573
574        if vulkanType.isHandleType() and self.mapHandles:
575            access = self.exprAccessor(vulkanType)
576            self.genHandleMappingCall(
577                vulkanType.getForAddressAccess(), access, "1")
578        elif self.typeInfo.isNonAbiPortableType(vulkanType.typeName):
579            access = self.exprPrimitiveValueAccessor(vulkanType)
580            self.genPrimitiveStreamCall(vulkanType, access)
581        else:
582            access = self.exprAccessor(vulkanType)
583            self.genStreamCall(vulkanType, access, self.cgen.sizeofExpr(vulkanType))
584
585        self.endFilterGuard(vulkanType)
586
587    def streamLetParameter(self, structInfo, letParamInfo):
588        filterFeature = "%s->getFeatureBits() & VULKAN_STREAM_FEATURE_IGNORED_HANDLES_BIT" % self.streamVarName
589        self.cgen.stmt("%s %s = 1" % (letParamInfo.typeName, letParamInfo.paramName))
590
591        self.cgen.beginIf(filterFeature)
592
593        if self.direction == "write":
594            bodyExpr = self.currentStructInfo.environment[letParamInfo.paramName]["body"]
595            self.cgen.stmt("%s = %s" % (letParamInfo.paramName, self.genFilterFunc(bodyExpr, self.currentStructInfo.environment)))
596
597        self.genPrimitiveStreamCall(letParamInfo, letParamInfo.paramName)
598
599        self.cgen.endIf()
600
601
602class VulkanMarshaling(VulkanWrapperGenerator):
603
604    def __init__(self, module, typeInfo, variant="host"):
605        VulkanWrapperGenerator.__init__(self, module, typeInfo)
606
607        self.cgenHeader = CodeGen()
608        self.cgenImpl = CodeGen()
609
610        self.variant = variant
611
612        self.currentFeature = None
613        self.apiOpcodes = {}
614        self.dynAlloc = self.variant != "guest"
615
616        if self.variant == "guest":
617            self.marshalingParams = PARAMETERS_MARSHALING_GUEST
618        else:
619            self.marshalingParams = PARAMETERS_MARSHALING
620
621        self.writeCodegen = \
622            VulkanMarshalingCodegen(
623                None,
624                VULKAN_STREAM_VAR_NAME,
625                ROOT_TYPE_VAR_NAME,
626                MARSHAL_INPUT_VAR_NAME,
627                API_PREFIX_MARSHAL,
628                direction = "write")
629
630        self.readCodegen = \
631            VulkanMarshalingCodegen(
632                None,
633                VULKAN_STREAM_VAR_NAME,
634                ROOT_TYPE_VAR_NAME,
635                UNMARSHAL_INPUT_VAR_NAME,
636                API_PREFIX_UNMARSHAL,
637                direction = "read",
638                dynAlloc=self.dynAlloc)
639
640        self.knownDefs = {}
641
642        # Begin Vulkan API opcodes from something high
643        # that is not going to interfere with renderControl
644        # opcodes
645        self.beginOpcodeOld = 20000
646        self.endOpcodeOld = 30000
647
648        self.beginOpcode = 200000000
649        self.endOpcode = 300000000
650        self.knownOpcodes = set()
651
652        self.extensionMarshalPrototype = \
653            VulkanAPI(API_PREFIX_MARSHAL + "extension_struct",
654                      STREAM_RET_TYPE,
655                      self.marshalingParams +
656                      [STRUCT_EXTENSION_PARAM])
657
658        self.extensionUnmarshalPrototype = \
659            VulkanAPI(API_PREFIX_UNMARSHAL + "extension_struct",
660                      STREAM_RET_TYPE,
661                      self.marshalingParams +
662                      [STRUCT_EXTENSION_PARAM_FOR_WRITE])
663
664    def onBegin(self,):
665        VulkanWrapperGenerator.onBegin(self)
666        self.module.appendImpl(self.cgenImpl.makeFuncDecl(self.extensionMarshalPrototype))
667        self.module.appendImpl(self.cgenImpl.makeFuncDecl(self.extensionUnmarshalPrototype))
668
669    def onBeginFeature(self, featureName, featureType):
670        VulkanWrapperGenerator.onBeginFeature(self, featureName, featureType)
671        self.currentFeature = featureName
672
673    def onGenType(self, typeXml, name, alias):
674        VulkanWrapperGenerator.onGenType(self, typeXml, name, alias)
675
676        if name in self.knownDefs:
677            return
678
679        category = self.typeInfo.categoryOf(name)
680
681        if category in ["struct", "union"] and alias:
682            self.module.appendHeader(
683                self.cgenHeader.makeFuncAlias(API_PREFIX_MARSHAL + name,
684                                              API_PREFIX_MARSHAL + alias))
685            self.module.appendHeader(
686                self.cgenHeader.makeFuncAlias(API_PREFIX_UNMARSHAL + name,
687                                              API_PREFIX_UNMARSHAL + alias))
688
689        if category in ["struct", "union"] and not alias:
690
691            structInfo = self.typeInfo.structs[name]
692
693            marshalParams = self.marshalingParams + \
694                [makeVulkanTypeSimple(True, name, 1, MARSHAL_INPUT_VAR_NAME)]
695
696            freeParams = []
697            letParams = []
698
699            for (envname, bindingInfo) in list(sorted(structInfo.environment.items(), key = lambda kv: kv[0])):
700                if None == bindingInfo["binding"]:
701                    freeParams.append(makeVulkanTypeSimple(True, bindingInfo["type"], 0, envname))
702                else:
703                    if not bindingInfo["structmember"]:
704                        letParams.append(makeVulkanTypeSimple(True, bindingInfo["type"], 0, envname))
705
706            marshalPrototype = \
707                VulkanAPI(API_PREFIX_MARSHAL + name,
708                          STREAM_RET_TYPE,
709                          marshalParams + freeParams)
710
711            marshalPrototypeNoFilter = \
712                VulkanAPI(API_PREFIX_MARSHAL + name,
713                          STREAM_RET_TYPE,
714                          marshalParams)
715
716            def structMarshalingCustom(cgen):
717                self.writeCodegen.cgen = cgen
718                self.writeCodegen.currentStructInfo = structInfo
719                self.writeCodegen.cgen.stmt("(void)%s" % ROOT_TYPE_VAR_NAME)
720
721                marshalingCode = \
722                    CUSTOM_MARSHAL_TYPES[name]["common"] + \
723                    CUSTOM_MARSHAL_TYPES[name]["marshaling"].format(
724                        streamVarName=self.writeCodegen.streamVarName,
725                        rootTypeVarName=self.writeCodegen.rootTypeVarName,
726                        inputVarName=self.writeCodegen.inputVarName,
727                        newInputVarName=self.writeCodegen.inputVarName + "_new")
728                for line in marshalingCode.split('\n'):
729                    cgen.line(line)
730
731            def structMarshalingDef(cgen):
732                self.writeCodegen.cgen = cgen
733                self.writeCodegen.currentStructInfo = structInfo
734                self.writeCodegen.cgen.stmt("(void)%s" % ROOT_TYPE_VAR_NAME)
735
736                if category == "struct":
737                    # marshal 'let' parameters first
738                    for letp in letParams:
739                        self.writeCodegen.streamLetParameter(self.typeInfo, letp)
740
741                    for member in structInfo.members:
742                        iterateVulkanType(self.typeInfo, member, self.writeCodegen)
743                if category == "union":
744                    iterateVulkanType(self.typeInfo, structInfo.members[0], self.writeCodegen)
745
746            def structMarshalingDefNoFilter(cgen):
747                self.writeCodegen.cgen = cgen
748                self.writeCodegen.currentStructInfo = structInfo
749                self.writeCodegen.doFiltering = False
750                self.writeCodegen.cgen.stmt("(void)%s" % ROOT_TYPE_VAR_NAME)
751
752                if category == "struct":
753                    # marshal 'let' parameters first
754                    for letp in letParams:
755                        self.writeCodegen.streamLetParameter(self.typeInfo, letp)
756
757                    for member in structInfo.members:
758                        iterateVulkanType(self.typeInfo, member, self.writeCodegen)
759                if category == "union":
760                    iterateVulkanType(self.typeInfo, structInfo.members[0], self.writeCodegen)
761                self.writeCodegen.doFiltering = True
762
763            self.module.appendHeader(
764                self.cgenHeader.makeFuncDecl(marshalPrototype))
765
766            if name in CUSTOM_MARSHAL_TYPES and CUSTOM_MARSHAL_TYPES[name].get("marshaling"):
767                self.module.appendImpl(
768                    self.cgenImpl.makeFuncImpl(
769                        marshalPrototype, structMarshalingCustom))
770            else:
771                self.module.appendImpl(
772                    self.cgenImpl.makeFuncImpl(
773                        marshalPrototype, structMarshalingDef))
774
775            if freeParams != []:
776                self.module.appendHeader(
777                    self.cgenHeader.makeFuncDecl(marshalPrototypeNoFilter))
778                self.module.appendImpl(
779                    self.cgenImpl.makeFuncImpl(
780                        marshalPrototypeNoFilter, structMarshalingDefNoFilter))
781
782            unmarshalPrototype = \
783                VulkanAPI(API_PREFIX_UNMARSHAL + name,
784                          STREAM_RET_TYPE,
785                          self.marshalingParams + [makeVulkanTypeSimple(False, name, 1, UNMARSHAL_INPUT_VAR_NAME)] + freeParams)
786
787            unmarshalPrototypeNoFilter = \
788                VulkanAPI(API_PREFIX_UNMARSHAL + name,
789                          STREAM_RET_TYPE,
790                          self.marshalingParams + [makeVulkanTypeSimple(False, name, 1, UNMARSHAL_INPUT_VAR_NAME)])
791
792            def structUnmarshalingCustom(cgen):
793                self.readCodegen.cgen = cgen
794                self.readCodegen.currentStructInfo = structInfo
795                self.writeCodegen.cgen.stmt("(void)%s" % ROOT_TYPE_VAR_NAME)
796
797                unmarshalingCode = \
798                    CUSTOM_MARSHAL_TYPES[name]["common"] + \
799                    CUSTOM_MARSHAL_TYPES[name]["unmarshaling"].format(
800                        streamVarName=self.readCodegen.streamVarName,
801                        rootTypeVarName=self.readCodegen.rootTypeVarName,
802                        inputVarName=self.readCodegen.inputVarName,
803                        newInputVarName=self.readCodegen.inputVarName + "_new")
804                for line in unmarshalingCode.split('\n'):
805                    cgen.line(line)
806
807            def structUnmarshalingDef(cgen):
808                self.readCodegen.cgen = cgen
809                self.readCodegen.currentStructInfo = structInfo
810                self.writeCodegen.cgen.stmt("(void)%s" % ROOT_TYPE_VAR_NAME)
811
812                if category == "struct":
813                    # unmarshal 'let' parameters first
814                    for letp in letParams:
815                        self.readCodegen.streamLetParameter(self.typeInfo, letp)
816
817                    for member in structInfo.members:
818                        iterateVulkanType(self.typeInfo, member, self.readCodegen)
819                if category == "union":
820                    iterateVulkanType(self.typeInfo, structInfo.members[0], self.readCodegen)
821
822            def structUnmarshalingDefNoFilter(cgen):
823                self.readCodegen.cgen = cgen
824                self.readCodegen.currentStructInfo = structInfo
825                self.readCodegen.doFiltering = False
826                self.writeCodegen.cgen.stmt("(void)%s" % ROOT_TYPE_VAR_NAME)
827
828                if category == "struct":
829                    # unmarshal 'let' parameters first
830                    for letp in letParams:
831                        iterateVulkanType(self.typeInfo, letp, self.readCodegen)
832                    for member in structInfo.members:
833                        iterateVulkanType(self.typeInfo, member, self.readCodegen)
834                if category == "union":
835                    iterateVulkanType(self.typeInfo, structInfo.members[0], self.readCodegen)
836                self.readCodegen.doFiltering = True
837
838            self.module.appendHeader(
839                self.cgenHeader.makeFuncDecl(unmarshalPrototype))
840
841            if name in CUSTOM_MARSHAL_TYPES and CUSTOM_MARSHAL_TYPES[name].get("unmarshaling"):
842                self.module.appendImpl(
843                    self.cgenImpl.makeFuncImpl(
844                        unmarshalPrototype, structUnmarshalingCustom))
845            else:
846                self.module.appendImpl(
847                    self.cgenImpl.makeFuncImpl(
848                        unmarshalPrototype, structUnmarshalingDef))
849
850            if freeParams != []:
851                self.module.appendHeader(
852                    self.cgenHeader.makeFuncDecl(unmarshalPrototypeNoFilter))
853                self.module.appendImpl(
854                    self.cgenImpl.makeFuncImpl(
855                        unmarshalPrototypeNoFilter, structUnmarshalingDefNoFilter))
856
857    def onGenCmd(self, cmdinfo, name, alias):
858        VulkanWrapperGenerator.onGenCmd(self, cmdinfo, name, alias)
859        if name in KNOWN_FUNCTION_OPCODES:
860            opcode = KNOWN_FUNCTION_OPCODES[name]
861        else:
862            hashCode = hashlib.sha256(name.encode()).hexdigest()[:8]
863            hashInt = int(hashCode, 16)
864            opcode = self.beginOpcode + hashInt % (self.endOpcode - self.beginOpcode)
865            hasHashCollision = False
866            while opcode in self.knownOpcodes:
867                hasHashCollision = True
868                opcode += 1
869            if hasHashCollision:
870                print("Hash collision occurred on function '{}'. "
871                      "Please add the following line to marshalingdefs.py:".format(name), file=sys.stderr)
872                print("----------------------", file=sys.stderr)
873                print("    \"{}\": {},".format(name, opcode), file=sys.stderr)
874                print("----------------------", file=sys.stderr)
875
876        self.module.appendHeader(
877            "#define OP_%s %d\n" % (name, opcode))
878        self.apiOpcodes[name] = (opcode, self.currentFeature)
879        self.knownOpcodes.add(opcode)
880
881    def doExtensionStructMarshalingCodegen(self, cgen, retType, extParam, forEach, funcproto, direction):
882        accessVar = "structAccess"
883        sizeVar = "currExtSize"
884        cgen.stmt("VkInstanceCreateInfo* %s = (VkInstanceCreateInfo*)(%s)" % (accessVar, extParam.paramName))
885        cgen.stmt("size_t %s = %s(%s->getFeatureBits(), %s, %s)" % (sizeVar,
886                                                                    EXTENSION_SIZE_WITH_STREAM_FEATURES_API_NAME, VULKAN_STREAM_VAR_NAME, ROOT_TYPE_VAR_NAME, extParam.paramName))
887
888        cgen.beginIf("!%s && %s" % (sizeVar, extParam.paramName))
889
890        cgen.line("// unknown struct extension; skip and call on its pNext field");
891        cgen.funcCall(None, funcproto.name, [
892                      "vkStream", ROOT_TYPE_VAR_NAME, "(void*)%s->pNext" % accessVar])
893        cgen.stmt("return")
894
895        cgen.endIf()
896        cgen.beginElse()
897
898        cgen.line("// known or null extension struct")
899
900        if direction == "write":
901            cgen.stmt("vkStream->putBe32(%s)" % sizeVar)
902        elif not self.dynAlloc:
903            cgen.stmt("vkStream->getBe32()");
904
905        cgen.beginIf("!%s" % (sizeVar))
906        cgen.line("// exit if this was a null extension struct (size == 0 in this branch)")
907        cgen.stmt("return")
908        cgen.endIf()
909
910        cgen.endIf()
911
912        # Now we can do stream stuff
913        if direction == "write":
914            cgen.stmt("vkStream->write(%s, sizeof(VkStructureType))" % extParam.paramName)
915        elif not self.dynAlloc:
916            cgen.stmt("uint64_t pNext_placeholder")
917            placeholderAccess = "(&pNext_placeholder)"
918            cgen.stmt("vkStream->read((void*)(&pNext_placeholder), sizeof(VkStructureType))")
919            cgen.stmt("(void)pNext_placeholder")
920
921        def fatalDefault(cgen):
922            cgen.line("// fatal; the switch is only taken if the extension struct is known")
923            if self.variant != "guest":
924                cgen.stmt("fprintf(stderr, \" %s, Unhandled Vulkan structure type %s [%d], aborting.\\n\", __func__, string_VkStructureType(VkStructureType(structType)), structType)")
925            cgen.stmt("abort()")
926            pass
927
928        self.emitForEachStructExtension(
929            cgen,
930            retType,
931            extParam,
932            forEach,
933            defaultEmit=fatalDefault,
934            rootTypeVar=ROOT_TYPE_PARAM)
935
936    def onEnd(self,):
937        VulkanWrapperGenerator.onEnd(self)
938
939        def forEachExtensionMarshal(ext, castedAccess, cgen):
940            cgen.funcCall(None, API_PREFIX_MARSHAL + ext.name,
941                          [VULKAN_STREAM_VAR_NAME, ROOT_TYPE_VAR_NAME, castedAccess])
942
943        def forEachExtensionUnmarshal(ext, castedAccess, cgen):
944            cgen.funcCall(None, API_PREFIX_UNMARSHAL + ext.name,
945                          [VULKAN_STREAM_VAR_NAME, ROOT_TYPE_VAR_NAME, castedAccess])
946
947        self.module.appendImpl(
948            self.cgenImpl.makeFuncImpl(
949                self.extensionMarshalPrototype,
950                lambda cgen: self.doExtensionStructMarshalingCodegen(
951                    cgen,
952                    STREAM_RET_TYPE,
953                    STRUCT_EXTENSION_PARAM,
954                    forEachExtensionMarshal,
955                    self.extensionMarshalPrototype,
956                    "write")))
957
958        self.module.appendImpl(
959            self.cgenImpl.makeFuncImpl(
960                self.extensionUnmarshalPrototype,
961                lambda cgen: self.doExtensionStructMarshalingCodegen(
962                    cgen,
963                    STREAM_RET_TYPE,
964                    STRUCT_EXTENSION_PARAM_FOR_WRITE,
965                    forEachExtensionUnmarshal,
966                    self.extensionUnmarshalPrototype,
967                    "read")))
968
969        opcode2stringPrototype = \
970            VulkanAPI("api_opcode_to_string",
971                          makeVulkanTypeSimple(True, "char", 1, "none"),
972                          [ makeVulkanTypeSimple(True, "uint32_t", 0, "opcode") ])
973
974        self.module.appendHeader(
975            self.cgenHeader.makeFuncDecl(opcode2stringPrototype))
976
977        def emitOpcode2StringImpl(apiOpcodes, cgen):
978            cgen.line("switch(opcode)")
979            cgen.beginBlock()
980
981            currFeature = None
982
983            for (name, (opcodeNum, feature)) in sorted(apiOpcodes.items(), key = lambda x : x[1][0]):
984                if not currFeature:
985                    cgen.leftline("#ifdef %s" % feature)
986                    currFeature = feature
987
988                if currFeature and feature != currFeature:
989                    cgen.leftline("#endif")
990                    cgen.leftline("#ifdef %s" % feature)
991                    currFeature = feature
992
993                cgen.line("case OP_%s:" % name)
994                cgen.beginBlock()
995                cgen.stmt("return \"OP_%s\"" % name)
996                cgen.endBlock()
997
998            if currFeature:
999                cgen.leftline("#endif")
1000
1001            cgen.line("default:")
1002            cgen.beginBlock()
1003            cgen.stmt("return \"OP_UNKNOWN_API_CALL\"")
1004            cgen.endBlock()
1005
1006            cgen.endBlock()
1007
1008        self.module.appendImpl(
1009            self.cgenImpl.makeFuncImpl(
1010                opcode2stringPrototype,
1011                lambda cgen: emitOpcode2StringImpl(self.apiOpcodes, cgen)))
1012
1013        self.module.appendHeader(
1014            "#define OP_vkFirst_old %d\n" % (self.beginOpcodeOld))
1015        self.module.appendHeader(
1016            "#define OP_vkLast_old %d\n" % (self.endOpcodeOld))
1017        self.module.appendHeader(
1018            "#define OP_vkFirst %d\n" % (self.beginOpcode))
1019        self.module.appendHeader(
1020            "#define OP_vkLast %d\n" % (self.endOpcode))
1021