xref: /aosp_15_r20/external/pytorch/benchmarks/fastrnns/profile.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport argparse
2*da0073e9SAndroid Build Coastguard Workerimport datetime
3*da0073e9SAndroid Build Coastguard Workerimport subprocess
4*da0073e9SAndroid Build Coastguard Workerimport sys
5*da0073e9SAndroid Build Coastguard Workerimport time
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerimport torch
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerfrom .runner import get_nn_runners
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerdef run_rnn(
13*da0073e9SAndroid Build Coastguard Worker    name,
14*da0073e9SAndroid Build Coastguard Worker    rnn_creator,
15*da0073e9SAndroid Build Coastguard Worker    nloops=5,
16*da0073e9SAndroid Build Coastguard Worker    seqLength=100,
17*da0073e9SAndroid Build Coastguard Worker    numLayers=1,
18*da0073e9SAndroid Build Coastguard Worker    inputSize=512,
19*da0073e9SAndroid Build Coastguard Worker    hiddenSize=512,
20*da0073e9SAndroid Build Coastguard Worker    miniBatch=64,
21*da0073e9SAndroid Build Coastguard Worker    device="cuda",
22*da0073e9SAndroid Build Coastguard Worker    seed=None,
23*da0073e9SAndroid Build Coastguard Worker):
24*da0073e9SAndroid Build Coastguard Worker    def run_iter(modeldef):
25*da0073e9SAndroid Build Coastguard Worker        # Forward
26*da0073e9SAndroid Build Coastguard Worker        forward_output = modeldef.forward(*modeldef.inputs)
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker        # "loss computation" and backward
29*da0073e9SAndroid Build Coastguard Worker        if modeldef.backward_setup is not None:
30*da0073e9SAndroid Build Coastguard Worker            backward_input = modeldef.backward_setup(forward_output)
31*da0073e9SAndroid Build Coastguard Worker        else:
32*da0073e9SAndroid Build Coastguard Worker            backward_input = forward_output
33*da0073e9SAndroid Build Coastguard Worker        if modeldef.backward is not None:
34*da0073e9SAndroid Build Coastguard Worker            modeldef.backward(*backward_input)
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker        # "Update" parameters
37*da0073e9SAndroid Build Coastguard Worker        if modeldef.backward is not None:
38*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
39*da0073e9SAndroid Build Coastguard Worker                for param in modeldef.params:
40*da0073e9SAndroid Build Coastguard Worker                    param.grad.zero_()
41*da0073e9SAndroid Build Coastguard Worker        torch.cuda.synchronize()
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker    assert device == "cuda"
44*da0073e9SAndroid Build Coastguard Worker    creator_args = dict(
45*da0073e9SAndroid Build Coastguard Worker        seqLength=seqLength,
46*da0073e9SAndroid Build Coastguard Worker        numLayers=numLayers,
47*da0073e9SAndroid Build Coastguard Worker        inputSize=inputSize,
48*da0073e9SAndroid Build Coastguard Worker        hiddenSize=hiddenSize,
49*da0073e9SAndroid Build Coastguard Worker        miniBatch=miniBatch,
50*da0073e9SAndroid Build Coastguard Worker        device=device,
51*da0073e9SAndroid Build Coastguard Worker        seed=seed,
52*da0073e9SAndroid Build Coastguard Worker    )
53*da0073e9SAndroid Build Coastguard Worker    modeldef = rnn_creator(**creator_args)
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker    [run_iter(modeldef) for _ in range(nloops)]
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Workerdef profile(
59*da0073e9SAndroid Build Coastguard Worker    rnns,
60*da0073e9SAndroid Build Coastguard Worker    sleep_between_seconds=1,
61*da0073e9SAndroid Build Coastguard Worker    nloops=5,
62*da0073e9SAndroid Build Coastguard Worker    internal_run=True,  # Unused, get rid of this TODO
63*da0073e9SAndroid Build Coastguard Worker    seqLength=100,
64*da0073e9SAndroid Build Coastguard Worker    numLayers=1,
65*da0073e9SAndroid Build Coastguard Worker    inputSize=512,
66*da0073e9SAndroid Build Coastguard Worker    hiddenSize=512,
67*da0073e9SAndroid Build Coastguard Worker    miniBatch=64,
68*da0073e9SAndroid Build Coastguard Worker    device="cuda",
69*da0073e9SAndroid Build Coastguard Worker    seed=None,
70*da0073e9SAndroid Build Coastguard Worker):
71*da0073e9SAndroid Build Coastguard Worker    params = dict(
72*da0073e9SAndroid Build Coastguard Worker        seqLength=seqLength,
73*da0073e9SAndroid Build Coastguard Worker        numLayers=numLayers,
74*da0073e9SAndroid Build Coastguard Worker        inputSize=inputSize,
75*da0073e9SAndroid Build Coastguard Worker        hiddenSize=hiddenSize,
76*da0073e9SAndroid Build Coastguard Worker        miniBatch=miniBatch,
77*da0073e9SAndroid Build Coastguard Worker        device=device,
78*da0073e9SAndroid Build Coastguard Worker        seed=seed,
79*da0073e9SAndroid Build Coastguard Worker    )
80*da0073e9SAndroid Build Coastguard Worker    for name, creator, context in get_nn_runners(*rnns):
81*da0073e9SAndroid Build Coastguard Worker        with context():
82*da0073e9SAndroid Build Coastguard Worker            run_rnn(name, creator, nloops, **params)
83*da0073e9SAndroid Build Coastguard Worker            time.sleep(sleep_between_seconds)
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Workerdef system(command):
87*da0073e9SAndroid Build Coastguard Worker    """Returns (return-code, stdout, stderr)"""
88*da0073e9SAndroid Build Coastguard Worker    print(f"[system] {command}")
89*da0073e9SAndroid Build Coastguard Worker    p = subprocess.Popen(
90*da0073e9SAndroid Build Coastguard Worker        command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True
91*da0073e9SAndroid Build Coastguard Worker    )
92*da0073e9SAndroid Build Coastguard Worker    output, err = p.communicate()
93*da0073e9SAndroid Build Coastguard Worker    rc = p.returncode
94*da0073e9SAndroid Build Coastguard Worker    output = output.decode("ascii")
95*da0073e9SAndroid Build Coastguard Worker    err = err.decode("ascii")
96*da0073e9SAndroid Build Coastguard Worker    return rc, output, err
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Workerdef describe_sizes(**sizes):
100*da0073e9SAndroid Build Coastguard Worker    # seqLength, numLayers, inputSize, hiddenSize, miniBatch
101*da0073e9SAndroid Build Coastguard Worker    return "s{}-l{}-i{}-h{}-b{}".format(
102*da0073e9SAndroid Build Coastguard Worker        sizes["seqLength"],
103*da0073e9SAndroid Build Coastguard Worker        sizes["numLayers"],
104*da0073e9SAndroid Build Coastguard Worker        sizes["inputSize"],
105*da0073e9SAndroid Build Coastguard Worker        sizes["hiddenSize"],
106*da0073e9SAndroid Build Coastguard Worker        sizes["miniBatch"],
107*da0073e9SAndroid Build Coastguard Worker    )
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard WorkerOUTPUT_DIR = "~/profout/"
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Workerdef nvprof_output_filename(rnns, **params):
114*da0073e9SAndroid Build Coastguard Worker    rnn_tag = "-".join(rnns)
115*da0073e9SAndroid Build Coastguard Worker    size_tag = describe_sizes(**params)
116*da0073e9SAndroid Build Coastguard Worker    date_tag = datetime.datetime.now().strftime("%m%d%y-%H%M")
117*da0073e9SAndroid Build Coastguard Worker    return f"{OUTPUT_DIR}prof_{rnn_tag}_{size_tag}_{date_tag}.nvvp"
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Workerdef nvprof(cmd, outpath):
121*da0073e9SAndroid Build Coastguard Worker    return system(f"nvprof -o {outpath} {cmd}")
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Workerdef full_profile(rnns, **args):
125*da0073e9SAndroid Build Coastguard Worker    profile_args = []
126*da0073e9SAndroid Build Coastguard Worker    for k, v in args.items():
127*da0073e9SAndroid Build Coastguard Worker        profile_args.append(f"--{k}={v}")
128*da0073e9SAndroid Build Coastguard Worker    profile_args.append(f"--rnns {' '.join(rnns)}")
129*da0073e9SAndroid Build Coastguard Worker    profile_args.append("--internal-run")
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker    outpath = nvprof_output_filename(rnns, **args)
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker    cmd = f"{sys.executable} -m fastrnns.profile {' '.join(profile_args)}"
134*da0073e9SAndroid Build Coastguard Worker    rc, stdout, stderr = nvprof(cmd, outpath)
135*da0073e9SAndroid Build Coastguard Worker    if rc != 0:
136*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(f"stderr: {stderr}\nstdout: {stdout}")
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker
139*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
140*da0073e9SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser(description="Profile RNNs")
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--seqLength", default="100", type=int)
143*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--numLayers", default="1", type=int)
144*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--inputSize", default="512", type=int)
145*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--hiddenSize", default="512", type=int)
146*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--miniBatch", default="64", type=int)
147*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
148*da0073e9SAndroid Build Coastguard Worker        "--sleep-between-seconds", "--sleep_between_seconds", default="1", type=int
149*da0073e9SAndroid Build Coastguard Worker    )
150*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--nloops", default="5", type=int)
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--rnns", nargs="*", help="What to run. cudnn, aten, jit, etc")
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker    # if internal_run, we actually run the rnns.
155*da0073e9SAndroid Build Coastguard Worker    # if not internal_run, we shell out to nvprof with internal_run=T
156*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
157*da0073e9SAndroid Build Coastguard Worker        "--internal-run",
158*da0073e9SAndroid Build Coastguard Worker        "--internal_run",
159*da0073e9SAndroid Build Coastguard Worker        default=False,
160*da0073e9SAndroid Build Coastguard Worker        action="store_true",
161*da0073e9SAndroid Build Coastguard Worker        help="Don't use this",
162*da0073e9SAndroid Build Coastguard Worker    )
163*da0073e9SAndroid Build Coastguard Worker    args = parser.parse_args()
164*da0073e9SAndroid Build Coastguard Worker    if args.rnns is None:
165*da0073e9SAndroid Build Coastguard Worker        args.rnns = ["cudnn", "aten", "jit"]
166*da0073e9SAndroid Build Coastguard Worker    print(args)
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker    if args.internal_run:
169*da0073e9SAndroid Build Coastguard Worker        profile(**vars(args))
170*da0073e9SAndroid Build Coastguard Worker    else:
171*da0073e9SAndroid Build Coastguard Worker        full_profile(**vars(args))
172