1# Copyright 2018 Google LLC 2# SPDX-License-Identifier: MIT 3 4from copy import copy 5 6from .common.codegen import CodeGen 7from .common.vulkantypes import \ 8 VulkanAPI, makeVulkanTypeSimple, iterateVulkanType, VulkanTypeIterator 9 10from .wrapperdefs import VulkanWrapperGenerator 11from .wrapperdefs import EQUALITY_VAR_NAMES 12from .wrapperdefs import EQUALITY_ON_FAIL_VAR 13from .wrapperdefs import EQUALITY_ON_FAIL_VAR_TYPE 14from .wrapperdefs import EQUALITY_RET_TYPE 15from .wrapperdefs import API_PREFIX_EQUALITY 16from .wrapperdefs import STRUCT_EXTENSION_PARAM, STRUCT_EXTENSION_PARAM2 17 18class VulkanEqualityCodegen(VulkanTypeIterator): 19 20 def __init__(self, cgen, inputVars, onFailCompareVar, prefix): 21 self.cgen = cgen 22 self.inputVars = inputVars 23 self.onFailCompareVar = onFailCompareVar 24 self.prefix = prefix 25 26 def makeAccess(varName, asPtr = True): 27 return lambda t: self.cgen.generalAccess(t, parentVarName = varName, asPtr = asPtr) 28 29 def makeLengthAccess(varName): 30 return lambda t: self.cgen.generalLengthAccess(t, parentVarName = varName) 31 32 def makeLengthAccessGuard(varName): 33 return lambda t: self.cgen.generalLengthAccessGuard(t, parentVarName=varName) 34 35 self.exprAccessorLhs = makeAccess(self.inputVars[0]) 36 self.exprAccessorRhs = makeAccess(self.inputVars[1]) 37 38 self.exprAccessorValueLhs = makeAccess(self.inputVars[0], asPtr = False) 39 self.exprAccessorValueRhs = makeAccess(self.inputVars[1], asPtr = False) 40 41 self.lenAccessorLhs = makeLengthAccess(self.inputVars[0]) 42 self.lenAccessorRhs = makeLengthAccess(self.inputVars[1]) 43 44 self.lenAccessGuardLhs = makeLengthAccessGuard(self.inputVars[0]) 45 self.lenAccessGuardRhs = makeLengthAccessGuard(self.inputVars[1]) 46 47 self.checked = False 48 49 def getTypeForCompare(self, vulkanType): 50 res = copy(vulkanType) 51 52 if not vulkanType.accessibleAsPointer(): 53 res = res.getForAddressAccess() 54 55 if vulkanType.staticArrExpr: 56 res = res.getForAddressAccess() 57 58 return res 59 60 def makeCastExpr(self, vulkanType): 61 return "(%s)" % ( 62 self.cgen.makeCTypeDecl(vulkanType, useParamName=False)) 63 64 def makeEqualExpr(self, lhs, rhs): 65 return "(%s) == (%s)" % (lhs, rhs) 66 67 def makeEqualBufExpr(self, lhs, rhs, size): 68 return "(memcmp(%s, %s, %s) == 0)" % (lhs, rhs, size) 69 70 def makeEqualStringExpr(self, lhs, rhs): 71 return "(strcmp(%s, %s) == 0)" % (lhs, rhs) 72 73 def makeBothNotNullExpr(self, lhs, rhs): 74 return "(%s) && (%s)" % (lhs, rhs) 75 76 def makeBothNullExpr(self, lhs, rhs): 77 return "!(%s) && !(%s)" % (lhs, rhs) 78 79 def compareWithConsequence(self, compareExpr, vulkanType, errMsg=""): 80 self.cgen.stmt("if (!(%s)) { %s(\"%s (Error: %s)\"); }" % 81 (compareExpr, self.onFailCompareVar, 82 self.exprAccessorValueLhs(vulkanType), errMsg)) 83 84 def onCheck(self, vulkanType): 85 86 self.checked = True 87 88 accessLhs = self.exprAccessorLhs(vulkanType) 89 accessRhs = self.exprAccessorRhs(vulkanType) 90 91 bothNull = self.makeBothNullExpr(accessLhs, accessRhs) 92 bothNotNull = self.makeBothNotNullExpr(accessLhs, accessRhs) 93 nullMatchExpr = "(%s) || (%s)" % (bothNull, bothNotNull) 94 95 self.compareWithConsequence( \ 96 nullMatchExpr, 97 vulkanType, 98 "Mismatch in optional field") 99 100 skipStreamInternal = vulkanType.typeName == "void" 101 102 if skipStreamInternal: 103 return 104 105 self.cgen.beginIf("%s && %s" % (accessLhs, accessRhs)) 106 107 def endCheck(self, vulkanType): 108 109 skipStreamInternal = vulkanType.typeName == "void" 110 if skipStreamInternal: 111 return 112 113 if self.checked: 114 self.cgen.endIf() 115 self.checked = False 116 117 def onCompoundType(self, vulkanType): 118 accessLhs = self.exprAccessorLhs(vulkanType) 119 accessRhs = self.exprAccessorRhs(vulkanType) 120 121 lenAccessLhs = self.lenAccessorLhs(vulkanType) 122 lenAccessRhs = self.lenAccessorRhs(vulkanType) 123 124 lenAccessGuardLhs = self.lenAccessGuardLhs(vulkanType) 125 lenAccessGuardRhs = self.lenAccessGuardRhs(vulkanType) 126 127 needNullCheck = vulkanType.pointerIndirectionLevels > 0 128 129 if needNullCheck: 130 bothNotNullExpr = self.makeBothNotNullExpr(accessLhs, accessRhs) 131 self.cgen.beginIf(bothNotNullExpr) 132 133 if lenAccessLhs is not None: 134 equalLenExpr = self.makeEqualExpr(lenAccessLhs, lenAccessRhs) 135 136 self.compareWithConsequence( \ 137 equalLenExpr, 138 vulkanType, "Lengths not equal") 139 140 loopVar = "i" 141 accessLhs = "%s + %s" % (accessLhs, loopVar) 142 accessRhs = "%s + %s" % (accessRhs, loopVar) 143 forInit = "uint32_t %s = 0" % loopVar 144 forCond = "%s < (uint32_t)%s" % (loopVar, lenAccessLhs) 145 forIncr = "++%s" % loopVar 146 147 if needNullCheck: 148 self.cgen.beginIf(equalLenExpr) 149 150 if lenAccessGuardLhs is not None: 151 self.cgen.beginIf(lenAccessGuardLhs) 152 153 self.cgen.beginFor(forInit, forCond, forIncr) 154 155 self.cgen.funcCall(None, self.prefix + vulkanType.typeName, 156 [accessLhs, accessRhs, self.onFailCompareVar]) 157 158 if lenAccessLhs is not None: 159 self.cgen.endFor() 160 if lenAccessGuardLhs is not None: 161 self.cgen.endIf() 162 if needNullCheck: 163 self.cgen.endIf() 164 165 if needNullCheck: 166 self.cgen.endIf() 167 168 def onString(self, vulkanType): 169 accessLhs = self.exprAccessorLhs(vulkanType) 170 accessRhs = self.exprAccessorRhs(vulkanType) 171 172 bothNullExpr = self.makeBothNullExpr(accessLhs, accessRhs) 173 bothNotNullExpr = self.makeBothNotNullExpr(accessLhs, accessRhs) 174 nullMatchExpr = "(%s) || (%s)" % (bothNullExpr, bothNotNullExpr) 175 176 self.compareWithConsequence( \ 177 nullMatchExpr, 178 vulkanType, 179 "Mismatch in string pointer nullness") 180 181 self.cgen.beginIf(bothNotNullExpr) 182 183 self.compareWithConsequence( 184 self.makeEqualStringExpr(accessLhs, accessRhs), 185 vulkanType, "Unequal strings") 186 187 self.cgen.endIf() 188 189 def onStringArray(self, vulkanType): 190 accessLhs = self.exprAccessorLhs(vulkanType) 191 accessRhs = self.exprAccessorRhs(vulkanType) 192 193 lenAccessLhs = self.lenAccessorLhs(vulkanType) 194 lenAccessRhs = self.lenAccessorRhs(vulkanType) 195 196 lenAccessGuardLhs = self.lenAccessGuardLhs(vulkanType) 197 lenAccessGuardRhs = self.lenAccessGuardRhs(vulkanType) 198 199 bothNullExpr = self.makeBothNullExpr(accessLhs, accessRhs) 200 bothNotNullExpr = self.makeBothNotNullExpr(accessLhs, accessRhs) 201 nullMatchExpr = "(%s) || (%s)" % (bothNullExpr, bothNotNullExpr) 202 203 self.compareWithConsequence( \ 204 nullMatchExpr, 205 vulkanType, 206 "Mismatch in string array pointer nullness") 207 208 equalLenExpr = self.makeEqualExpr(lenAccessLhs, lenAccessRhs) 209 210 self.compareWithConsequence( \ 211 equalLenExpr, 212 vulkanType, "Lengths not equal in string array") 213 214 self.compareWithConsequence( \ 215 equalLenExpr, 216 vulkanType, "Lengths not equal in string array") 217 218 self.cgen.beginIf("%s && %s" % (equalLenExpr, bothNotNullExpr)) 219 220 loopVar = "i" 221 accessLhs = "*(%s + %s)" % (accessLhs, loopVar) 222 accessRhs = "*(%s + %s)" % (accessRhs, loopVar) 223 forInit = "uint32_t %s = 0" % loopVar 224 forCond = "%s < (uint32_t)%s" % (loopVar, lenAccessLhs) 225 forIncr = "++%s" % loopVar 226 227 if lenAccessGuardLhs is not None: 228 self.cgen.beginIf(lenAccessGuardLhs) 229 230 self.cgen.beginFor(forInit, forCond, forIncr) 231 232 self.compareWithConsequence( 233 self.makeEqualStringExpr(accessLhs, accessRhs), 234 vulkanType, "Unequal string in string array") 235 236 self.cgen.endFor() 237 238 if lenAccessGuardLhs is not None: 239 self.cgen.endIf() 240 241 self.cgen.endIf() 242 243 def onStaticArr(self, vulkanType): 244 accessLhs = self.exprAccessorLhs(vulkanType) 245 accessRhs = self.exprAccessorRhs(vulkanType) 246 247 lenAccessLhs = self.lenAccessorLhs(vulkanType) 248 249 finalLenExpr = "%s * %s" % (lenAccessLhs, 250 self.cgen.sizeofExpr(vulkanType)) 251 252 self.compareWithConsequence( 253 self.makeEqualBufExpr(accessLhs, accessRhs, finalLenExpr), 254 vulkanType, "Unequal static array") 255 256 def onStructExtension(self, vulkanType): 257 lhs = self.exprAccessorLhs(vulkanType) 258 rhs = self.exprAccessorRhs(vulkanType) 259 260 self.cgen.beginIf(lhs) 261 self.cgen.funcCall(None, self.prefix + "extension_struct", 262 [lhs, rhs, self.onFailCompareVar]) 263 self.cgen.endIf() 264 265 def onPointer(self, vulkanType): 266 accessLhs = self.exprAccessorLhs(vulkanType) 267 accessRhs = self.exprAccessorRhs(vulkanType) 268 269 skipStreamInternal = vulkanType.typeName == "void" 270 if skipStreamInternal: 271 return 272 273 lenAccessLhs = self.lenAccessorLhs(vulkanType) 274 lenAccessRhs = self.lenAccessorRhs(vulkanType) 275 276 if lenAccessLhs is not None: 277 self.compareWithConsequence( \ 278 self.makeEqualExpr(lenAccessLhs, lenAccessRhs), 279 vulkanType, "Lengths not equal") 280 281 finalLenExpr = "%s * %s" % (lenAccessLhs, 282 self.cgen.sizeofExpr( 283 vulkanType.getForValueAccess())) 284 else: 285 finalLenExpr = self.cgen.sizeofExpr(vulkanType.getForValueAccess()) 286 287 self.compareWithConsequence( 288 self.makeEqualBufExpr(accessLhs, accessRhs, finalLenExpr), 289 vulkanType, "Unequal dyn array") 290 291 def onValue(self, vulkanType): 292 accessLhs = self.exprAccessorValueLhs(vulkanType) 293 accessRhs = self.exprAccessorValueRhs(vulkanType) 294 self.compareWithConsequence( 295 self.makeEqualExpr(accessLhs, accessRhs), vulkanType, 296 "Value not equal") 297 298 299class VulkanTesting(VulkanWrapperGenerator): 300 301 def __init__(self, module, typeInfo): 302 VulkanWrapperGenerator.__init__(self, module, typeInfo) 303 304 self.codegen = CodeGen() 305 306 self.equalityCodegen = \ 307 VulkanEqualityCodegen( 308 None, 309 EQUALITY_VAR_NAMES, 310 EQUALITY_ON_FAIL_VAR, 311 API_PREFIX_EQUALITY) 312 313 self.knownDefs = {} 314 315 self.extensionTestingPrototype = \ 316 VulkanAPI(API_PREFIX_EQUALITY + "extension_struct", 317 EQUALITY_RET_TYPE, 318 [STRUCT_EXTENSION_PARAM, 319 STRUCT_EXTENSION_PARAM2, 320 EQUALITY_ON_FAIL_VAR_TYPE]) 321 322 def onBegin(self,): 323 VulkanWrapperGenerator.onBegin(self) 324 self.module.appendImpl(self.codegen.makeFuncDecl( 325 self.extensionTestingPrototype)) 326 327 def onGenType(self, typeXml, name, alias): 328 VulkanWrapperGenerator.onGenType(self, typeXml, name, alias) 329 330 if name in self.knownDefs: 331 return 332 333 category = self.typeInfo.categoryOf(name) 334 335 if category in ["struct", "union"] and alias: 336 self.module.appendHeader( 337 self.codegen.makeFuncAlias(API_PREFIX_EQUALITY + name, 338 API_PREFIX_EQUALITY + alias)) 339 340 if category in ["struct", "union"] and not alias: 341 342 structInfo = self.typeInfo.structs[name] 343 344 typeFromName = \ 345 lambda varname: makeVulkanTypeSimple(True, name, 1, varname) 346 347 compareParams = \ 348 list(map(typeFromName, EQUALITY_VAR_NAMES)) + \ 349 [EQUALITY_ON_FAIL_VAR_TYPE] 350 351 comparePrototype = \ 352 VulkanAPI(API_PREFIX_EQUALITY + name, 353 EQUALITY_RET_TYPE, 354 compareParams) 355 356 def structCompareDef(cgen): 357 self.equalityCodegen.cgen = cgen 358 for member in structInfo.members: 359 iterateVulkanType(self.typeInfo, member, 360 self.equalityCodegen) 361 362 self.module.appendHeader( 363 self.codegen.makeFuncDecl(comparePrototype)) 364 self.module.appendImpl( 365 self.codegen.makeFuncImpl(comparePrototype, structCompareDef)) 366 367 def onGenCmd(self, cmdinfo, name, alias): 368 VulkanWrapperGenerator.onGenCmd(self, cmdinfo, name, alias) 369 370 def onEnd(self,): 371 VulkanWrapperGenerator.onEnd(self) 372 373 def forEachExtensionCompare(ext, castedAccess, cgen): 374 cgen.funcCall(None, API_PREFIX_EQUALITY + ext.name, 375 [castedAccess, 376 cgen.makeReinterpretCast( 377 STRUCT_EXTENSION_PARAM2.paramName, ext.name), 378 EQUALITY_ON_FAIL_VAR]) 379 380 self.module.appendImpl( 381 self.codegen.makeFuncImpl( 382 self.extensionTestingPrototype, 383 lambda cgen: self.emitForEachStructExtension( 384 cgen, 385 EQUALITY_RET_TYPE, 386 STRUCT_EXTENSION_PARAM, 387 forEachExtensionCompare))) 388