1import difflib 2import io 3 4import numpy as np 5 6import onnx 7import onnx.helper 8 9import torch 10import torch.jit 11import torch.onnx 12 13 14def colonize(msg, sep=": "): 15 if not msg: 16 return "" 17 else: 18 return msg + sep 19 20 21class Errors: 22 """ 23 An error-collecting object which supports error recovery. 24 25 It is intended to be used like a context manager: 26 27 >>> with Errors("Top-level error message") as errs: 28 >>> ... 29 """ 30 31 def __init__(self, msg, rtol=1e-3, atol=1e-5): 32 self.msg = msg 33 self.errors = [] 34 self.context = [] 35 self.rtol = rtol 36 self.atol = atol 37 38 # Allocated upon instance creation so that multiple Errors 39 # can be used 40 class ShortCircuit(Exception): 41 pass 42 43 self.exc_class = ShortCircuit 44 45 def requireAlmostEqual(self, x, y, msg=None): 46 """ 47 Test that x and y are nearly equal (equal within self.rtol 48 precision); aborts execution if they are not. 49 """ 50 self.almostEqualAndThen(x, y, msg, self.failWith) 51 52 def checkAlmostEqual(self, x, y, msg=None): 53 """ 54 Test that x and y are nearly equal (equal within self.rtol 55 precision), but continue execution even if they are not equal. 56 57 To prevent error cascades, you should remember to call "failIfErrs" 58 at some later point in time. 59 """ 60 self.almostEqualAndThen(x, y, msg, self.addErr) 61 62 def almostEqualAndThen(self, x, y, msg, k): 63 """ 64 Helper for implementing "requireAlmostEqual" and "checkAlmostEqual". 65 Upon failure, invokes continuation "k" with the error message. 66 67 At the moment, only tests on "numpy.ndarray" are supported. 68 """ 69 if isinstance(x, np.ndarray) and isinstance(y, np.ndarray): 70 np.testing.assert_allclose( 71 x, y, rtol=self.rtol, atol=self.atol, equal_nan=True, verbose=True 72 ) 73 else: 74 raise RuntimeError("Unsupported almost equal test") 75 76 def requireEqual(self, x, y, msg=None): 77 """ 78 Test that x and y are equal; aborts execution if they are not. 79 """ 80 self.equalAndThen(x, y, msg, self.failWith) 81 82 def checkEqual(self, x, y, msg=None): 83 """ 84 Test that x and y are equal, but continue execution even if they are not equal. 85 86 To prevent error cascades, you should remember to call "failIfErrs" 87 at some later point in time. 88 """ 89 self.equalAndThen(x, y, msg, self.addErr) 90 91 # Bit-for-bit accuracy test 92 def equalAndThen(self, x, y, msg, k): 93 """ 94 Helper for implementing "requireEqual" and "checkEqual". Upon failure, 95 invokes continuation "k" with the error message. 96 """ 97 if isinstance(x, onnx.TensorProto) and isinstance(y, onnx.TensorProto): 98 self.equalAndThen(x.name, y.name, msg, k) 99 # Use numpy for the comparison 100 t1 = onnx.numpy_helper.to_array(x) 101 t2 = onnx.numpy_helper.to_array(y) 102 new_msg = f"{colonize(msg)}In embedded parameter '{x.name}'" 103 self.equalAndThen(t1, t2, new_msg, k) 104 elif isinstance(x, np.ndarray) and isinstance(y, np.ndarray): 105 np.testing.assert_equal(x, y) 106 else: 107 if x != y: 108 # TODO: Better algorithm for lists 109 sx = str(x) 110 sy = str(y) 111 if len(sx) > 40 or len(sy) > 40 or "\n" in sx or "\n" in sy: 112 # long form 113 l = "=" * 50 114 k( 115 "\n{}The value\n{}\n{}\n{}\n\ndoes not equal\n\n{}\n{}\n{}".format( 116 colonize(msg, ":\n"), l, sx, l, l, sy, l 117 ) 118 ) 119 else: 120 k(f"{colonize(msg)}{sx} != {sy}") 121 122 def requireMultiLineEqual(self, x, y, msg=None): 123 """ 124 Test that long, multi-line strings x and y are equal; 125 aborts execution if they are not. 126 """ 127 self.multiLineEqualAndThen(x, y, msg, self.failWith) 128 129 def multiLineEqualAndThen(self, x, y, msg, k): 130 """ 131 Helper for implementing "requireMultiLineEqual". Upon failure, 132 invokes continuation "k" with the error message. 133 """ 134 if msg is None: 135 msg = "Strings are not equal" 136 if x != y: 137 diff = difflib.ndiff(x.splitlines(True), y.splitlines(True)) 138 k("{}{}".format(colonize(msg, ":\n\n"), "".join(diff))) 139 140 def addErr(self, msg): 141 """ 142 Add an error to the error context, but continue executing. 143 """ 144 # TODO: instead of immediately concatenating the context in the msg, 145 # attach it as metadata and make a decision how to format it later. 146 msg_w_ctx = msg 147 for c in reversed(self.context): 148 msg += "\n\n * " + "\n ".join(c.splitlines()) 149 self.errors.append(msg) 150 151 def fail(self): 152 """ 153 Immediately fail and short-circuit to the next recovery context. 154 155 NB: It is an error to "fail" without having added any errors to 156 the error context. 157 """ 158 raise self.exc_class 159 160 def failWith(self, msg): 161 """ 162 Add an error to the error context, and then short-circuit. 163 """ 164 self.addErr(msg) 165 self.fail() 166 167 def failIfErrs(self): 168 """ 169 If there are any errors in the error context, short-circuit. 170 171 This is used to prevent error cascades. 172 """ 173 if self.errors: 174 self.fail() 175 176 def recover(self): 177 """ 178 Returns a context manager which can be used to recover in case of 179 an error. Example usage: 180 181 >>> with errs.recover(): 182 >>> ... 183 """ 184 parent_self = self 185 186 class Recover: 187 def __enter__(self): 188 pass 189 190 def __exit__(self, exc_type, exc_value, traceback): 191 if exc_type == parent_self.exc_class: 192 return True 193 194 return Recover() 195 196 def addErrCtxt(self, msg): 197 """ 198 Returns a context manager which encloses a fragment of code with 199 an extra contextual message, e.g., where an error occurred, or a hint 200 applicable to all errors in the area. Example usage: 201 202 >>> with errs.addErrCtx("Some text"): 203 >>> ... 204 """ 205 parent_self = self 206 207 class AddContext: 208 def __enter__(self): 209 parent_self.context.append(msg) 210 211 def __exit__(self, exc_type, exc_value, traceback): 212 parent_self.context.pop() 213 214 return AddContext() 215 216 def __enter__(self): 217 return self 218 219 def __exit__(self, exc_type, exc_value, traceback): 220 if self.errors: 221 errors_msg = "\n\n".join("ERROR: " + x for x in self.errors) 222 final_msg = "{}\n{}\n{}".format(self.msg, "-" * 70, errors_msg) 223 raise AssertionError(final_msg) 224 if exc_type == self.exc_class: 225 raise RuntimeError("ShortCircuit was raised, but no errors were recorded") 226 227 228def verify( 229 model, 230 args, 231 backend, 232 verbose=False, 233 training=torch.onnx.TrainingMode.EVAL, 234 rtol=1e-3, 235 atol=1e-7, 236 test_args=2, 237 do_constant_folding=True, 238 opset_version=None, 239 keep_initializers_as_inputs=True, 240 add_node_names=False, 241 operator_export_type=torch.onnx.OperatorExportTypes.ONNX, 242 input_names=None, 243 dynamic_axes=None, 244 remained_onnx_input_idx=None, 245): 246 """ 247 Export a model into ONNX, import it into a specified ONNX backend, and then 248 on a few random inputs verify that PyTorch and the backend produced the same 249 results. Requires onnx to be installed. 250 251 This function may spuriously fail: some operators are implemented with 252 different numerical precision in an ONNX backend, in which case an unstable 253 network (e.g., Inception) may blow up these numerical instabilities. This 254 situation is less likely to happen if your model has been trained. However, 255 if this is not the case, you may have found a bug! Please report it to the 256 PyTorch developers. You can also debug the issue yourself by removing 257 suffixes of operators from your model until verification passes. 258 259 For reproducibility, we recommend explicitly setting PyTorch's seed before 260 invoking this function. 261 262 Args: 263 model (torch.nn.Module): the model to be exported and verified 264 args (tuple of arguments): the inputs to 265 the model, e.g., such that ``model(*args)`` is a valid 266 invocation of the model. Any non-Variable arguments will 267 be hard-coded into the exported model; any Variable arguments 268 will become inputs of the exported model, in the order they 269 occur in args. If args is a Variable, this is equivalent 270 to having called it with a 1-ary tuple of that Variable. 271 (Note: passing keyword arguments to the model is not currently 272 supported. Give us a shout if you need it.) 273 backend (onnx.backend module): ONNX backend to verify with 274 verbose (bool, default False): if specified, we will print out a debug 275 description of the trace being exported. 276 training (bool, default False): export the model in training mode. At 277 the moment, ONNX is oriented towards exporting models for inference 278 only, so you will generally not need to set this to True. 279 rtol (float, default 1e-3): relative precision required 280 test_args (int or iterable of args, default 2): 281 either an integer specifying the number 282 of random arguments to generate, or an iterable producing arguments 283 to test under. 284 opset_version (int, default None): the opset version of the model to 285 export. If not specified, the default value in symboli_helper will 286 be used in utils._export(). 287 operator_export_type (enum, default OperatorExportTypes.ONNX): the operator 288 export type to use when exporting the model. The default value converts 289 all operators to ONNX ops. 290 input_names (list of string): list of input names. 291 dynamic_axes (dict of (string, list)): dynamic_axes. 292 remained_onnx_input_idx (list of int, default None): The remained ONNX input index. 293 """ 294 295 def _nested_map(condition, fn, condition_msg=None): 296 def _map(obj): 297 if condition(obj): 298 return fn(obj) 299 elif obj is None: 300 return None 301 elif isinstance(obj, (list, tuple)): 302 return type(obj)(_map(x) for x in obj) 303 else: 304 raise ValueError( 305 "Auto nesting doesn't know how to process " 306 "an input object of type " 307 + torch.typename(obj) 308 + ( 309 ". Accepted types: " 310 + condition_msg 311 + ", or lists/tuples of them" 312 if condition_msg 313 else "" 314 ) 315 ) 316 317 return _map 318 319 def _iter_filter(condition, allow_unknown=False, condition_msg=None): 320 def _iter(obj): 321 if condition(obj): 322 yield obj 323 elif obj is None: 324 return 325 elif isinstance(obj, (list, tuple)): 326 for o in obj: 327 yield from _iter(o) 328 elif allow_unknown: 329 yield obj 330 else: 331 raise ValueError( 332 "Auto nesting doesn't know how to process " 333 "an input object of type " 334 + torch.typename(obj) 335 + ( 336 ". Accepted types: " 337 + condition_msg 338 + ", or lists/tuples of them" 339 if condition_msg 340 else "" 341 ) 342 ) 343 344 return _iter 345 346 def is_tensor(o): 347 return isinstance(o, torch.Tensor) 348 349 _iter_tensors = _iter_filter(is_tensor, condition_msg="Tensors") 350 351 def randomize_arg(arg): 352 new_data = arg.data.clone() 353 # For now, don't try randomizing non-float tensors; these 354 # are likely to be things like indices, where just randomly 355 # spattering some longs is unlikely to work. One way we could 356 # make this work is to apply a random permutation or something. 357 if arg.is_floating_point(): 358 new_data.uniform_() 359 return torch.autograd.Variable(new_data, requires_grad=arg.requires_grad) 360 361 randomize_args = _nested_map(is_tensor, randomize_arg) 362 363 def backend_args(args): 364 # TODO: onnx should accept iterables 365 return tuple(v.data.cpu().numpy() for v in _iter_tensors(args)) 366 367 def load_bytes(b): 368 b.seek(0) 369 x = onnx.load(b) 370 # doc_string has stack traces - let's remove them to make comparison 371 # sane 372 onnx.helper.strip_doc_string(x) 373 return x 374 375 # Special case for common case of passing a single Tensor 376 if isinstance(args, torch.Tensor): 377 args = (args,) 378 379 with torch.onnx.select_model_mode_for_export(model, training): 380 proto_bytes = io.BytesIO() 381 torch_out = torch.onnx.utils._export( 382 model, 383 args, 384 proto_bytes, 385 verbose=verbose, 386 do_constant_folding=do_constant_folding, 387 opset_version=opset_version, 388 keep_initializers_as_inputs=keep_initializers_as_inputs, 389 add_node_names=add_node_names, 390 operator_export_type=operator_export_type, 391 input_names=input_names, 392 dynamic_axes=dynamic_axes, 393 ) 394 if isinstance(model, torch.jit.ScriptModule): 395 torch_out = model(*args) 396 proto = load_bytes(proto_bytes) 397 prepared = backend.prepare(proto) 398 399 def run(args, remained_onnx_input_idx): 400 alt_proto_bytes = io.BytesIO() 401 torch_out = torch.onnx.utils._export( 402 model, 403 args, 404 alt_proto_bytes, 405 verbose=verbose, 406 do_constant_folding=do_constant_folding, 407 opset_version=opset_version, 408 keep_initializers_as_inputs=keep_initializers_as_inputs, 409 add_node_names=add_node_names, 410 operator_export_type=operator_export_type, 411 input_names=input_names, 412 dynamic_axes=dynamic_axes, 413 ) 414 if isinstance(model, torch.jit.ScriptModule): 415 torch_out = model(*args) 416 alt_proto = load_bytes(alt_proto_bytes) 417 if proto.SerializeToString() != alt_proto.SerializeToString(): 418 # OK, let's try to figure out what happened. 419 msg = "When I exported your model with different inputs, the result was different." 420 if not verbose: 421 msg += "\n(To get more information, run torch.onnx.verify(..., verbose=True))" 422 with Errors(msg, rtol=rtol, atol=atol) as errs: 423 # First, check if we have the same number of parameters, and 424 # that they"re the same order. If they don"t, something has *really* gone wrong. 425 initializer_order_hint = ( 426 "This is really strange! The second time I exported your model,\n" 427 "it had a different set of parameters. Are you assigning Parameters\n" 428 "in the forward() of your model definition?" 429 ) 430 with errs.addErrCtxt(initializer_order_hint): 431 errs.requireEqual( 432 [x.name for x in proto.graph.initializer], 433 [x.name for x in alt_proto.graph.initializer], 434 msg="Parameters list differs", 435 ) 436 437 # Now check if the embedded parameters are actually the same 438 initializer_hint = ( 439 "A difference in embedded parameters usually means that\n" 440 "your model is updating parameters/buffers even in inference\n" 441 "mode. Look for a buggy nn.Module which isn't respecting train().\n" 442 ) 443 with errs.recover(), errs.addErrCtxt(initializer_hint): 444 for x, y in zip( 445 proto.graph.initializer, alt_proto.graph.initializer 446 ): 447 errs.checkEqual(x, y) 448 449 # Next, check if the model structure lines up. 450 structure_hint = ( 451 "A difference in model structure usually means that\n" 452 "your model has dynamic control flow. These models are not\n" 453 "currently supported by the exporter." 454 ) 455 with errs.recover(), errs.addErrCtxt(structure_hint): 456 # Delete initializers since we already tested them 457 stripped_proto = onnx.ModelProto() 458 stripped_proto.CopyFrom(proto) 459 del stripped_proto.graph.initializer[:] 460 461 stripped_alt_proto = onnx.ModelProto() 462 stripped_alt_proto.CopyFrom(alt_proto) 463 del stripped_alt_proto.graph.initializer[:] 464 465 # Compare the printable graph representations first 466 errs.requireMultiLineEqual( 467 onnx.helper.printable_graph(stripped_proto.graph), 468 onnx.helper.printable_graph(stripped_alt_proto.graph), 469 ) 470 471 # Compare the actual protobuf text formats now (not 472 # very user-friendly!) 473 errs.requireMultiLineEqual( 474 str(stripped_proto), str(stripped_alt_proto) 475 ) 476 477 # One last ditch effort, using built-in equality on 478 # protobufs 479 errs.requireEqual(stripped_proto, stripped_alt_proto) 480 481 errs.failIfErrs() 482 483 # At this point, we should have figured out why the binary 484 # protobufs differed, and short-circuited out of this code 485 # with a helpful error message. But what if we didn't? 486 # We better still try to give a good error message in this 487 # case. We EXPECT these requires to fail. If they don't, 488 # that is a bug in verify 489 errs.requireEqual(proto, alt_proto) 490 errs.requireEqual( 491 proto_bytes.getvalue(), alt_proto_bytes.getvalue() 492 ) 493 raise AssertionError 494 495 # TODO: test that the traced model also returns the same thing... 496 run_helper(torch_out, args, remained_onnx_input_idx) 497 498 # Factored out so we can avoid one run of the model 499 def run_helper(torch_out, args, remained_onnx_input_idx): 500 onnx_input = backend_args(args) 501 if remained_onnx_input_idx is not None: 502 input_onnx = [] 503 for idx in remained_onnx_input_idx: 504 input_onnx.append(onnx_input[idx]) 505 onnx_input = tuple(input_onnx) 506 backend_out = prepared.run(onnx_input) 507 if isinstance(torch_out, torch.Tensor): 508 torch_out = (torch_out,) 509 torch_out, _ = torch.jit._flatten(torch_out) 510 # NB: onnx backend NEVER returns bare numpy array 511 msg = "ONNX backend returned different results from PyTorch" 512 result_hint = ( 513 "If you are not using trained parameters, a difference in results\n" 514 "could mean that your network is numerically unstable. Otherwise\n" 515 "it indicates a bug in PyTorch/ONNX; please file a bug report." 516 ) 517 with Errors(msg, rtol=rtol, atol=atol) as errs, errs.addErrCtxt( 518 result_hint 519 ): 520 for i, (x, y) in enumerate(zip(torch_out, backend_out)): 521 errs.checkAlmostEqual(x.data.cpu().numpy(), y, f"In output {i}") 522 523 run_helper(torch_out, args, remained_onnx_input_idx) 524 525 if isinstance(test_args, int): 526 for i in range(test_args): 527 run(randomize_args(args), remained_onnx_input_idx) 528 else: 529 for test_arg in test_args: 530 run(test_arg, remained_onnx_input_idx) 531