xref: /aosp_15_r20/external/pytorch/test/onnx/verify.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import difflib
2import io
3
4import numpy as np
5
6import onnx
7import onnx.helper
8
9import torch
10import torch.jit
11import torch.onnx
12
13
14def colonize(msg, sep=": "):
15    if not msg:
16        return ""
17    else:
18        return msg + sep
19
20
21class Errors:
22    """
23    An error-collecting object which supports error recovery.
24
25    It is intended to be used like a context manager:
26
27    >>> with Errors("Top-level error message") as errs:
28    >>>     ...
29    """
30
31    def __init__(self, msg, rtol=1e-3, atol=1e-5):
32        self.msg = msg
33        self.errors = []
34        self.context = []
35        self.rtol = rtol
36        self.atol = atol
37
38        # Allocated upon instance creation so that multiple Errors
39        # can be used
40        class ShortCircuit(Exception):
41            pass
42
43        self.exc_class = ShortCircuit
44
45    def requireAlmostEqual(self, x, y, msg=None):
46        """
47        Test that x and y are nearly equal (equal within self.rtol
48        precision); aborts execution if they are not.
49        """
50        self.almostEqualAndThen(x, y, msg, self.failWith)
51
52    def checkAlmostEqual(self, x, y, msg=None):
53        """
54        Test that x and y are nearly equal (equal within self.rtol
55        precision), but continue execution even if they are not equal.
56
57        To prevent error cascades, you should remember to call "failIfErrs"
58        at some later point in time.
59        """
60        self.almostEqualAndThen(x, y, msg, self.addErr)
61
62    def almostEqualAndThen(self, x, y, msg, k):
63        """
64        Helper for implementing "requireAlmostEqual" and "checkAlmostEqual".
65        Upon failure, invokes continuation "k" with the error message.
66
67        At the moment, only tests on "numpy.ndarray" are supported.
68        """
69        if isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
70            np.testing.assert_allclose(
71                x, y, rtol=self.rtol, atol=self.atol, equal_nan=True, verbose=True
72            )
73        else:
74            raise RuntimeError("Unsupported almost equal test")
75
76    def requireEqual(self, x, y, msg=None):
77        """
78        Test that x and y are equal; aborts execution if they are not.
79        """
80        self.equalAndThen(x, y, msg, self.failWith)
81
82    def checkEqual(self, x, y, msg=None):
83        """
84        Test that x and y are equal, but continue execution even if they are not equal.
85
86        To prevent error cascades, you should remember to call "failIfErrs"
87        at some later point in time.
88        """
89        self.equalAndThen(x, y, msg, self.addErr)
90
91    # Bit-for-bit accuracy test
92    def equalAndThen(self, x, y, msg, k):
93        """
94        Helper for implementing "requireEqual" and "checkEqual".  Upon failure,
95        invokes continuation "k" with the error message.
96        """
97        if isinstance(x, onnx.TensorProto) and isinstance(y, onnx.TensorProto):
98            self.equalAndThen(x.name, y.name, msg, k)
99            # Use numpy for the comparison
100            t1 = onnx.numpy_helper.to_array(x)
101            t2 = onnx.numpy_helper.to_array(y)
102            new_msg = f"{colonize(msg)}In embedded parameter '{x.name}'"
103            self.equalAndThen(t1, t2, new_msg, k)
104        elif isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
105            np.testing.assert_equal(x, y)
106        else:
107            if x != y:
108                # TODO: Better algorithm for lists
109                sx = str(x)
110                sy = str(y)
111                if len(sx) > 40 or len(sy) > 40 or "\n" in sx or "\n" in sy:
112                    # long form
113                    l = "=" * 50
114                    k(
115                        "\n{}The value\n{}\n{}\n{}\n\ndoes not equal\n\n{}\n{}\n{}".format(
116                            colonize(msg, ":\n"), l, sx, l, l, sy, l
117                        )
118                    )
119                else:
120                    k(f"{colonize(msg)}{sx} != {sy}")
121
122    def requireMultiLineEqual(self, x, y, msg=None):
123        """
124        Test that long, multi-line strings x and y are equal;
125        aborts execution if they are not.
126        """
127        self.multiLineEqualAndThen(x, y, msg, self.failWith)
128
129    def multiLineEqualAndThen(self, x, y, msg, k):
130        """
131        Helper for implementing "requireMultiLineEqual".  Upon failure,
132        invokes continuation "k" with the error message.
133        """
134        if msg is None:
135            msg = "Strings are not equal"
136        if x != y:
137            diff = difflib.ndiff(x.splitlines(True), y.splitlines(True))
138            k("{}{}".format(colonize(msg, ":\n\n"), "".join(diff)))
139
140    def addErr(self, msg):
141        """
142        Add an error to the error context, but continue executing.
143        """
144        # TODO: instead of immediately concatenating the context in the msg,
145        # attach it as metadata and make a decision how to format it later.
146        msg_w_ctx = msg
147        for c in reversed(self.context):
148            msg += "\n\n  * " + "\n    ".join(c.splitlines())
149        self.errors.append(msg)
150
151    def fail(self):
152        """
153        Immediately fail and short-circuit to the next recovery context.
154
155        NB: It is an error to "fail" without having added any errors to
156        the error context.
157        """
158        raise self.exc_class
159
160    def failWith(self, msg):
161        """
162        Add an error to the error context, and then short-circuit.
163        """
164        self.addErr(msg)
165        self.fail()
166
167    def failIfErrs(self):
168        """
169        If there are any errors in the error context, short-circuit.
170
171        This is used to prevent error cascades.
172        """
173        if self.errors:
174            self.fail()
175
176    def recover(self):
177        """
178        Returns a context manager which can be used to recover in case of
179        an error.  Example usage:
180
181        >>> with errs.recover():
182        >>>     ...
183        """
184        parent_self = self
185
186        class Recover:
187            def __enter__(self):
188                pass
189
190            def __exit__(self, exc_type, exc_value, traceback):
191                if exc_type == parent_self.exc_class:
192                    return True
193
194        return Recover()
195
196    def addErrCtxt(self, msg):
197        """
198        Returns a context manager which encloses a fragment of code with
199        an extra contextual message, e.g., where an error occurred, or a hint
200        applicable to all errors in the area.  Example usage:
201
202        >>> with errs.addErrCtx("Some text"):
203        >>>     ...
204        """
205        parent_self = self
206
207        class AddContext:
208            def __enter__(self):
209                parent_self.context.append(msg)
210
211            def __exit__(self, exc_type, exc_value, traceback):
212                parent_self.context.pop()
213
214        return AddContext()
215
216    def __enter__(self):
217        return self
218
219    def __exit__(self, exc_type, exc_value, traceback):
220        if self.errors:
221            errors_msg = "\n\n".join("ERROR: " + x for x in self.errors)
222            final_msg = "{}\n{}\n{}".format(self.msg, "-" * 70, errors_msg)
223            raise AssertionError(final_msg)
224        if exc_type == self.exc_class:
225            raise RuntimeError("ShortCircuit was raised, but no errors were recorded")
226
227
228def verify(
229    model,
230    args,
231    backend,
232    verbose=False,
233    training=torch.onnx.TrainingMode.EVAL,
234    rtol=1e-3,
235    atol=1e-7,
236    test_args=2,
237    do_constant_folding=True,
238    opset_version=None,
239    keep_initializers_as_inputs=True,
240    add_node_names=False,
241    operator_export_type=torch.onnx.OperatorExportTypes.ONNX,
242    input_names=None,
243    dynamic_axes=None,
244    remained_onnx_input_idx=None,
245):
246    """
247    Export a model into ONNX, import it into a specified ONNX backend, and then
248    on a few random inputs verify that PyTorch and the backend produced the same
249    results.  Requires onnx to be installed.
250
251    This function may spuriously fail: some operators are implemented with
252    different numerical precision in an ONNX backend, in which case an unstable
253    network (e.g., Inception) may blow up these numerical instabilities.  This
254    situation is less likely to happen if your model has been trained.  However,
255    if this is not the case, you may have found a bug!  Please report it to the
256    PyTorch developers.  You can also debug the issue yourself by removing
257    suffixes of operators from your model until verification passes.
258
259    For reproducibility, we recommend explicitly setting PyTorch's seed before
260    invoking this function.
261
262    Args:
263        model (torch.nn.Module): the model to be exported and verified
264        args (tuple of arguments): the inputs to
265            the model, e.g., such that ``model(*args)`` is a valid
266            invocation of the model.  Any non-Variable arguments will
267            be hard-coded into the exported model; any Variable arguments
268            will become inputs of the exported model, in the order they
269            occur in args.  If args is a Variable, this is equivalent
270            to having called it with a 1-ary tuple of that Variable.
271            (Note: passing keyword arguments to the model is not currently
272            supported.  Give us a shout if you need it.)
273        backend (onnx.backend module): ONNX backend to verify with
274        verbose (bool, default False): if specified, we will print out a debug
275            description of the trace being exported.
276        training (bool, default False): export the model in training mode.  At
277            the moment, ONNX is oriented towards exporting models for inference
278            only, so you will generally not need to set this to True.
279        rtol (float, default 1e-3): relative precision required
280        test_args (int or iterable of args, default 2):
281            either an integer specifying the number
282            of random arguments to generate, or an iterable producing arguments
283            to test under.
284        opset_version (int, default None): the opset version of the model to
285            export. If not specified, the default value in symboli_helper will
286            be used in utils._export().
287        operator_export_type (enum, default OperatorExportTypes.ONNX): the operator
288            export type to use when exporting the model. The default value converts
289            all operators to ONNX ops.
290        input_names (list of string): list of input names.
291        dynamic_axes (dict of (string, list)): dynamic_axes.
292        remained_onnx_input_idx (list of int, default None): The remained ONNX input index.
293    """
294
295    def _nested_map(condition, fn, condition_msg=None):
296        def _map(obj):
297            if condition(obj):
298                return fn(obj)
299            elif obj is None:
300                return None
301            elif isinstance(obj, (list, tuple)):
302                return type(obj)(_map(x) for x in obj)
303            else:
304                raise ValueError(
305                    "Auto nesting doesn't know how to process "
306                    "an input object of type "
307                    + torch.typename(obj)
308                    + (
309                        ". Accepted types: "
310                        + condition_msg
311                        + ", or lists/tuples of them"
312                        if condition_msg
313                        else ""
314                    )
315                )
316
317        return _map
318
319    def _iter_filter(condition, allow_unknown=False, condition_msg=None):
320        def _iter(obj):
321            if condition(obj):
322                yield obj
323            elif obj is None:
324                return
325            elif isinstance(obj, (list, tuple)):
326                for o in obj:
327                    yield from _iter(o)
328            elif allow_unknown:
329                yield obj
330            else:
331                raise ValueError(
332                    "Auto nesting doesn't know how to process "
333                    "an input object of type "
334                    + torch.typename(obj)
335                    + (
336                        ". Accepted types: "
337                        + condition_msg
338                        + ", or lists/tuples of them"
339                        if condition_msg
340                        else ""
341                    )
342                )
343
344        return _iter
345
346    def is_tensor(o):
347        return isinstance(o, torch.Tensor)
348
349    _iter_tensors = _iter_filter(is_tensor, condition_msg="Tensors")
350
351    def randomize_arg(arg):
352        new_data = arg.data.clone()
353        # For now, don't try randomizing non-float tensors; these
354        # are likely to be things like indices, where just randomly
355        # spattering some longs is unlikely to work.  One way we could
356        # make this work is to apply a random permutation or something.
357        if arg.is_floating_point():
358            new_data.uniform_()
359        return torch.autograd.Variable(new_data, requires_grad=arg.requires_grad)
360
361    randomize_args = _nested_map(is_tensor, randomize_arg)
362
363    def backend_args(args):
364        # TODO: onnx should accept iterables
365        return tuple(v.data.cpu().numpy() for v in _iter_tensors(args))
366
367    def load_bytes(b):
368        b.seek(0)
369        x = onnx.load(b)
370        # doc_string has stack traces - let's remove them to make comparison
371        # sane
372        onnx.helper.strip_doc_string(x)
373        return x
374
375    # Special case for common case of passing a single Tensor
376    if isinstance(args, torch.Tensor):
377        args = (args,)
378
379    with torch.onnx.select_model_mode_for_export(model, training):
380        proto_bytes = io.BytesIO()
381        torch_out = torch.onnx.utils._export(
382            model,
383            args,
384            proto_bytes,
385            verbose=verbose,
386            do_constant_folding=do_constant_folding,
387            opset_version=opset_version,
388            keep_initializers_as_inputs=keep_initializers_as_inputs,
389            add_node_names=add_node_names,
390            operator_export_type=operator_export_type,
391            input_names=input_names,
392            dynamic_axes=dynamic_axes,
393        )
394        if isinstance(model, torch.jit.ScriptModule):
395            torch_out = model(*args)
396        proto = load_bytes(proto_bytes)
397        prepared = backend.prepare(proto)
398
399        def run(args, remained_onnx_input_idx):
400            alt_proto_bytes = io.BytesIO()
401            torch_out = torch.onnx.utils._export(
402                model,
403                args,
404                alt_proto_bytes,
405                verbose=verbose,
406                do_constant_folding=do_constant_folding,
407                opset_version=opset_version,
408                keep_initializers_as_inputs=keep_initializers_as_inputs,
409                add_node_names=add_node_names,
410                operator_export_type=operator_export_type,
411                input_names=input_names,
412                dynamic_axes=dynamic_axes,
413            )
414            if isinstance(model, torch.jit.ScriptModule):
415                torch_out = model(*args)
416            alt_proto = load_bytes(alt_proto_bytes)
417            if proto.SerializeToString() != alt_proto.SerializeToString():
418                # OK, let's try to figure out what happened.
419                msg = "When I exported your model with different inputs, the result was different."
420                if not verbose:
421                    msg += "\n(To get more information, run torch.onnx.verify(..., verbose=True))"
422                with Errors(msg, rtol=rtol, atol=atol) as errs:
423                    # First, check if we have the same number of parameters, and
424                    # that they"re the same order.  If they don"t, something has *really* gone wrong.
425                    initializer_order_hint = (
426                        "This is really strange! The second time I exported your model,\n"
427                        "it had a different set of parameters.  Are you assigning Parameters\n"
428                        "in the forward() of your model definition?"
429                    )
430                    with errs.addErrCtxt(initializer_order_hint):
431                        errs.requireEqual(
432                            [x.name for x in proto.graph.initializer],
433                            [x.name for x in alt_proto.graph.initializer],
434                            msg="Parameters list differs",
435                        )
436
437                    # Now check if the embedded parameters are actually the same
438                    initializer_hint = (
439                        "A difference in embedded parameters usually means that\n"
440                        "your model is updating parameters/buffers even in inference\n"
441                        "mode.  Look for a buggy nn.Module which isn't respecting train().\n"
442                    )
443                    with errs.recover(), errs.addErrCtxt(initializer_hint):
444                        for x, y in zip(
445                            proto.graph.initializer, alt_proto.graph.initializer
446                        ):
447                            errs.checkEqual(x, y)
448
449                    # Next, check if the model structure lines up.
450                    structure_hint = (
451                        "A difference in model structure usually means that\n"
452                        "your model has dynamic control flow.  These models are not\n"
453                        "currently supported by the exporter."
454                    )
455                    with errs.recover(), errs.addErrCtxt(structure_hint):
456                        # Delete initializers since we already tested them
457                        stripped_proto = onnx.ModelProto()
458                        stripped_proto.CopyFrom(proto)
459                        del stripped_proto.graph.initializer[:]
460
461                        stripped_alt_proto = onnx.ModelProto()
462                        stripped_alt_proto.CopyFrom(alt_proto)
463                        del stripped_alt_proto.graph.initializer[:]
464
465                        # Compare the printable graph representations first
466                        errs.requireMultiLineEqual(
467                            onnx.helper.printable_graph(stripped_proto.graph),
468                            onnx.helper.printable_graph(stripped_alt_proto.graph),
469                        )
470
471                        # Compare the actual protobuf text formats now (not
472                        # very user-friendly!)
473                        errs.requireMultiLineEqual(
474                            str(stripped_proto), str(stripped_alt_proto)
475                        )
476
477                        # One last ditch effort, using built-in equality on
478                        # protobufs
479                        errs.requireEqual(stripped_proto, stripped_alt_proto)
480
481                    errs.failIfErrs()
482
483                    # At this point, we should have figured out why the binary
484                    # protobufs differed, and short-circuited out of this code
485                    # with a helpful error message.  But what if we didn't?
486                    # We better still try to give a good error message in this
487                    # case.  We EXPECT these requires to fail.  If they don't,
488                    # that is a bug in verify
489                    errs.requireEqual(proto, alt_proto)
490                    errs.requireEqual(
491                        proto_bytes.getvalue(), alt_proto_bytes.getvalue()
492                    )
493                    raise AssertionError
494
495            # TODO: test that the traced model also returns the same thing...
496            run_helper(torch_out, args, remained_onnx_input_idx)
497
498        # Factored out so we can avoid one run of the model
499        def run_helper(torch_out, args, remained_onnx_input_idx):
500            onnx_input = backend_args(args)
501            if remained_onnx_input_idx is not None:
502                input_onnx = []
503                for idx in remained_onnx_input_idx:
504                    input_onnx.append(onnx_input[idx])
505                onnx_input = tuple(input_onnx)
506            backend_out = prepared.run(onnx_input)
507            if isinstance(torch_out, torch.Tensor):
508                torch_out = (torch_out,)
509            torch_out, _ = torch.jit._flatten(torch_out)
510            # NB: onnx backend NEVER returns bare numpy array
511            msg = "ONNX backend returned different results from PyTorch"
512            result_hint = (
513                "If you are not using trained parameters, a difference in results\n"
514                "could mean that your network is numerically unstable.  Otherwise\n"
515                "it indicates a bug in PyTorch/ONNX; please file a bug report."
516            )
517            with Errors(msg, rtol=rtol, atol=atol) as errs, errs.addErrCtxt(
518                result_hint
519            ):
520                for i, (x, y) in enumerate(zip(torch_out, backend_out)):
521                    errs.checkAlmostEqual(x.data.cpu().numpy(), y, f"In output {i}")
522
523        run_helper(torch_out, args, remained_onnx_input_idx)
524
525        if isinstance(test_args, int):
526            for i in range(test_args):
527                run(randomize_args(args), remained_onnx_input_idx)
528        else:
529            for test_arg in test_args:
530                run(test_arg, remained_onnx_input_idx)
531