1*da0073e9SAndroid Build Coastguard Workerimport argparse 2*da0073e9SAndroid Build Coastguard Workerimport datetime 3*da0073e9SAndroid Build Coastguard Workerimport subprocess 4*da0073e9SAndroid Build Coastguard Workerimport sys 5*da0073e9SAndroid Build Coastguard Workerimport time 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerimport torch 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerfrom .runner import get_nn_runners 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workerdef run_rnn( 13*da0073e9SAndroid Build Coastguard Worker name, 14*da0073e9SAndroid Build Coastguard Worker rnn_creator, 15*da0073e9SAndroid Build Coastguard Worker nloops=5, 16*da0073e9SAndroid Build Coastguard Worker seqLength=100, 17*da0073e9SAndroid Build Coastguard Worker numLayers=1, 18*da0073e9SAndroid Build Coastguard Worker inputSize=512, 19*da0073e9SAndroid Build Coastguard Worker hiddenSize=512, 20*da0073e9SAndroid Build Coastguard Worker miniBatch=64, 21*da0073e9SAndroid Build Coastguard Worker device="cuda", 22*da0073e9SAndroid Build Coastguard Worker seed=None, 23*da0073e9SAndroid Build Coastguard Worker): 24*da0073e9SAndroid Build Coastguard Worker def run_iter(modeldef): 25*da0073e9SAndroid Build Coastguard Worker # Forward 26*da0073e9SAndroid Build Coastguard Worker forward_output = modeldef.forward(*modeldef.inputs) 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker # "loss computation" and backward 29*da0073e9SAndroid Build Coastguard Worker if modeldef.backward_setup is not None: 30*da0073e9SAndroid Build Coastguard Worker backward_input = modeldef.backward_setup(forward_output) 31*da0073e9SAndroid Build Coastguard Worker else: 32*da0073e9SAndroid Build Coastguard Worker backward_input = forward_output 33*da0073e9SAndroid Build Coastguard Worker if modeldef.backward is not None: 34*da0073e9SAndroid Build Coastguard Worker modeldef.backward(*backward_input) 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker # "Update" parameters 37*da0073e9SAndroid Build Coastguard Worker if modeldef.backward is not None: 38*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 39*da0073e9SAndroid Build Coastguard Worker for param in modeldef.params: 40*da0073e9SAndroid Build Coastguard Worker param.grad.zero_() 41*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker assert device == "cuda" 44*da0073e9SAndroid Build Coastguard Worker creator_args = dict( 45*da0073e9SAndroid Build Coastguard Worker seqLength=seqLength, 46*da0073e9SAndroid Build Coastguard Worker numLayers=numLayers, 47*da0073e9SAndroid Build Coastguard Worker inputSize=inputSize, 48*da0073e9SAndroid Build Coastguard Worker hiddenSize=hiddenSize, 49*da0073e9SAndroid Build Coastguard Worker miniBatch=miniBatch, 50*da0073e9SAndroid Build Coastguard Worker device=device, 51*da0073e9SAndroid Build Coastguard Worker seed=seed, 52*da0073e9SAndroid Build Coastguard Worker ) 53*da0073e9SAndroid Build Coastguard Worker modeldef = rnn_creator(**creator_args) 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker [run_iter(modeldef) for _ in range(nloops)] 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Workerdef profile( 59*da0073e9SAndroid Build Coastguard Worker rnns, 60*da0073e9SAndroid Build Coastguard Worker sleep_between_seconds=1, 61*da0073e9SAndroid Build Coastguard Worker nloops=5, 62*da0073e9SAndroid Build Coastguard Worker internal_run=True, # Unused, get rid of this TODO 63*da0073e9SAndroid Build Coastguard Worker seqLength=100, 64*da0073e9SAndroid Build Coastguard Worker numLayers=1, 65*da0073e9SAndroid Build Coastguard Worker inputSize=512, 66*da0073e9SAndroid Build Coastguard Worker hiddenSize=512, 67*da0073e9SAndroid Build Coastguard Worker miniBatch=64, 68*da0073e9SAndroid Build Coastguard Worker device="cuda", 69*da0073e9SAndroid Build Coastguard Worker seed=None, 70*da0073e9SAndroid Build Coastguard Worker): 71*da0073e9SAndroid Build Coastguard Worker params = dict( 72*da0073e9SAndroid Build Coastguard Worker seqLength=seqLength, 73*da0073e9SAndroid Build Coastguard Worker numLayers=numLayers, 74*da0073e9SAndroid Build Coastguard Worker inputSize=inputSize, 75*da0073e9SAndroid Build Coastguard Worker hiddenSize=hiddenSize, 76*da0073e9SAndroid Build Coastguard Worker miniBatch=miniBatch, 77*da0073e9SAndroid Build Coastguard Worker device=device, 78*da0073e9SAndroid Build Coastguard Worker seed=seed, 79*da0073e9SAndroid Build Coastguard Worker ) 80*da0073e9SAndroid Build Coastguard Worker for name, creator, context in get_nn_runners(*rnns): 81*da0073e9SAndroid Build Coastguard Worker with context(): 82*da0073e9SAndroid Build Coastguard Worker run_rnn(name, creator, nloops, **params) 83*da0073e9SAndroid Build Coastguard Worker time.sleep(sleep_between_seconds) 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Workerdef system(command): 87*da0073e9SAndroid Build Coastguard Worker """Returns (return-code, stdout, stderr)""" 88*da0073e9SAndroid Build Coastguard Worker print(f"[system] {command}") 89*da0073e9SAndroid Build Coastguard Worker p = subprocess.Popen( 90*da0073e9SAndroid Build Coastguard Worker command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True 91*da0073e9SAndroid Build Coastguard Worker ) 92*da0073e9SAndroid Build Coastguard Worker output, err = p.communicate() 93*da0073e9SAndroid Build Coastguard Worker rc = p.returncode 94*da0073e9SAndroid Build Coastguard Worker output = output.decode("ascii") 95*da0073e9SAndroid Build Coastguard Worker err = err.decode("ascii") 96*da0073e9SAndroid Build Coastguard Worker return rc, output, err 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Workerdef describe_sizes(**sizes): 100*da0073e9SAndroid Build Coastguard Worker # seqLength, numLayers, inputSize, hiddenSize, miniBatch 101*da0073e9SAndroid Build Coastguard Worker return "s{}-l{}-i{}-h{}-b{}".format( 102*da0073e9SAndroid Build Coastguard Worker sizes["seqLength"], 103*da0073e9SAndroid Build Coastguard Worker sizes["numLayers"], 104*da0073e9SAndroid Build Coastguard Worker sizes["inputSize"], 105*da0073e9SAndroid Build Coastguard Worker sizes["hiddenSize"], 106*da0073e9SAndroid Build Coastguard Worker sizes["miniBatch"], 107*da0073e9SAndroid Build Coastguard Worker ) 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard WorkerOUTPUT_DIR = "~/profout/" 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Workerdef nvprof_output_filename(rnns, **params): 114*da0073e9SAndroid Build Coastguard Worker rnn_tag = "-".join(rnns) 115*da0073e9SAndroid Build Coastguard Worker size_tag = describe_sizes(**params) 116*da0073e9SAndroid Build Coastguard Worker date_tag = datetime.datetime.now().strftime("%m%d%y-%H%M") 117*da0073e9SAndroid Build Coastguard Worker return f"{OUTPUT_DIR}prof_{rnn_tag}_{size_tag}_{date_tag}.nvvp" 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Workerdef nvprof(cmd, outpath): 121*da0073e9SAndroid Build Coastguard Worker return system(f"nvprof -o {outpath} {cmd}") 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Workerdef full_profile(rnns, **args): 125*da0073e9SAndroid Build Coastguard Worker profile_args = [] 126*da0073e9SAndroid Build Coastguard Worker for k, v in args.items(): 127*da0073e9SAndroid Build Coastguard Worker profile_args.append(f"--{k}={v}") 128*da0073e9SAndroid Build Coastguard Worker profile_args.append(f"--rnns {' '.join(rnns)}") 129*da0073e9SAndroid Build Coastguard Worker profile_args.append("--internal-run") 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker outpath = nvprof_output_filename(rnns, **args) 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker cmd = f"{sys.executable} -m fastrnns.profile {' '.join(profile_args)}" 134*da0073e9SAndroid Build Coastguard Worker rc, stdout, stderr = nvprof(cmd, outpath) 135*da0073e9SAndroid Build Coastguard Worker if rc != 0: 136*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"stderr: {stderr}\nstdout: {stdout}") 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker 139*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 140*da0073e9SAndroid Build Coastguard Worker parser = argparse.ArgumentParser(description="Profile RNNs") 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--seqLength", default="100", type=int) 143*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--numLayers", default="1", type=int) 144*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--inputSize", default="512", type=int) 145*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--hiddenSize", default="512", type=int) 146*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--miniBatch", default="64", type=int) 147*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 148*da0073e9SAndroid Build Coastguard Worker "--sleep-between-seconds", "--sleep_between_seconds", default="1", type=int 149*da0073e9SAndroid Build Coastguard Worker ) 150*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--nloops", default="5", type=int) 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--rnns", nargs="*", help="What to run. cudnn, aten, jit, etc") 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker # if internal_run, we actually run the rnns. 155*da0073e9SAndroid Build Coastguard Worker # if not internal_run, we shell out to nvprof with internal_run=T 156*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 157*da0073e9SAndroid Build Coastguard Worker "--internal-run", 158*da0073e9SAndroid Build Coastguard Worker "--internal_run", 159*da0073e9SAndroid Build Coastguard Worker default=False, 160*da0073e9SAndroid Build Coastguard Worker action="store_true", 161*da0073e9SAndroid Build Coastguard Worker help="Don't use this", 162*da0073e9SAndroid Build Coastguard Worker ) 163*da0073e9SAndroid Build Coastguard Worker args = parser.parse_args() 164*da0073e9SAndroid Build Coastguard Worker if args.rnns is None: 165*da0073e9SAndroid Build Coastguard Worker args.rnns = ["cudnn", "aten", "jit"] 166*da0073e9SAndroid Build Coastguard Worker print(args) 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker if args.internal_run: 169*da0073e9SAndroid Build Coastguard Worker profile(**vars(args)) 170*da0073e9SAndroid Build Coastguard Worker else: 171*da0073e9SAndroid Build Coastguard Worker full_profile(**vars(args)) 172