xref: /aosp_15_r20/external/pytorch/torch/fx/experimental/validator.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3import logging
4import math
5import operator
6import sympy
7import builtins
8
9from dataclasses import dataclass
10from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
11
12import torch
13import torch.fx
14import torch.fx.traceback as fx_traceback
15
16from torch._dynamo.exc import TorchDynamoException
17from torch.fx.node import Argument, Target
18from torch.utils._sympy.interp import sympy_interp
19from torch._dynamo.utils import dynamo_timed
20
21log = logging.getLogger(__name__)
22
23try:
24    import z3  # type: ignore[import]
25
26    # Translation Validation for Dynamo guards
27    # ========================================
28    #
29    # Checks whether optimizations applied to the collected guards are
30    # valid. In other words, whether the guard function we actually run
31    # does not have false positives (unsound).
32    #
33    # In order to do so, we build the guards using 2 different information
34    # attached to each 'SymNode':
35    #   1. SymPy expressions
36    #   2. FX nodes
37    #
38    # SymPy expressions have implicit optimizations baked within itself,
39    # which may have a few bugs. On the other hand, we build the FX graph
40    # manually, with no optimizations enabled. This gives us access to
41    # the "ground truth".
42    #
43    # We then convert into Z3 expressions both the SymPy expressions
44    # (see [Note: SympyToZ3]) that reach 'ShapeEnv.produce_guards' function
45    # and the FX nodes (see [Note: PopulateValidator]) that go through
46    # 'ShapeEnv.evaluate_expr' function. Finally, we run the validation.
47    # (see [Note: TranslationValidator])
48
49    # Better Z3 to string implementation (for a small fraction of Z3).
50    #
51    # Here are the things we clean before showing the Z3 expression:
52    #   - Rename a few ops (e.g. "Distinct" ==> "!=")
53    #
54    #   - Ignore ToInt and ToReal operations:
55    #     usually they don't really matter
56    #
57    #   - Transform (ToInt (/ ...)) into (idiv ...):
58    #     this is the pattern for floor division
59    #
60    #   - Collect a chain of the same operations into one
61    def z3str(e: z3.ExprRef) -> str:
62        assert z3.is_expr(e), f"unsupported expression type: {e}"
63
64        def get_args_str(e: z3.ExprRef) -> List[str]:
65            return [z3str(e.arg(i)) for i in range(e.num_args())]
66
67        # First, we simplify the given expression.
68        # This is done using rewriting rules, so shouldn't take long.
69        e = z3.simplify(e)
70
71
72        # Only support function applications.
73        # Even Z3 "variables" are, in fact, function applications.
74        if not z3.is_app(e):
75            raise ValueError(f"can't print Z3 expression: {e}")
76
77        if z3.is_int_value(e) or z3.is_rational_value(e):
78            return e.as_string()  # type: ignore[attr-defined]
79
80        decl = e.decl()
81        kind = decl.kind()
82        op = str(decl)
83        args = get_args_str(e)
84
85        if kind == z3.Z3_OP_POWER:
86            op = "pow"
87
88        elif kind in (z3.Z3_OP_ADD, z3.Z3_OP_MUL):
89            # Collect the arguments of chains of ADD and MUL.
90            # This is safe, since they are associative.
91
92            def collect_str_args(e):
93                if not (z3.is_app(e) and e.decl().kind() == kind):
94                    return [z3str(e)]
95                else:
96                    return [
97                        x
98                        for i in range(e.num_args())
99                        for x in collect_str_args(e.arg(i))
100                    ]
101
102            args = collect_str_args(e)
103
104        elif kind == z3.Z3_OP_NOT:
105            # Revert some conversions that z3.simplify applies:
106            #   - a != b ==> (Not (== a b)) ==> (!= a b)
107            #   - a < b ==> (Not (<= b a)) ==> (> b a)
108            #   - a > b ==> (Not (<= a b)) ==> (> a b)
109
110            assert e.num_args() == 1
111            arg = e.arg(0)
112
113            assert z3.is_app(arg)
114            argkind = arg.decl().kind()
115
116            logic_inverse = {
117                z3.Z3_OP_EQ: "!=",
118                z3.Z3_OP_LE: ">",
119                z3.Z3_OP_GE: "<",
120            }
121
122            if argkind in logic_inverse:
123                op = logic_inverse[argkind]
124                args = get_args_str(arg)
125
126        elif kind in (z3.Z3_OP_TO_INT, z3.Z3_OP_TO_REAL):
127            assert e.num_args() == 1
128            argstr = z3str(e.arg(0))
129
130            # Check if it's the floor division pattern.
131            if argstr.startswith("(/"):
132                return "(idiv" + argstr[2:]
133
134            # Otherwise, just ignore it.
135            return argstr
136
137        elif kind == z3.Z3_OP_UNINTERPRETED:
138            assert e.num_args() == 0
139            return str(decl)
140
141        string = op + " " + " ".join(args)
142        return f"({string.rstrip()})"
143
144    # Implementation of Python semantics as Z3 expressions.
145    #
146    # Z3 Real-Int theory has operators with semantics that differ that of
147    # Python. Therefore, in order to get it right, we need to implement
148    # the (Python) semantics we are relying on in Z3.
149    @dataclass
150    class _Z3Ops:
151        # Validator used for adding assertions as needed.
152        # e.g. div(a, b) requires b != 0.
153        validator: "TranslationValidator"
154
155        # The 2 functions below are used for conditionally casting between
156        # integer and reals.
157        #
158        # Returns a real expression from 'x'.
159        @staticmethod
160        def to_real(x: z3.ArithRef) -> z3.ArithRef:
161            return x if x.is_real() else z3.ToReal(x)
162
163        # Returns an integer expression from 'x'.
164        @staticmethod
165        def to_int(x: z3.ArithRef) -> z3.ArithRef:
166            return x if x.is_int() else z3.ToInt(x)
167
168        # Implements Python division semantics.
169        def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
170            self.validator.add_assertion(denominator != 0)  # type: ignore[arg-type]
171            return _Z3Ops.to_real(numerator) / _Z3Ops.to_real(denominator)
172
173        def floor(self, number: z3.ArithRef) -> z3.ArithRef:
174            # Z3 ToInt function rounds a real number towards negative infinity.
175            return _Z3Ops.to_int(number)
176
177        # Python semantics for 'FloorDiv' states that before applying the floor
178        # function, the operands are converted to their common type.
179        def floordiv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
180            cast_result_to_real = numerator.is_real() or denominator.is_real()
181            result = _Z3Ops.to_int(self.div(numerator, denominator))
182            # Since the 'result' is already an integer, we just have to check
183            # whether we should cast it to real.
184            return _Z3Ops.to_real(result) if cast_result_to_real else result
185
186        def ceil(self, number: z3.ArithRef) -> z3.ArithRef:
187            return z3.If(
188                self.floor(number) < number,
189                self.floor(number + 1),
190                number
191            )  # type: ignore[return-value]
192
193        def max(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef:
194            return z3.If(a > b, a, b)  # type: ignore[return-value]
195
196        def min(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef:
197            return z3.If(a < b, a, b)  # type: ignore[return-value]
198
199        # Python semantics for 'Mod' is defined as: p % q = p - floordiv(p, q) * q
200        # It should work with both integer and reals.
201        def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef:
202            return p - self.floordiv(p, q) * q
203
204        def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef:
205            # Z3 can't handle complex numbers very well.
206            self.validator.add_assertion(z3.Or(base != 0, exp > 0))  # type: ignore[arg-type]
207            return base ** exp
208
209        def sqrt(self, number: z3.ArithRef) -> z3.ArithRef:
210            # Square-root:
211            # 1. Only work with reals
212            number = _Z3Ops.to_real(number)
213            # 2. The number should be positive or zero.
214            #    Otherwise, Z3 returns 'unknown'.
215            self.validator.add_assertion(number >= 0)
216            return number ** 0.5
217
218        def abs(self, number: z3.ArithRef) -> z3.ArithRef:
219            return z3.Abs(number)
220
221        def round_to_int(self, number: z3.ArithRef) -> z3.ArithRef:
222            # Pythons builtin 'round' implements the 'round half to even' strategy
223            # See https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even
224            # z3 has an equivalent z3.fpRoundToIntegral(z3.RoundNearestTiesToEven(), ...), but this only applies to
225            # floating point numbers, which is different from real numbers that we are dealing with here.
226            # Instead, we implement 'round half to even' in terms of 'round half up' (floor(x + 0.5)) and
227            # 'round half down' (ceil(x - 0.5)).
228            # Assuming 'round half up' is the default case, we need to correct ..., -3.5, -1.5, 0.5, 2.5, 4.5, ...
229            # to round down, i.e. use the 'round half down' strategy
230            return z3.If(
231                self.mod(number, z3.IntVal(2)) == 0.5,
232                self.ceil(number - 0.5),
233                self.floor(number + 0.5),
234            )
235
236    # Lifts a callable to be used in Z3.
237    #
238    # This function replaces the given 'op' by a function that:
239    #
240    #   1. Lifts the arguments into Z3 (i.e. make them inhabitants of Z3)
241    #
242    #   2. Calls an operation that corresponds to 'op', but works with Z3
243    #      inhabitants (left as is if it works as is)
244    def z3op(op: Callable, validator: "TranslationValidator") -> Callable:
245        # Operations that have booleans as their argument.
246        # This is needed because the argument of some FX nodes were
247        # literal integers, instead of booleans. So, whenever this flag
248        # is set, we also convert ints to booleans.
249        boolean_ops = {operator.not_, operator.and_, operator.or_}
250        as_bool = op in boolean_ops
251
252        # Lifts the function into 'z3.ExprRef' domain.
253        def lift(func):
254            def wrap(a) -> z3.ExprRef:
255                if isinstance(a, (z3.ArithRef, z3.BoolRef)):
256                    return a
257                # Convert it into a Z3 value, if it is some of the supported
258                # types below.
259                if isinstance(a, bool) or (as_bool and isinstance(a, int)):
260                    return z3.BoolVal(bool(a))
261                if isinstance(a, (int, sympy.Integer)):
262                    return z3.IntVal(int(a))
263                if isinstance(a, (float, sympy.Float)):
264                    return z3.RealVal(float(a))
265                raise ValueError(f"can't lift type: {type(a)}")
266
267            @functools.wraps(func)
268            def wrapper(*args):
269                # Lifts the arguments into a list of Z3 inhabitants.
270                wrapped_args = (wrap(a) for a in args)
271                # Run the function on the Z3 expressions.
272                return func(*wrapped_args)
273
274            return wrapper
275
276        ops = _Z3Ops(validator)
277        replacement_map = {
278            # Operator module.
279            operator.not_: lift(z3.Not),
280            operator.and_: lift(z3.And),
281            operator.or_: lift(z3.Or),
282            operator.floordiv: lift(ops.floordiv),
283            operator.truediv: lift(ops.div),
284            operator.mod: lift(ops.mod),
285            operator.abs: lift(ops.abs),
286            builtins.round: lift(ops.round_to_int),
287
288            # Math module.
289            math.ceil: lift(ops.ceil),
290            math.floor: lift(ops.floor),
291
292            # Torch module.
293            torch.sym_float: lift(ops.to_real),
294            torch.sym_max: lift(ops.max),
295            torch.sym_min: lift(ops.min),
296            torch.sym_ite: lift(lambda b, t, f: t if b else f),
297            torch._sym_sqrt: lift(ops.sqrt),  # type: ignore[attr-defined]
298            # Not lifted because we only use this function as a
299            # marker for adding the expression as validator input.
300            torch._assert: torch._assert,
301        }
302        return replacement_map[op] if op in replacement_map else lift(op)
303
304    # Processes an FX graph, populating the given validator.
305    #
306    # [Note: PopulateValidator]
307    # This class walks through each node in the FX graph, translating
308    # them into the Z3 world.
309    #
310    # Then, whenever it finds an 'torch._assert' call_function operation,
311    # it adds the Z3 expression corresponding to the argument as validator
312    # input.
313    class PopulateValidator(torch.fx.Interpreter):
314        def __init__(self, graph: torch.fx.Graph, validator: "TranslationValidator"):
315            # Reference to the translation validator.
316            self.validator = validator
317
318            # Build the graph module and call `Interpreter` constructor.
319            module = torch.fx.GraphModule(root={}, graph=graph)
320            super().__init__(module, garbage_collect_values=True)
321
322        def placeholder(self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
323            symbol = fx_traceback.get_current_meta()["symbol"]
324            return self.validator.z3var(symbol)
325
326        def call_function(self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
327            if target != torch._assert:
328                # Lift and runs the node target function
329                return super().call_function(z3op(target, self.validator), args, kwargs)  # type: ignore[arg-type]
330            # Adds the Z3 expression corresponding to the first argument
331            # as a validator input.
332            assert len(args) == 1, f"expected 1 argument on assertion. Got: {len(args)} "
333            self.validator.add_source_expr(args[0])  # type: ignore[arg-type]
334
335    # Translates SymPy expressions into Z3 expressions.
336    #
337    # [Note: SympyToZ3]
338    # At the time of the translation, all free variables present in the
339    # SymPy expression being translated must be already mapped to a Z3
340    # integer variable.
341    class SympyToZ3:
342        OPERATOR_HANDLES = {"add", "mul", "eq", "ne", "lt", "gt", "le", "ge"}
343
344        def __init__(
345                self,
346                validator: "TranslationValidator",
347        ) -> None:
348            self._validator = validator
349            self._ops = _Z3Ops(self._validator)
350
351        def constant(self, value: Any, dtype: torch.dtype) -> z3.ExprRef:
352            # TODO: Probably OK to relax this and allow lower precision
353            if dtype is torch.int64:
354                return z3.IntVal(int(value))
355            if dtype is torch.double:
356                return z3.RealVal(float(value))
357            if dtype is torch.bool:
358                return z3.BoolVal(bool(value))
359            raise ValueError(f"unsupported dtype (SympyToZ3): {dtype}")
360
361        def to_dtype(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
362            if dtype == torch.float64:
363                return z3.ToReal(x)
364            raise NotImplementedError(f"to_dtype {dtype} NYI")
365
366        def trunc_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
367            return z3.ToInt(x)
368
369        def round_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
370            return self._ops.round_to_int(x)
371
372        def int_truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
373            return self._ops.div(numerator, denominator)
374
375        def truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
376            return self._ops.div(numerator, denominator)
377
378        def floordiv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
379            return self._ops.floordiv(numerator, denominator)
380
381        def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
382            return self._ops.floordiv(numerator, denominator)
383
384        def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef:
385            return self._ops.pow(base, exp)
386
387        def pow_by_natural(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef:
388            return self._ops.pow(base, exp)
389
390        def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef:
391            return self._ops.mod(p, q)
392
393        def ceil_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
394            return self._ops.ceil(x)
395
396        def floor_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
397            return self._ops.floor(x)
398
399        def __getattr__(self, name: str) -> Any:
400            REPLACEMENT = {
401                "and_": z3.And,
402                "or_": z3.Or,
403                "not_": z3.Not,
404                "floor": self._ops.floor,
405                "ceil": self._ops.ceil,
406                "minimum": self._ops.min,
407                "maximum": self._ops.max,
408            }
409
410            if name in REPLACEMENT:
411                return REPLACEMENT[name]
412            if name in self.OPERATOR_HANDLES:
413                return getattr(operator, name)
414            raise AttributeError(f"unhandled operator: {name}")
415
416        def run(self, expr: sympy.Basic) -> z3.ExprRef:
417            return sympy_interp(self, self._validator.symbols, expr)  # type: ignore[arg-type]
418
419    # Dynamo guards translation validator.
420    #
421    # [Note: TranslationValidator]
422    # Verifies whether the guards issued by 'ShapeEnv.produce_guards' are sound.
423    # That is: whether those (target) guards only yield TRUE whenever the original,
424    # unoptimized, (source) guards yield TRUE.
425    #
426    # More concretely, given 'source' and 'target' guard expressions, we wish to
427    # check whether the following expression holds:
428    #
429    # Not(And(source)) AND And(target)
430    #
431    # i.e. whether there is an assignment of the free variables where the opposite
432    # happens: target is TRUE, but source is FALSE.
433    class TranslationValidator:
434        def __init__(self) -> None:
435            log.debug("new instance")
436
437            # Mapping of SymPy symbols to Z3 variables.
438            self.symbols: Dict[sympy.Symbol, z3.ExprRef] = {}
439
440            # Set of source Z3 expressions.
441            # They represent the generated guards without any kind of
442            # simplification or transformation.
443            self._source_exprs: Set[z3.BoolRef] = set()
444
445            # Set of target Z3 expressions.
446            # They represent the actual checked guards at runtime. They might
447            # be simplified or transformed versions of the source guards.
448            self._target_exprs: Set[z3.BoolRef] = set()
449
450            # Set of Z3 expressions representing assertions over both the
451            # source and target expressions.
452            self._assertions: Set[z3.BoolRef] = set()
453
454        # Retrieves the corresponding Z3 variable.
455        def z3var(self, symbol: sympy.Symbol) -> z3.ExprRef:
456            assert symbol in self.symbols, f"Z3 variable not found for: {symbol}"
457            return self.symbols[symbol]
458
459        # Create a variable in Z3 of 'type' for 'symbol', if it doesn't already exists.
460        def add_var(self, symbol: sympy.Symbol, type: Type) -> z3.ExprRef:
461            if symbol in self.symbols:
462                return self.symbols[symbol]
463
464            log.debug("new variable: %s (%s)", symbol.name, type.__name__)
465
466            if type is int:
467                var = z3.Int(symbol.name)
468
469                # If 'symbol' is positive (SymPy assumption), we have to
470                # convey it to Z3 as well.
471                if symbol.is_positive:  # type: ignore[attr-defined]
472                    self._target_exprs.add(var > 0)
473            elif type is float:
474                var = z3.Real(symbol.name)
475            elif type is bool:
476                var = z3.Bool(symbol.name)
477            else:
478                raise RuntimeError(f"unsupported type for Z3 variable: {type}")
479
480            self.symbols[symbol] = var
481            return var
482
483        # Checks whether all symbols were already added.
484        def _check_freesymbols(self, e: sympy.Basic) -> None:
485            for s in e.free_symbols:
486                assert isinstance(s, sympy.Symbol)
487                # Call 'z3var' just to check whether there's already a
488                # Z3 variable corresponding to 's'.
489                self.z3var(s)
490
491
492        def to_z3_boolean_expr(self, e: sympy.Basic) -> z3.BoolRef:
493            z3expr = SympyToZ3(self).run(e)
494            assert isinstance(z3expr, z3.BoolRef), f"expected boolean expression. Got: {z3expr}"
495            return z3expr
496
497        def add_source_expr(self, e: z3.BoolRef) -> None:
498            if e not in self._source_exprs:
499                log.debug("add source guard: %s", z3str(e))
500            self._source_exprs.add(e)
501
502        def add_target_expr(self, e: sympy.Expr) -> None:
503            self._check_freesymbols(e)
504            z3expr = self.to_z3_boolean_expr(e)
505            if e not in self._target_exprs:
506                log.debug("add target guard: %s", z3str(z3expr))
507            self._target_exprs.add(z3expr)
508
509        def add_assertion(self, e: Union[z3.BoolRef, sympy.Basic]) -> None:
510            if isinstance(e, sympy.Basic):
511                self._check_freesymbols(e)
512                ref = self.to_z3_boolean_expr(e)
513            else:
514                ref = e
515            assert isinstance(ref, z3.BoolRef)
516            if ref not in self._assertions:
517                log.debug("add assertion: %s", z3str(ref))
518            self._assertions.add(ref)
519
520        def validate(self) -> None:
521            with dynamo_timed("TranslationValidator.validate"):
522                return self._validate()
523
524        def _validate(self) -> None:
525            if len(self._source_exprs) == 0 or len(self._target_exprs) == 0:
526                # If there are no source/target expressions, there's nothing we really
527                # wish to prove. So, we just return.
528                return None
529
530            # Here, we use "QF_NRA" logic for the solver:
531            #   "Quantifier-free Non-linear Real Arithmetic".
532            #
533            # Most of the guards expressions have:
534            #   1. arithmetic between integer and reals
535            #   2. no quantifiers
536            #   3. potentially non-linear.
537            #
538            # Although there's also "QF_NIRA" (mixed integer-real arithmetic),
539            # "QF_NRA" seems to work better on 'dynamo/test_dynamic_shapes.py'.
540            solver = z3.SolverFor("QF_NRA")
541            # Set a timeout for finding a solution.
542            solver.set(timeout=translation_validation_timeout())
543
544            # Add all the assertions to the solver.
545            for assertion in self._assertions:
546                solver.add(assertion)
547
548            # "Is there any case where it's TRUE for the target expressions,
549            #  but FALSE for the source expressions?"
550            solver.add(z3.Not(z3.And(*self._source_exprs)))
551            solver.add(*self._target_exprs)
552
553            log.debug("translation validation: start")
554            r = solver.check()
555            if r == z3.sat:
556                # Target expressions are unsound.
557                # Log the found model and the source expressions that failed.
558                model = solver.model()
559                raise ValidationException(
560                    model, self._assertions, self._target_exprs,
561                    failed_source_exprs=[
562                        inp for inp in self._source_exprs if not model.evaluate(inp)
563                    ]
564                )
565            else:
566                if r == z3.unknown:
567                    # Could not find a solution. It didn't fail, but it also
568                    # didn't succeed. Canceling the validation execution (keyboard
569                    # interrupt) also gets to this branch.
570                    log.warning("translation validation: could not validate: got z3.unknown")
571                else:
572                    # Target expressions are sound.
573                    assert r == z3.unsat
574                    log.debug("translation validation: success")
575
576except ImportError:
577    _HAS_Z3 = False
578
579    __all__ = [
580        "translation_validation_enabled", "translation_validation_timeout",
581        "ValidationException", "BisectValidationException",
582    ]
583
584else:
585    _HAS_Z3 = True
586
587    __all__ = [
588        "z3str", "z3op", "PopulateValidator", "SympyToZ3", "TranslationValidator",
589        "translation_validation_enabled", "translation_validation_timeout",
590        "ValidationException", "BisectValidationException",
591    ]
592
593from torch.fx.experimental import _config as config
594
595def translation_validation_enabled() -> bool:
596    # Checks everytime this function is called, in case the Dynamo
597    # option is set, but Z3 is not installed.
598    _assert_z3_installed_if_tv_set()
599    return _HAS_Z3 and config.translation_validation
600
601
602def translation_validation_timeout() -> int:
603    return config.translation_validation_timeout
604
605
606def _assert_z3_installed_if_tv_set():
607    assert _HAS_Z3 or not config.translation_validation, (
608        "translation validation requires Z3 package. Please, either install "
609        "z3-solver or disable translation validation."
610    )
611
612
613class ValidationException(TorchDynamoException):
614    def __init__(self, model, assertions, target_exprs, failed_source_exprs):
615        assert _HAS_Z3
616
617        def symbolstr(sym) -> str:
618            return f"{sym}: {model[sym]}"
619
620        def joinlines(xs) -> str:
621            return "\n".join(f"  ==> {x}" for x in xs)
622
623        model_str = joinlines(sorted(map(symbolstr, model)))
624        assertions_str = joinlines(sorted(map(z3str, assertions)))
625        target_exprs_str = joinlines(sorted(map(z3str, target_exprs)))
626        failed_source_exprs_str = joinlines(sorted(map(z3str, failed_source_exprs)))
627
628        self.msg = "translation validation failed."
629        self.details = f"""\
630Model:
631{model_str}
632
633Assertions:
634{assertions_str}
635
636Target Expressions:
637{target_exprs_str}
638
639Failed Source Expressions:
640{failed_source_exprs_str}"""
641
642    def __str__(self):
643        return f"{self.msg}\n\n{self.details}"
644
645
646class BisectValidationException(TorchDynamoException):
647    def __init__(self, validation_exc, expr, failed_action, traced_node):
648        self.msg = f"translation validation failed when {failed_action}: {expr}"
649        self.details = f"""\
650Failure occurred while running node:
651    {traced_node.format_node()}
652
653{validation_exc.details}"""
654
655    def __str__(self):
656        return f"{self.msg}\n\n{self.details}"
657
658# Checks when this module is loaded.
659_assert_z3_installed_if_tv_set()
660
661# Translation validation bisection.
662#
663# Bisect into the torch._assert nodes recorded in the shape_env FX graph, and raise
664# the earliest ValidationException.
665#
666# As guards are added by ShapeEnv.evaluate_expr calls, some simplification errors
667# might be silently happening. This function tries to nail down exactly at which
668# point things went wrong from a validation perspective.
669def bisect(shape_env):
670    from torch.fx.experimental.symbolic_shapes import ShapeEnv, SHAPEENV_EVENT_KEY, CURRENT_NODE_KEY
671    from torch.fx.experimental.recording import FakeTensorMeta, ShapeEnvEvent, replay_shape_env_events
672
673    events = shape_env.events
674
675    # Retrieves the ShapeEnvEvent associated with node.
676    def get_node_event(node: torch.fx.Node) -> ShapeEnvEvent:
677        assert SHAPEENV_EVENT_KEY in node.meta
678        return events[node.meta[SHAPEENV_EVENT_KEY]]
679
680    # Creates a new instance of fake, but updating every symbolic value's ShapeEnv
681    # reference to the one given as argument.
682    #
683    # This is needed so as not to simplify a symbolic expression using a ShapeEnv
684    # "from the future", where it may have a different set of replacements.
685    def new_with_shape_env(shape_env: ShapeEnv, fake) -> Any:
686        if isinstance(fake, int):
687            return fake
688        if isinstance(fake, torch.SymInt):
689            return torch.SymInt(fake.node.with_shape_env(shape_env))
690        assert isinstance(fake, FakeTensorMeta)
691        return FakeTensorMeta(
692            tuple(new_with_shape_env(shape_env, s) for s in fake.size()),
693            tuple(new_with_shape_env(shape_env, s) for s in fake.stride()),
694            new_with_shape_env(shape_env, fake.storage_offset()),
695            fake.is_nested,
696        )
697
698    # Checks whether the given shape_env fails when produce_guards is called.
699    def check_shapeenv_fails(shape_env: ShapeEnv, tracked_fakes: Optional[List[Any]]) -> Optional[ValidationException]:
700        assert tracked_fakes is not None
701        try:
702            # This produce_guards call is a best-effort replication, since we
703            # don't populate EqualityConstraint list. Reason: we would also have
704            # to save OutputGraph.tracked_fakes_id_to_source.
705            shape_env.produce_guards(
706                [new_with_shape_env(shape_env, a.fake) for a in tracked_fakes],
707                [a.source for a in tracked_fakes],
708                input_contexts=[a.symbolic_context for a in tracked_fakes],
709            )
710            return None
711        except ValidationException as e:
712            return e
713
714    # Checks whether the ShapeEnv reconstructed by replaying the events until
715    # node is created fails when produce_guards is called.
716    def check_node_fails(node: torch.fx.Node) -> Optional[ValidationException]:
717        number = node.meta[SHAPEENV_EVENT_KEY]
718        # Reconstruct shape_env until the event at event_number.
719        shape_env = replay_shape_env_events(events[:number + 1])
720        shape_env.graph.lint()
721        return check_shapeenv_fails(shape_env, events[number].tracked_fakes)
722
723    last_exception = check_shapeenv_fails(shape_env, shape_env._snapshot_tracked_fakes())
724
725    if not last_exception:
726        # We don't actually fail due to a produce_guards call.
727        # Stop and don't bisect.
728        log.info("translation validation succeeded: no errors found.")
729        return
730
731    if not shape_env.should_record_events or config.translation_validation_no_bisect:
732        # Bisection is off.
733        # Return the last ValidationException we got.
734        raise last_exception
735
736    # Cache the raised exception (if any) at each bisection point.
737    exception = {}
738
739    # Bisection happens on the assertion nodes of the recorded FX graph for
740    # dynamic shapes.
741    assert_nodes = [node for node in shape_env.graph.nodes if node.target == torch._assert]
742
743    # Preparing the indices for binary search.
744    left, mid, right = 0, 0, len(assert_nodes) - 1
745
746    while left < right:
747        mid = (left + right) // 2
748
749        node = assert_nodes[mid]
750        log.debug("bisecting at %s: %s", mid, get_node_event(node))
751
752        # Check whether the new shape_env raises a ValidationException or not.
753        exception[mid] = check_node_fails(node)
754
755        if exception[mid]:
756            right = mid
757        else:
758            left = mid + 1
759
760    assert left in exception and isinstance(exception[left], ValidationException)
761
762    node = assert_nodes[left]
763    event = get_node_event(node)
764
765    if event.is_evaluate_expr():
766        failed_action = "evaluating"
767    else:
768        assert event.is_defer_runtime_assert(), f"unexpected event type: {event}"
769        failed_action = "adding runtime assert"
770
771    args = event.args
772    assert args is not None
773    assert len(args) >= 2, (
774        f"bisecting expects {event.name} to have at least 2 positional arguments. "
775        f"Got: {len(args)}"
776    )
777    assert isinstance(args[1], sympy.Basic), (
778        f"bisecting expects {event.name} to have a SymPy expression as its second argument. "
779        f"Got: {type(args[1])}"
780    )
781
782    raise BisectValidationException(
783        exception[left],
784        expr=args[1],
785        failed_action=failed_action,
786        traced_node=node.meta[CURRENT_NODE_KEY],
787    )
788