1# mypy: allow-untyped-defs 2import functools 3import logging 4import math 5import operator 6import sympy 7import builtins 8 9from dataclasses import dataclass 10from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union 11 12import torch 13import torch.fx 14import torch.fx.traceback as fx_traceback 15 16from torch._dynamo.exc import TorchDynamoException 17from torch.fx.node import Argument, Target 18from torch.utils._sympy.interp import sympy_interp 19from torch._dynamo.utils import dynamo_timed 20 21log = logging.getLogger(__name__) 22 23try: 24 import z3 # type: ignore[import] 25 26 # Translation Validation for Dynamo guards 27 # ======================================== 28 # 29 # Checks whether optimizations applied to the collected guards are 30 # valid. In other words, whether the guard function we actually run 31 # does not have false positives (unsound). 32 # 33 # In order to do so, we build the guards using 2 different information 34 # attached to each 'SymNode': 35 # 1. SymPy expressions 36 # 2. FX nodes 37 # 38 # SymPy expressions have implicit optimizations baked within itself, 39 # which may have a few bugs. On the other hand, we build the FX graph 40 # manually, with no optimizations enabled. This gives us access to 41 # the "ground truth". 42 # 43 # We then convert into Z3 expressions both the SymPy expressions 44 # (see [Note: SympyToZ3]) that reach 'ShapeEnv.produce_guards' function 45 # and the FX nodes (see [Note: PopulateValidator]) that go through 46 # 'ShapeEnv.evaluate_expr' function. Finally, we run the validation. 47 # (see [Note: TranslationValidator]) 48 49 # Better Z3 to string implementation (for a small fraction of Z3). 50 # 51 # Here are the things we clean before showing the Z3 expression: 52 # - Rename a few ops (e.g. "Distinct" ==> "!=") 53 # 54 # - Ignore ToInt and ToReal operations: 55 # usually they don't really matter 56 # 57 # - Transform (ToInt (/ ...)) into (idiv ...): 58 # this is the pattern for floor division 59 # 60 # - Collect a chain of the same operations into one 61 def z3str(e: z3.ExprRef) -> str: 62 assert z3.is_expr(e), f"unsupported expression type: {e}" 63 64 def get_args_str(e: z3.ExprRef) -> List[str]: 65 return [z3str(e.arg(i)) for i in range(e.num_args())] 66 67 # First, we simplify the given expression. 68 # This is done using rewriting rules, so shouldn't take long. 69 e = z3.simplify(e) 70 71 72 # Only support function applications. 73 # Even Z3 "variables" are, in fact, function applications. 74 if not z3.is_app(e): 75 raise ValueError(f"can't print Z3 expression: {e}") 76 77 if z3.is_int_value(e) or z3.is_rational_value(e): 78 return e.as_string() # type: ignore[attr-defined] 79 80 decl = e.decl() 81 kind = decl.kind() 82 op = str(decl) 83 args = get_args_str(e) 84 85 if kind == z3.Z3_OP_POWER: 86 op = "pow" 87 88 elif kind in (z3.Z3_OP_ADD, z3.Z3_OP_MUL): 89 # Collect the arguments of chains of ADD and MUL. 90 # This is safe, since they are associative. 91 92 def collect_str_args(e): 93 if not (z3.is_app(e) and e.decl().kind() == kind): 94 return [z3str(e)] 95 else: 96 return [ 97 x 98 for i in range(e.num_args()) 99 for x in collect_str_args(e.arg(i)) 100 ] 101 102 args = collect_str_args(e) 103 104 elif kind == z3.Z3_OP_NOT: 105 # Revert some conversions that z3.simplify applies: 106 # - a != b ==> (Not (== a b)) ==> (!= a b) 107 # - a < b ==> (Not (<= b a)) ==> (> b a) 108 # - a > b ==> (Not (<= a b)) ==> (> a b) 109 110 assert e.num_args() == 1 111 arg = e.arg(0) 112 113 assert z3.is_app(arg) 114 argkind = arg.decl().kind() 115 116 logic_inverse = { 117 z3.Z3_OP_EQ: "!=", 118 z3.Z3_OP_LE: ">", 119 z3.Z3_OP_GE: "<", 120 } 121 122 if argkind in logic_inverse: 123 op = logic_inverse[argkind] 124 args = get_args_str(arg) 125 126 elif kind in (z3.Z3_OP_TO_INT, z3.Z3_OP_TO_REAL): 127 assert e.num_args() == 1 128 argstr = z3str(e.arg(0)) 129 130 # Check if it's the floor division pattern. 131 if argstr.startswith("(/"): 132 return "(idiv" + argstr[2:] 133 134 # Otherwise, just ignore it. 135 return argstr 136 137 elif kind == z3.Z3_OP_UNINTERPRETED: 138 assert e.num_args() == 0 139 return str(decl) 140 141 string = op + " " + " ".join(args) 142 return f"({string.rstrip()})" 143 144 # Implementation of Python semantics as Z3 expressions. 145 # 146 # Z3 Real-Int theory has operators with semantics that differ that of 147 # Python. Therefore, in order to get it right, we need to implement 148 # the (Python) semantics we are relying on in Z3. 149 @dataclass 150 class _Z3Ops: 151 # Validator used for adding assertions as needed. 152 # e.g. div(a, b) requires b != 0. 153 validator: "TranslationValidator" 154 155 # The 2 functions below are used for conditionally casting between 156 # integer and reals. 157 # 158 # Returns a real expression from 'x'. 159 @staticmethod 160 def to_real(x: z3.ArithRef) -> z3.ArithRef: 161 return x if x.is_real() else z3.ToReal(x) 162 163 # Returns an integer expression from 'x'. 164 @staticmethod 165 def to_int(x: z3.ArithRef) -> z3.ArithRef: 166 return x if x.is_int() else z3.ToInt(x) 167 168 # Implements Python division semantics. 169 def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: 170 self.validator.add_assertion(denominator != 0) # type: ignore[arg-type] 171 return _Z3Ops.to_real(numerator) / _Z3Ops.to_real(denominator) 172 173 def floor(self, number: z3.ArithRef) -> z3.ArithRef: 174 # Z3 ToInt function rounds a real number towards negative infinity. 175 return _Z3Ops.to_int(number) 176 177 # Python semantics for 'FloorDiv' states that before applying the floor 178 # function, the operands are converted to their common type. 179 def floordiv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: 180 cast_result_to_real = numerator.is_real() or denominator.is_real() 181 result = _Z3Ops.to_int(self.div(numerator, denominator)) 182 # Since the 'result' is already an integer, we just have to check 183 # whether we should cast it to real. 184 return _Z3Ops.to_real(result) if cast_result_to_real else result 185 186 def ceil(self, number: z3.ArithRef) -> z3.ArithRef: 187 return z3.If( 188 self.floor(number) < number, 189 self.floor(number + 1), 190 number 191 ) # type: ignore[return-value] 192 193 def max(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef: 194 return z3.If(a > b, a, b) # type: ignore[return-value] 195 196 def min(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef: 197 return z3.If(a < b, a, b) # type: ignore[return-value] 198 199 # Python semantics for 'Mod' is defined as: p % q = p - floordiv(p, q) * q 200 # It should work with both integer and reals. 201 def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef: 202 return p - self.floordiv(p, q) * q 203 204 def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef: 205 # Z3 can't handle complex numbers very well. 206 self.validator.add_assertion(z3.Or(base != 0, exp > 0)) # type: ignore[arg-type] 207 return base ** exp 208 209 def sqrt(self, number: z3.ArithRef) -> z3.ArithRef: 210 # Square-root: 211 # 1. Only work with reals 212 number = _Z3Ops.to_real(number) 213 # 2. The number should be positive or zero. 214 # Otherwise, Z3 returns 'unknown'. 215 self.validator.add_assertion(number >= 0) 216 return number ** 0.5 217 218 def abs(self, number: z3.ArithRef) -> z3.ArithRef: 219 return z3.Abs(number) 220 221 def round_to_int(self, number: z3.ArithRef) -> z3.ArithRef: 222 # Pythons builtin 'round' implements the 'round half to even' strategy 223 # See https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even 224 # z3 has an equivalent z3.fpRoundToIntegral(z3.RoundNearestTiesToEven(), ...), but this only applies to 225 # floating point numbers, which is different from real numbers that we are dealing with here. 226 # Instead, we implement 'round half to even' in terms of 'round half up' (floor(x + 0.5)) and 227 # 'round half down' (ceil(x - 0.5)). 228 # Assuming 'round half up' is the default case, we need to correct ..., -3.5, -1.5, 0.5, 2.5, 4.5, ... 229 # to round down, i.e. use the 'round half down' strategy 230 return z3.If( 231 self.mod(number, z3.IntVal(2)) == 0.5, 232 self.ceil(number - 0.5), 233 self.floor(number + 0.5), 234 ) 235 236 # Lifts a callable to be used in Z3. 237 # 238 # This function replaces the given 'op' by a function that: 239 # 240 # 1. Lifts the arguments into Z3 (i.e. make them inhabitants of Z3) 241 # 242 # 2. Calls an operation that corresponds to 'op', but works with Z3 243 # inhabitants (left as is if it works as is) 244 def z3op(op: Callable, validator: "TranslationValidator") -> Callable: 245 # Operations that have booleans as their argument. 246 # This is needed because the argument of some FX nodes were 247 # literal integers, instead of booleans. So, whenever this flag 248 # is set, we also convert ints to booleans. 249 boolean_ops = {operator.not_, operator.and_, operator.or_} 250 as_bool = op in boolean_ops 251 252 # Lifts the function into 'z3.ExprRef' domain. 253 def lift(func): 254 def wrap(a) -> z3.ExprRef: 255 if isinstance(a, (z3.ArithRef, z3.BoolRef)): 256 return a 257 # Convert it into a Z3 value, if it is some of the supported 258 # types below. 259 if isinstance(a, bool) or (as_bool and isinstance(a, int)): 260 return z3.BoolVal(bool(a)) 261 if isinstance(a, (int, sympy.Integer)): 262 return z3.IntVal(int(a)) 263 if isinstance(a, (float, sympy.Float)): 264 return z3.RealVal(float(a)) 265 raise ValueError(f"can't lift type: {type(a)}") 266 267 @functools.wraps(func) 268 def wrapper(*args): 269 # Lifts the arguments into a list of Z3 inhabitants. 270 wrapped_args = (wrap(a) for a in args) 271 # Run the function on the Z3 expressions. 272 return func(*wrapped_args) 273 274 return wrapper 275 276 ops = _Z3Ops(validator) 277 replacement_map = { 278 # Operator module. 279 operator.not_: lift(z3.Not), 280 operator.and_: lift(z3.And), 281 operator.or_: lift(z3.Or), 282 operator.floordiv: lift(ops.floordiv), 283 operator.truediv: lift(ops.div), 284 operator.mod: lift(ops.mod), 285 operator.abs: lift(ops.abs), 286 builtins.round: lift(ops.round_to_int), 287 288 # Math module. 289 math.ceil: lift(ops.ceil), 290 math.floor: lift(ops.floor), 291 292 # Torch module. 293 torch.sym_float: lift(ops.to_real), 294 torch.sym_max: lift(ops.max), 295 torch.sym_min: lift(ops.min), 296 torch.sym_ite: lift(lambda b, t, f: t if b else f), 297 torch._sym_sqrt: lift(ops.sqrt), # type: ignore[attr-defined] 298 # Not lifted because we only use this function as a 299 # marker for adding the expression as validator input. 300 torch._assert: torch._assert, 301 } 302 return replacement_map[op] if op in replacement_map else lift(op) 303 304 # Processes an FX graph, populating the given validator. 305 # 306 # [Note: PopulateValidator] 307 # This class walks through each node in the FX graph, translating 308 # them into the Z3 world. 309 # 310 # Then, whenever it finds an 'torch._assert' call_function operation, 311 # it adds the Z3 expression corresponding to the argument as validator 312 # input. 313 class PopulateValidator(torch.fx.Interpreter): 314 def __init__(self, graph: torch.fx.Graph, validator: "TranslationValidator"): 315 # Reference to the translation validator. 316 self.validator = validator 317 318 # Build the graph module and call `Interpreter` constructor. 319 module = torch.fx.GraphModule(root={}, graph=graph) 320 super().__init__(module, garbage_collect_values=True) 321 322 def placeholder(self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: 323 symbol = fx_traceback.get_current_meta()["symbol"] 324 return self.validator.z3var(symbol) 325 326 def call_function(self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: 327 if target != torch._assert: 328 # Lift and runs the node target function 329 return super().call_function(z3op(target, self.validator), args, kwargs) # type: ignore[arg-type] 330 # Adds the Z3 expression corresponding to the first argument 331 # as a validator input. 332 assert len(args) == 1, f"expected 1 argument on assertion. Got: {len(args)} " 333 self.validator.add_source_expr(args[0]) # type: ignore[arg-type] 334 335 # Translates SymPy expressions into Z3 expressions. 336 # 337 # [Note: SympyToZ3] 338 # At the time of the translation, all free variables present in the 339 # SymPy expression being translated must be already mapped to a Z3 340 # integer variable. 341 class SympyToZ3: 342 OPERATOR_HANDLES = {"add", "mul", "eq", "ne", "lt", "gt", "le", "ge"} 343 344 def __init__( 345 self, 346 validator: "TranslationValidator", 347 ) -> None: 348 self._validator = validator 349 self._ops = _Z3Ops(self._validator) 350 351 def constant(self, value: Any, dtype: torch.dtype) -> z3.ExprRef: 352 # TODO: Probably OK to relax this and allow lower precision 353 if dtype is torch.int64: 354 return z3.IntVal(int(value)) 355 if dtype is torch.double: 356 return z3.RealVal(float(value)) 357 if dtype is torch.bool: 358 return z3.BoolVal(bool(value)) 359 raise ValueError(f"unsupported dtype (SympyToZ3): {dtype}") 360 361 def to_dtype(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: 362 if dtype == torch.float64: 363 return z3.ToReal(x) 364 raise NotImplementedError(f"to_dtype {dtype} NYI") 365 366 def trunc_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: 367 return z3.ToInt(x) 368 369 def round_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: 370 return self._ops.round_to_int(x) 371 372 def int_truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: 373 return self._ops.div(numerator, denominator) 374 375 def truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: 376 return self._ops.div(numerator, denominator) 377 378 def floordiv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: 379 return self._ops.floordiv(numerator, denominator) 380 381 def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: 382 return self._ops.floordiv(numerator, denominator) 383 384 def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef: 385 return self._ops.pow(base, exp) 386 387 def pow_by_natural(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef: 388 return self._ops.pow(base, exp) 389 390 def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef: 391 return self._ops.mod(p, q) 392 393 def ceil_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: 394 return self._ops.ceil(x) 395 396 def floor_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: 397 return self._ops.floor(x) 398 399 def __getattr__(self, name: str) -> Any: 400 REPLACEMENT = { 401 "and_": z3.And, 402 "or_": z3.Or, 403 "not_": z3.Not, 404 "floor": self._ops.floor, 405 "ceil": self._ops.ceil, 406 "minimum": self._ops.min, 407 "maximum": self._ops.max, 408 } 409 410 if name in REPLACEMENT: 411 return REPLACEMENT[name] 412 if name in self.OPERATOR_HANDLES: 413 return getattr(operator, name) 414 raise AttributeError(f"unhandled operator: {name}") 415 416 def run(self, expr: sympy.Basic) -> z3.ExprRef: 417 return sympy_interp(self, self._validator.symbols, expr) # type: ignore[arg-type] 418 419 # Dynamo guards translation validator. 420 # 421 # [Note: TranslationValidator] 422 # Verifies whether the guards issued by 'ShapeEnv.produce_guards' are sound. 423 # That is: whether those (target) guards only yield TRUE whenever the original, 424 # unoptimized, (source) guards yield TRUE. 425 # 426 # More concretely, given 'source' and 'target' guard expressions, we wish to 427 # check whether the following expression holds: 428 # 429 # Not(And(source)) AND And(target) 430 # 431 # i.e. whether there is an assignment of the free variables where the opposite 432 # happens: target is TRUE, but source is FALSE. 433 class TranslationValidator: 434 def __init__(self) -> None: 435 log.debug("new instance") 436 437 # Mapping of SymPy symbols to Z3 variables. 438 self.symbols: Dict[sympy.Symbol, z3.ExprRef] = {} 439 440 # Set of source Z3 expressions. 441 # They represent the generated guards without any kind of 442 # simplification or transformation. 443 self._source_exprs: Set[z3.BoolRef] = set() 444 445 # Set of target Z3 expressions. 446 # They represent the actual checked guards at runtime. They might 447 # be simplified or transformed versions of the source guards. 448 self._target_exprs: Set[z3.BoolRef] = set() 449 450 # Set of Z3 expressions representing assertions over both the 451 # source and target expressions. 452 self._assertions: Set[z3.BoolRef] = set() 453 454 # Retrieves the corresponding Z3 variable. 455 def z3var(self, symbol: sympy.Symbol) -> z3.ExprRef: 456 assert symbol in self.symbols, f"Z3 variable not found for: {symbol}" 457 return self.symbols[symbol] 458 459 # Create a variable in Z3 of 'type' for 'symbol', if it doesn't already exists. 460 def add_var(self, symbol: sympy.Symbol, type: Type) -> z3.ExprRef: 461 if symbol in self.symbols: 462 return self.symbols[symbol] 463 464 log.debug("new variable: %s (%s)", symbol.name, type.__name__) 465 466 if type is int: 467 var = z3.Int(symbol.name) 468 469 # If 'symbol' is positive (SymPy assumption), we have to 470 # convey it to Z3 as well. 471 if symbol.is_positive: # type: ignore[attr-defined] 472 self._target_exprs.add(var > 0) 473 elif type is float: 474 var = z3.Real(symbol.name) 475 elif type is bool: 476 var = z3.Bool(symbol.name) 477 else: 478 raise RuntimeError(f"unsupported type for Z3 variable: {type}") 479 480 self.symbols[symbol] = var 481 return var 482 483 # Checks whether all symbols were already added. 484 def _check_freesymbols(self, e: sympy.Basic) -> None: 485 for s in e.free_symbols: 486 assert isinstance(s, sympy.Symbol) 487 # Call 'z3var' just to check whether there's already a 488 # Z3 variable corresponding to 's'. 489 self.z3var(s) 490 491 492 def to_z3_boolean_expr(self, e: sympy.Basic) -> z3.BoolRef: 493 z3expr = SympyToZ3(self).run(e) 494 assert isinstance(z3expr, z3.BoolRef), f"expected boolean expression. Got: {z3expr}" 495 return z3expr 496 497 def add_source_expr(self, e: z3.BoolRef) -> None: 498 if e not in self._source_exprs: 499 log.debug("add source guard: %s", z3str(e)) 500 self._source_exprs.add(e) 501 502 def add_target_expr(self, e: sympy.Expr) -> None: 503 self._check_freesymbols(e) 504 z3expr = self.to_z3_boolean_expr(e) 505 if e not in self._target_exprs: 506 log.debug("add target guard: %s", z3str(z3expr)) 507 self._target_exprs.add(z3expr) 508 509 def add_assertion(self, e: Union[z3.BoolRef, sympy.Basic]) -> None: 510 if isinstance(e, sympy.Basic): 511 self._check_freesymbols(e) 512 ref = self.to_z3_boolean_expr(e) 513 else: 514 ref = e 515 assert isinstance(ref, z3.BoolRef) 516 if ref not in self._assertions: 517 log.debug("add assertion: %s", z3str(ref)) 518 self._assertions.add(ref) 519 520 def validate(self) -> None: 521 with dynamo_timed("TranslationValidator.validate"): 522 return self._validate() 523 524 def _validate(self) -> None: 525 if len(self._source_exprs) == 0 or len(self._target_exprs) == 0: 526 # If there are no source/target expressions, there's nothing we really 527 # wish to prove. So, we just return. 528 return None 529 530 # Here, we use "QF_NRA" logic for the solver: 531 # "Quantifier-free Non-linear Real Arithmetic". 532 # 533 # Most of the guards expressions have: 534 # 1. arithmetic between integer and reals 535 # 2. no quantifiers 536 # 3. potentially non-linear. 537 # 538 # Although there's also "QF_NIRA" (mixed integer-real arithmetic), 539 # "QF_NRA" seems to work better on 'dynamo/test_dynamic_shapes.py'. 540 solver = z3.SolverFor("QF_NRA") 541 # Set a timeout for finding a solution. 542 solver.set(timeout=translation_validation_timeout()) 543 544 # Add all the assertions to the solver. 545 for assertion in self._assertions: 546 solver.add(assertion) 547 548 # "Is there any case where it's TRUE for the target expressions, 549 # but FALSE for the source expressions?" 550 solver.add(z3.Not(z3.And(*self._source_exprs))) 551 solver.add(*self._target_exprs) 552 553 log.debug("translation validation: start") 554 r = solver.check() 555 if r == z3.sat: 556 # Target expressions are unsound. 557 # Log the found model and the source expressions that failed. 558 model = solver.model() 559 raise ValidationException( 560 model, self._assertions, self._target_exprs, 561 failed_source_exprs=[ 562 inp for inp in self._source_exprs if not model.evaluate(inp) 563 ] 564 ) 565 else: 566 if r == z3.unknown: 567 # Could not find a solution. It didn't fail, but it also 568 # didn't succeed. Canceling the validation execution (keyboard 569 # interrupt) also gets to this branch. 570 log.warning("translation validation: could not validate: got z3.unknown") 571 else: 572 # Target expressions are sound. 573 assert r == z3.unsat 574 log.debug("translation validation: success") 575 576except ImportError: 577 _HAS_Z3 = False 578 579 __all__ = [ 580 "translation_validation_enabled", "translation_validation_timeout", 581 "ValidationException", "BisectValidationException", 582 ] 583 584else: 585 _HAS_Z3 = True 586 587 __all__ = [ 588 "z3str", "z3op", "PopulateValidator", "SympyToZ3", "TranslationValidator", 589 "translation_validation_enabled", "translation_validation_timeout", 590 "ValidationException", "BisectValidationException", 591 ] 592 593from torch.fx.experimental import _config as config 594 595def translation_validation_enabled() -> bool: 596 # Checks everytime this function is called, in case the Dynamo 597 # option is set, but Z3 is not installed. 598 _assert_z3_installed_if_tv_set() 599 return _HAS_Z3 and config.translation_validation 600 601 602def translation_validation_timeout() -> int: 603 return config.translation_validation_timeout 604 605 606def _assert_z3_installed_if_tv_set(): 607 assert _HAS_Z3 or not config.translation_validation, ( 608 "translation validation requires Z3 package. Please, either install " 609 "z3-solver or disable translation validation." 610 ) 611 612 613class ValidationException(TorchDynamoException): 614 def __init__(self, model, assertions, target_exprs, failed_source_exprs): 615 assert _HAS_Z3 616 617 def symbolstr(sym) -> str: 618 return f"{sym}: {model[sym]}" 619 620 def joinlines(xs) -> str: 621 return "\n".join(f" ==> {x}" for x in xs) 622 623 model_str = joinlines(sorted(map(symbolstr, model))) 624 assertions_str = joinlines(sorted(map(z3str, assertions))) 625 target_exprs_str = joinlines(sorted(map(z3str, target_exprs))) 626 failed_source_exprs_str = joinlines(sorted(map(z3str, failed_source_exprs))) 627 628 self.msg = "translation validation failed." 629 self.details = f"""\ 630Model: 631{model_str} 632 633Assertions: 634{assertions_str} 635 636Target Expressions: 637{target_exprs_str} 638 639Failed Source Expressions: 640{failed_source_exprs_str}""" 641 642 def __str__(self): 643 return f"{self.msg}\n\n{self.details}" 644 645 646class BisectValidationException(TorchDynamoException): 647 def __init__(self, validation_exc, expr, failed_action, traced_node): 648 self.msg = f"translation validation failed when {failed_action}: {expr}" 649 self.details = f"""\ 650Failure occurred while running node: 651 {traced_node.format_node()} 652 653{validation_exc.details}""" 654 655 def __str__(self): 656 return f"{self.msg}\n\n{self.details}" 657 658# Checks when this module is loaded. 659_assert_z3_installed_if_tv_set() 660 661# Translation validation bisection. 662# 663# Bisect into the torch._assert nodes recorded in the shape_env FX graph, and raise 664# the earliest ValidationException. 665# 666# As guards are added by ShapeEnv.evaluate_expr calls, some simplification errors 667# might be silently happening. This function tries to nail down exactly at which 668# point things went wrong from a validation perspective. 669def bisect(shape_env): 670 from torch.fx.experimental.symbolic_shapes import ShapeEnv, SHAPEENV_EVENT_KEY, CURRENT_NODE_KEY 671 from torch.fx.experimental.recording import FakeTensorMeta, ShapeEnvEvent, replay_shape_env_events 672 673 events = shape_env.events 674 675 # Retrieves the ShapeEnvEvent associated with node. 676 def get_node_event(node: torch.fx.Node) -> ShapeEnvEvent: 677 assert SHAPEENV_EVENT_KEY in node.meta 678 return events[node.meta[SHAPEENV_EVENT_KEY]] 679 680 # Creates a new instance of fake, but updating every symbolic value's ShapeEnv 681 # reference to the one given as argument. 682 # 683 # This is needed so as not to simplify a symbolic expression using a ShapeEnv 684 # "from the future", where it may have a different set of replacements. 685 def new_with_shape_env(shape_env: ShapeEnv, fake) -> Any: 686 if isinstance(fake, int): 687 return fake 688 if isinstance(fake, torch.SymInt): 689 return torch.SymInt(fake.node.with_shape_env(shape_env)) 690 assert isinstance(fake, FakeTensorMeta) 691 return FakeTensorMeta( 692 tuple(new_with_shape_env(shape_env, s) for s in fake.size()), 693 tuple(new_with_shape_env(shape_env, s) for s in fake.stride()), 694 new_with_shape_env(shape_env, fake.storage_offset()), 695 fake.is_nested, 696 ) 697 698 # Checks whether the given shape_env fails when produce_guards is called. 699 def check_shapeenv_fails(shape_env: ShapeEnv, tracked_fakes: Optional[List[Any]]) -> Optional[ValidationException]: 700 assert tracked_fakes is not None 701 try: 702 # This produce_guards call is a best-effort replication, since we 703 # don't populate EqualityConstraint list. Reason: we would also have 704 # to save OutputGraph.tracked_fakes_id_to_source. 705 shape_env.produce_guards( 706 [new_with_shape_env(shape_env, a.fake) for a in tracked_fakes], 707 [a.source for a in tracked_fakes], 708 input_contexts=[a.symbolic_context for a in tracked_fakes], 709 ) 710 return None 711 except ValidationException as e: 712 return e 713 714 # Checks whether the ShapeEnv reconstructed by replaying the events until 715 # node is created fails when produce_guards is called. 716 def check_node_fails(node: torch.fx.Node) -> Optional[ValidationException]: 717 number = node.meta[SHAPEENV_EVENT_KEY] 718 # Reconstruct shape_env until the event at event_number. 719 shape_env = replay_shape_env_events(events[:number + 1]) 720 shape_env.graph.lint() 721 return check_shapeenv_fails(shape_env, events[number].tracked_fakes) 722 723 last_exception = check_shapeenv_fails(shape_env, shape_env._snapshot_tracked_fakes()) 724 725 if not last_exception: 726 # We don't actually fail due to a produce_guards call. 727 # Stop and don't bisect. 728 log.info("translation validation succeeded: no errors found.") 729 return 730 731 if not shape_env.should_record_events or config.translation_validation_no_bisect: 732 # Bisection is off. 733 # Return the last ValidationException we got. 734 raise last_exception 735 736 # Cache the raised exception (if any) at each bisection point. 737 exception = {} 738 739 # Bisection happens on the assertion nodes of the recorded FX graph for 740 # dynamic shapes. 741 assert_nodes = [node for node in shape_env.graph.nodes if node.target == torch._assert] 742 743 # Preparing the indices for binary search. 744 left, mid, right = 0, 0, len(assert_nodes) - 1 745 746 while left < right: 747 mid = (left + right) // 2 748 749 node = assert_nodes[mid] 750 log.debug("bisecting at %s: %s", mid, get_node_event(node)) 751 752 # Check whether the new shape_env raises a ValidationException or not. 753 exception[mid] = check_node_fails(node) 754 755 if exception[mid]: 756 right = mid 757 else: 758 left = mid + 1 759 760 assert left in exception and isinstance(exception[left], ValidationException) 761 762 node = assert_nodes[left] 763 event = get_node_event(node) 764 765 if event.is_evaluate_expr(): 766 failed_action = "evaluating" 767 else: 768 assert event.is_defer_runtime_assert(), f"unexpected event type: {event}" 769 failed_action = "adding runtime assert" 770 771 args = event.args 772 assert args is not None 773 assert len(args) >= 2, ( 774 f"bisecting expects {event.name} to have at least 2 positional arguments. " 775 f"Got: {len(args)}" 776 ) 777 assert isinstance(args[1], sympy.Basic), ( 778 f"bisecting expects {event.name} to have a SymPy expression as its second argument. " 779 f"Got: {type(args[1])}" 780 ) 781 782 raise BisectValidationException( 783 exception[left], 784 expr=args[1], 785 failed_action=failed_action, 786 traced_node=node.meta[CURRENT_NODE_KEY], 787 ) 788