xref: /aosp_15_r20/external/mesa3d/src/gfxstream/codegen/scripts/cereal/testing.py (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1# Copyright 2018 Google LLC
2# SPDX-License-Identifier: MIT
3
4from copy import copy
5
6from .common.codegen import CodeGen
7from .common.vulkantypes import \
8        VulkanAPI, makeVulkanTypeSimple, iterateVulkanType, VulkanTypeIterator
9
10from .wrapperdefs import VulkanWrapperGenerator
11from .wrapperdefs import EQUALITY_VAR_NAMES
12from .wrapperdefs import EQUALITY_ON_FAIL_VAR
13from .wrapperdefs import EQUALITY_ON_FAIL_VAR_TYPE
14from .wrapperdefs import EQUALITY_RET_TYPE
15from .wrapperdefs import API_PREFIX_EQUALITY
16from .wrapperdefs import STRUCT_EXTENSION_PARAM, STRUCT_EXTENSION_PARAM2
17
18class VulkanEqualityCodegen(VulkanTypeIterator):
19
20    def __init__(self, cgen, inputVars, onFailCompareVar, prefix):
21        self.cgen = cgen
22        self.inputVars = inputVars
23        self.onFailCompareVar = onFailCompareVar
24        self.prefix = prefix
25
26        def makeAccess(varName, asPtr = True):
27            return lambda t: self.cgen.generalAccess(t, parentVarName = varName, asPtr = asPtr)
28
29        def makeLengthAccess(varName):
30            return lambda t: self.cgen.generalLengthAccess(t, parentVarName = varName)
31
32        def makeLengthAccessGuard(varName):
33            return lambda t: self.cgen.generalLengthAccessGuard(t, parentVarName=varName)
34
35        self.exprAccessorLhs = makeAccess(self.inputVars[0])
36        self.exprAccessorRhs = makeAccess(self.inputVars[1])
37
38        self.exprAccessorValueLhs = makeAccess(self.inputVars[0], asPtr = False)
39        self.exprAccessorValueRhs = makeAccess(self.inputVars[1], asPtr = False)
40
41        self.lenAccessorLhs = makeLengthAccess(self.inputVars[0])
42        self.lenAccessorRhs = makeLengthAccess(self.inputVars[1])
43
44        self.lenAccessGuardLhs = makeLengthAccessGuard(self.inputVars[0])
45        self.lenAccessGuardRhs = makeLengthAccessGuard(self.inputVars[1])
46
47        self.checked = False
48
49    def getTypeForCompare(self, vulkanType):
50        res = copy(vulkanType)
51
52        if not vulkanType.accessibleAsPointer():
53            res = res.getForAddressAccess()
54
55        if vulkanType.staticArrExpr:
56            res = res.getForAddressAccess()
57
58        return res
59
60    def makeCastExpr(self, vulkanType):
61        return "(%s)" % (
62            self.cgen.makeCTypeDecl(vulkanType, useParamName=False))
63
64    def makeEqualExpr(self, lhs, rhs):
65        return "(%s) == (%s)" % (lhs, rhs)
66
67    def makeEqualBufExpr(self, lhs, rhs, size):
68        return "(memcmp(%s, %s, %s) == 0)" % (lhs, rhs, size)
69
70    def makeEqualStringExpr(self, lhs, rhs):
71        return "(strcmp(%s, %s) == 0)" % (lhs, rhs)
72
73    def makeBothNotNullExpr(self, lhs, rhs):
74        return "(%s) && (%s)" % (lhs, rhs)
75
76    def makeBothNullExpr(self, lhs, rhs):
77        return "!(%s) && !(%s)" % (lhs, rhs)
78
79    def compareWithConsequence(self, compareExpr, vulkanType, errMsg=""):
80        self.cgen.stmt("if (!(%s)) { %s(\"%s (Error: %s)\"); }" %
81                       (compareExpr, self.onFailCompareVar,
82                        self.exprAccessorValueLhs(vulkanType), errMsg))
83
84    def onCheck(self, vulkanType):
85
86        self.checked = True
87
88        accessLhs = self.exprAccessorLhs(vulkanType)
89        accessRhs = self.exprAccessorRhs(vulkanType)
90
91        bothNull = self.makeBothNullExpr(accessLhs, accessRhs)
92        bothNotNull = self.makeBothNotNullExpr(accessLhs, accessRhs)
93        nullMatchExpr = "(%s) || (%s)" % (bothNull, bothNotNull)
94
95        self.compareWithConsequence( \
96            nullMatchExpr,
97            vulkanType,
98            "Mismatch in optional field")
99
100        skipStreamInternal = vulkanType.typeName == "void"
101
102        if skipStreamInternal:
103            return
104
105        self.cgen.beginIf("%s && %s" % (accessLhs, accessRhs))
106
107    def endCheck(self, vulkanType):
108
109        skipStreamInternal = vulkanType.typeName == "void"
110        if skipStreamInternal:
111            return
112
113        if self.checked:
114            self.cgen.endIf()
115            self.checked = False
116
117    def onCompoundType(self, vulkanType):
118        accessLhs = self.exprAccessorLhs(vulkanType)
119        accessRhs = self.exprAccessorRhs(vulkanType)
120
121        lenAccessLhs = self.lenAccessorLhs(vulkanType)
122        lenAccessRhs = self.lenAccessorRhs(vulkanType)
123
124        lenAccessGuardLhs = self.lenAccessGuardLhs(vulkanType)
125        lenAccessGuardRhs = self.lenAccessGuardRhs(vulkanType)
126
127        needNullCheck = vulkanType.pointerIndirectionLevels > 0
128
129        if needNullCheck:
130            bothNotNullExpr = self.makeBothNotNullExpr(accessLhs, accessRhs)
131            self.cgen.beginIf(bothNotNullExpr)
132
133        if lenAccessLhs is not None:
134            equalLenExpr = self.makeEqualExpr(lenAccessLhs, lenAccessRhs)
135
136            self.compareWithConsequence( \
137                equalLenExpr,
138                vulkanType, "Lengths not equal")
139
140            loopVar = "i"
141            accessLhs = "%s + %s" % (accessLhs, loopVar)
142            accessRhs = "%s + %s" % (accessRhs, loopVar)
143            forInit = "uint32_t %s = 0" % loopVar
144            forCond = "%s < (uint32_t)%s" % (loopVar, lenAccessLhs)
145            forIncr = "++%s" % loopVar
146
147            if needNullCheck:
148                self.cgen.beginIf(equalLenExpr)
149
150            if lenAccessGuardLhs is not None:
151                self.cgen.beginIf(lenAccessGuardLhs)
152
153            self.cgen.beginFor(forInit, forCond, forIncr)
154
155        self.cgen.funcCall(None, self.prefix + vulkanType.typeName,
156                           [accessLhs, accessRhs, self.onFailCompareVar])
157
158        if lenAccessLhs is not None:
159            self.cgen.endFor()
160            if lenAccessGuardLhs is not None:
161                self.cgen.endIf()
162            if needNullCheck:
163                self.cgen.endIf()
164
165        if needNullCheck:
166            self.cgen.endIf()
167
168    def onString(self, vulkanType):
169        accessLhs = self.exprAccessorLhs(vulkanType)
170        accessRhs = self.exprAccessorRhs(vulkanType)
171
172        bothNullExpr = self.makeBothNullExpr(accessLhs, accessRhs)
173        bothNotNullExpr = self.makeBothNotNullExpr(accessLhs, accessRhs)
174        nullMatchExpr = "(%s) || (%s)" % (bothNullExpr, bothNotNullExpr)
175
176        self.compareWithConsequence( \
177            nullMatchExpr,
178            vulkanType,
179            "Mismatch in string pointer nullness")
180
181        self.cgen.beginIf(bothNotNullExpr)
182
183        self.compareWithConsequence(
184            self.makeEqualStringExpr(accessLhs, accessRhs),
185            vulkanType, "Unequal strings")
186
187        self.cgen.endIf()
188
189    def onStringArray(self, vulkanType):
190        accessLhs = self.exprAccessorLhs(vulkanType)
191        accessRhs = self.exprAccessorRhs(vulkanType)
192
193        lenAccessLhs = self.lenAccessorLhs(vulkanType)
194        lenAccessRhs = self.lenAccessorRhs(vulkanType)
195
196        lenAccessGuardLhs = self.lenAccessGuardLhs(vulkanType)
197        lenAccessGuardRhs = self.lenAccessGuardRhs(vulkanType)
198
199        bothNullExpr = self.makeBothNullExpr(accessLhs, accessRhs)
200        bothNotNullExpr = self.makeBothNotNullExpr(accessLhs, accessRhs)
201        nullMatchExpr = "(%s) || (%s)" % (bothNullExpr, bothNotNullExpr)
202
203        self.compareWithConsequence( \
204            nullMatchExpr,
205            vulkanType,
206            "Mismatch in string array pointer nullness")
207
208        equalLenExpr = self.makeEqualExpr(lenAccessLhs, lenAccessRhs)
209
210        self.compareWithConsequence( \
211            equalLenExpr,
212            vulkanType, "Lengths not equal in string array")
213
214        self.compareWithConsequence( \
215            equalLenExpr,
216            vulkanType, "Lengths not equal in string array")
217
218        self.cgen.beginIf("%s && %s" % (equalLenExpr, bothNotNullExpr))
219
220        loopVar = "i"
221        accessLhs = "*(%s + %s)" % (accessLhs, loopVar)
222        accessRhs = "*(%s + %s)" % (accessRhs, loopVar)
223        forInit = "uint32_t %s = 0" % loopVar
224        forCond = "%s < (uint32_t)%s" % (loopVar, lenAccessLhs)
225        forIncr = "++%s" % loopVar
226
227        if lenAccessGuardLhs is not None:
228            self.cgen.beginIf(lenAccessGuardLhs)
229
230        self.cgen.beginFor(forInit, forCond, forIncr)
231
232        self.compareWithConsequence(
233            self.makeEqualStringExpr(accessLhs, accessRhs),
234            vulkanType, "Unequal string in string array")
235
236        self.cgen.endFor()
237
238        if lenAccessGuardLhs is not None:
239            self.cgen.endIf()
240
241        self.cgen.endIf()
242
243    def onStaticArr(self, vulkanType):
244        accessLhs = self.exprAccessorLhs(vulkanType)
245        accessRhs = self.exprAccessorRhs(vulkanType)
246
247        lenAccessLhs = self.lenAccessorLhs(vulkanType)
248
249        finalLenExpr = "%s * %s" % (lenAccessLhs,
250                                    self.cgen.sizeofExpr(vulkanType))
251
252        self.compareWithConsequence(
253            self.makeEqualBufExpr(accessLhs, accessRhs, finalLenExpr),
254            vulkanType, "Unequal static array")
255
256    def onStructExtension(self, vulkanType):
257        lhs = self.exprAccessorLhs(vulkanType)
258        rhs = self.exprAccessorRhs(vulkanType)
259
260        self.cgen.beginIf(lhs)
261        self.cgen.funcCall(None, self.prefix + "extension_struct",
262                           [lhs, rhs, self.onFailCompareVar])
263        self.cgen.endIf()
264
265    def onPointer(self, vulkanType):
266        accessLhs = self.exprAccessorLhs(vulkanType)
267        accessRhs = self.exprAccessorRhs(vulkanType)
268
269        skipStreamInternal = vulkanType.typeName == "void"
270        if skipStreamInternal:
271            return
272
273        lenAccessLhs = self.lenAccessorLhs(vulkanType)
274        lenAccessRhs = self.lenAccessorRhs(vulkanType)
275
276        if lenAccessLhs is not None:
277            self.compareWithConsequence( \
278                self.makeEqualExpr(lenAccessLhs, lenAccessRhs),
279                vulkanType, "Lengths not equal")
280
281            finalLenExpr = "%s * %s" % (lenAccessLhs,
282                                        self.cgen.sizeofExpr(
283                                            vulkanType.getForValueAccess()))
284        else:
285            finalLenExpr = self.cgen.sizeofExpr(vulkanType.getForValueAccess())
286
287        self.compareWithConsequence(
288            self.makeEqualBufExpr(accessLhs, accessRhs, finalLenExpr),
289            vulkanType, "Unequal dyn array")
290
291    def onValue(self, vulkanType):
292        accessLhs = self.exprAccessorValueLhs(vulkanType)
293        accessRhs = self.exprAccessorValueRhs(vulkanType)
294        self.compareWithConsequence(
295            self.makeEqualExpr(accessLhs, accessRhs), vulkanType,
296            "Value not equal")
297
298
299class VulkanTesting(VulkanWrapperGenerator):
300
301    def __init__(self, module, typeInfo):
302        VulkanWrapperGenerator.__init__(self, module, typeInfo)
303
304        self.codegen = CodeGen()
305
306        self.equalityCodegen = \
307            VulkanEqualityCodegen(
308                None,
309                EQUALITY_VAR_NAMES,
310                EQUALITY_ON_FAIL_VAR,
311                API_PREFIX_EQUALITY)
312
313        self.knownDefs = {}
314
315        self.extensionTestingPrototype = \
316            VulkanAPI(API_PREFIX_EQUALITY + "extension_struct",
317                      EQUALITY_RET_TYPE,
318                      [STRUCT_EXTENSION_PARAM,
319                       STRUCT_EXTENSION_PARAM2,
320                       EQUALITY_ON_FAIL_VAR_TYPE])
321
322    def onBegin(self,):
323        VulkanWrapperGenerator.onBegin(self)
324        self.module.appendImpl(self.codegen.makeFuncDecl(
325            self.extensionTestingPrototype))
326
327    def onGenType(self, typeXml, name, alias):
328        VulkanWrapperGenerator.onGenType(self, typeXml, name, alias)
329
330        if name in self.knownDefs:
331            return
332
333        category = self.typeInfo.categoryOf(name)
334
335        if category in ["struct", "union"] and alias:
336            self.module.appendHeader(
337                self.codegen.makeFuncAlias(API_PREFIX_EQUALITY + name,
338                                           API_PREFIX_EQUALITY + alias))
339
340        if category in ["struct", "union"] and not alias:
341
342            structInfo = self.typeInfo.structs[name]
343
344            typeFromName = \
345                lambda varname: makeVulkanTypeSimple(True, name, 1, varname)
346
347            compareParams = \
348                list(map(typeFromName, EQUALITY_VAR_NAMES)) + \
349                [EQUALITY_ON_FAIL_VAR_TYPE]
350
351            comparePrototype = \
352                VulkanAPI(API_PREFIX_EQUALITY + name,
353                          EQUALITY_RET_TYPE,
354                          compareParams)
355
356            def structCompareDef(cgen):
357                self.equalityCodegen.cgen = cgen
358                for member in structInfo.members:
359                    iterateVulkanType(self.typeInfo, member,
360                                      self.equalityCodegen)
361
362            self.module.appendHeader(
363                self.codegen.makeFuncDecl(comparePrototype))
364            self.module.appendImpl(
365                self.codegen.makeFuncImpl(comparePrototype, structCompareDef))
366
367    def onGenCmd(self, cmdinfo, name, alias):
368        VulkanWrapperGenerator.onGenCmd(self, cmdinfo, name, alias)
369
370    def onEnd(self,):
371        VulkanWrapperGenerator.onEnd(self)
372
373        def forEachExtensionCompare(ext, castedAccess, cgen):
374            cgen.funcCall(None, API_PREFIX_EQUALITY + ext.name,
375                          [castedAccess,
376                           cgen.makeReinterpretCast(
377                               STRUCT_EXTENSION_PARAM2.paramName, ext.name),
378                           EQUALITY_ON_FAIL_VAR])
379
380        self.module.appendImpl(
381            self.codegen.makeFuncImpl(
382                self.extensionTestingPrototype,
383                lambda cgen: self.emitForEachStructExtension(
384                    cgen,
385                    EQUALITY_RET_TYPE,
386                    STRUCT_EXTENSION_PARAM,
387                    forEachExtensionCompare)))
388