1import argparse 2import datetime 3import subprocess 4import sys 5import time 6 7import torch 8 9from .runner import get_nn_runners 10 11 12def run_rnn( 13 name, 14 rnn_creator, 15 nloops=5, 16 seqLength=100, 17 numLayers=1, 18 inputSize=512, 19 hiddenSize=512, 20 miniBatch=64, 21 device="cuda", 22 seed=None, 23): 24 def run_iter(modeldef): 25 # Forward 26 forward_output = modeldef.forward(*modeldef.inputs) 27 28 # "loss computation" and backward 29 if modeldef.backward_setup is not None: 30 backward_input = modeldef.backward_setup(forward_output) 31 else: 32 backward_input = forward_output 33 if modeldef.backward is not None: 34 modeldef.backward(*backward_input) 35 36 # "Update" parameters 37 if modeldef.backward is not None: 38 with torch.no_grad(): 39 for param in modeldef.params: 40 param.grad.zero_() 41 torch.cuda.synchronize() 42 43 assert device == "cuda" 44 creator_args = dict( 45 seqLength=seqLength, 46 numLayers=numLayers, 47 inputSize=inputSize, 48 hiddenSize=hiddenSize, 49 miniBatch=miniBatch, 50 device=device, 51 seed=seed, 52 ) 53 modeldef = rnn_creator(**creator_args) 54 55 [run_iter(modeldef) for _ in range(nloops)] 56 57 58def profile( 59 rnns, 60 sleep_between_seconds=1, 61 nloops=5, 62 internal_run=True, # Unused, get rid of this TODO 63 seqLength=100, 64 numLayers=1, 65 inputSize=512, 66 hiddenSize=512, 67 miniBatch=64, 68 device="cuda", 69 seed=None, 70): 71 params = dict( 72 seqLength=seqLength, 73 numLayers=numLayers, 74 inputSize=inputSize, 75 hiddenSize=hiddenSize, 76 miniBatch=miniBatch, 77 device=device, 78 seed=seed, 79 ) 80 for name, creator, context in get_nn_runners(*rnns): 81 with context(): 82 run_rnn(name, creator, nloops, **params) 83 time.sleep(sleep_between_seconds) 84 85 86def system(command): 87 """Returns (return-code, stdout, stderr)""" 88 print(f"[system] {command}") 89 p = subprocess.Popen( 90 command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True 91 ) 92 output, err = p.communicate() 93 rc = p.returncode 94 output = output.decode("ascii") 95 err = err.decode("ascii") 96 return rc, output, err 97 98 99def describe_sizes(**sizes): 100 # seqLength, numLayers, inputSize, hiddenSize, miniBatch 101 return "s{}-l{}-i{}-h{}-b{}".format( 102 sizes["seqLength"], 103 sizes["numLayers"], 104 sizes["inputSize"], 105 sizes["hiddenSize"], 106 sizes["miniBatch"], 107 ) 108 109 110OUTPUT_DIR = "~/profout/" 111 112 113def nvprof_output_filename(rnns, **params): 114 rnn_tag = "-".join(rnns) 115 size_tag = describe_sizes(**params) 116 date_tag = datetime.datetime.now().strftime("%m%d%y-%H%M") 117 return f"{OUTPUT_DIR}prof_{rnn_tag}_{size_tag}_{date_tag}.nvvp" 118 119 120def nvprof(cmd, outpath): 121 return system(f"nvprof -o {outpath} {cmd}") 122 123 124def full_profile(rnns, **args): 125 profile_args = [] 126 for k, v in args.items(): 127 profile_args.append(f"--{k}={v}") 128 profile_args.append(f"--rnns {' '.join(rnns)}") 129 profile_args.append("--internal-run") 130 131 outpath = nvprof_output_filename(rnns, **args) 132 133 cmd = f"{sys.executable} -m fastrnns.profile {' '.join(profile_args)}" 134 rc, stdout, stderr = nvprof(cmd, outpath) 135 if rc != 0: 136 raise RuntimeError(f"stderr: {stderr}\nstdout: {stdout}") 137 138 139if __name__ == "__main__": 140 parser = argparse.ArgumentParser(description="Profile RNNs") 141 142 parser.add_argument("--seqLength", default="100", type=int) 143 parser.add_argument("--numLayers", default="1", type=int) 144 parser.add_argument("--inputSize", default="512", type=int) 145 parser.add_argument("--hiddenSize", default="512", type=int) 146 parser.add_argument("--miniBatch", default="64", type=int) 147 parser.add_argument( 148 "--sleep-between-seconds", "--sleep_between_seconds", default="1", type=int 149 ) 150 parser.add_argument("--nloops", default="5", type=int) 151 152 parser.add_argument("--rnns", nargs="*", help="What to run. cudnn, aten, jit, etc") 153 154 # if internal_run, we actually run the rnns. 155 # if not internal_run, we shell out to nvprof with internal_run=T 156 parser.add_argument( 157 "--internal-run", 158 "--internal_run", 159 default=False, 160 action="store_true", 161 help="Don't use this", 162 ) 163 args = parser.parse_args() 164 if args.rnns is None: 165 args.rnns = ["cudnn", "aten", "jit"] 166 print(args) 167 168 if args.internal_run: 169 profile(**vars(args)) 170 else: 171 full_profile(**vars(args)) 172