import argparse import torch import torch.nn as nn from .factory import pytorch_lstm_creator, varlen_pytorch_lstm_creator from .runner import get_nn_runners def barf(): import pdb pdb.set_trace() def assertEqual(tensor, expected, threshold=0.001): if isinstance(tensor, (list, tuple)): for t, e in zip(tensor, expected): assertEqual(t, e) else: if (tensor - expected).abs().max() > threshold: barf() def filter_requires_grad(tensors): return [t for t in tensors if t.requires_grad] def test_rnns( experim_creator, control_creator, check_grad=True, verbose=False, seqLength=100, numLayers=1, inputSize=512, hiddenSize=512, miniBatch=64, device="cuda", seed=17, ): creator_args = dict( seqLength=seqLength, numLayers=numLayers, inputSize=inputSize, hiddenSize=hiddenSize, miniBatch=miniBatch, device=device, seed=seed, ) print("Setting up...") control = control_creator(**creator_args) experim = experim_creator(**creator_args) # Precondition assertEqual(experim.inputs, control.inputs) assertEqual(experim.params, control.params) print("Checking outputs...") control_outputs = control.forward(*control.inputs) experim_outputs = experim.forward(*experim.inputs) assertEqual(experim_outputs, control_outputs) print("Checking grads...") assert control.backward_setup is not None assert experim.backward_setup is not None assert control.backward is not None assert experim.backward is not None control_backward_inputs = control.backward_setup(control_outputs, seed) experim_backward_inputs = experim.backward_setup(experim_outputs, seed) control.backward(*control_backward_inputs) experim.backward(*experim_backward_inputs) control_grads = [p.grad for p in control.params] experim_grads = [p.grad for p in experim.params] assertEqual(experim_grads, control_grads) if verbose: print(experim.forward.graph_for(*experim.inputs)) print() def test_vl_py(**test_args): # XXX: This compares vl_py with vl_lstm. # It's done this way because those two don't give the same outputs so # the result isn't an apples-to-apples comparison right now. control_creator = varlen_pytorch_lstm_creator name, experim_creator, context = get_nn_runners("vl_py")[0] with context(): print(f"testing {name}...") creator_keys = [ "seqLength", "numLayers", "inputSize", "hiddenSize", "miniBatch", "device", "seed", ] creator_args = {key: test_args[key] for key in creator_keys} print("Setting up...") control = control_creator(**creator_args) experim = experim_creator(**creator_args) # Precondition assertEqual(experim.inputs, control.inputs[:2]) assertEqual(experim.params, control.params) print("Checking outputs...") control_out, control_hiddens = control.forward(*control.inputs) control_hx, control_cx = control_hiddens experim_out, experim_hiddens = experim.forward(*experim.inputs) experim_hx, experim_cx = experim_hiddens experim_padded = nn.utils.rnn.pad_sequence(experim_out).squeeze(-2) assertEqual(experim_padded, control_out) assertEqual(torch.cat(experim_hx, dim=1), control_hx) assertEqual(torch.cat(experim_cx, dim=1), control_cx) print("Checking grads...") assert control.backward_setup is not None assert experim.backward_setup is not None assert control.backward is not None assert experim.backward is not None control_backward_inputs = control.backward_setup( (control_out, control_hiddens), test_args["seed"] ) experim_backward_inputs = experim.backward_setup( (experim_out, experim_hiddens), test_args["seed"] ) control.backward(*control_backward_inputs) experim.backward(*experim_backward_inputs) control_grads = [p.grad for p in control.params] experim_grads = [p.grad for p in experim.params] assertEqual(experim_grads, control_grads) if test_args["verbose"]: print(experim.forward.graph_for(*experim.inputs)) print() if __name__ == "__main__": parser = argparse.ArgumentParser(description="Test lstm correctness") parser.add_argument("--seqLength", default="100", type=int) parser.add_argument("--numLayers", default="1", type=int) parser.add_argument("--inputSize", default="512", type=int) parser.add_argument("--hiddenSize", default="512", type=int) parser.add_argument("--miniBatch", default="64", type=int) parser.add_argument("--device", default="cuda", type=str) parser.add_argument("--check-grad", "--check_grad", default="True", type=bool) parser.add_argument("--variable-lstms", "--variable_lstms", action="store_true") parser.add_argument("--seed", default="17", type=int) parser.add_argument("--verbose", action="store_true") parser.add_argument("--rnns", nargs="*", help="What to run. jit_premul, jit, etc") args = parser.parse_args() if args.rnns is None: args.rnns = ["jit_premul", "jit"] print(args) if "cuda" in args.device: assert torch.cuda.is_available() rnn_runners = get_nn_runners(*args.rnns) should_test_varlen_lstms = args.variable_lstms test_args = vars(args) del test_args["rnns"] del test_args["variable_lstms"] if should_test_varlen_lstms: test_vl_py(**test_args) for name, creator, context in rnn_runners: with context(): print(f"testing {name}...") test_rnns(creator, pytorch_lstm_creator, **test_args)