xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/microbenchmarks/analyze_templates.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""
2This script uses linear programming to analyze outputs of triton mm config tuning.
3To generate output that can be fed into this script set the env varTORCHINDUCTOR_MM_LOGGING_FILE.
4
5That file can be fed into this script to generate the minimizes total, weighted matmul time as a function of allowed templates.
6"""
7import json
8
9import click
10import pulp
11
12
13def parse_log_file(file_path):
14    with open(file_path) as f:
15        logs = json.load(f)
16
17    occurrence_count = {}
18    benchmark_logs = {}
19
20    # Parse the logs
21    for entry in logs:
22        if "invoke" in entry:
23            shape = entry["invoke"]
24            if shape not in occurrence_count:
25                occurrence_count[shape] = 0
26            occurrence_count[shape] += 1
27        else:
28            for shape, timings in entry.items():
29                if shape not in benchmark_logs:
30                    benchmark_logs[shape] = []
31                benchmark_logs[shape].extend(timings)
32
33    return occurrence_count, benchmark_logs
34
35
36def optimize_templates(N, occurrence_count, benchmark_logs, verbose=False):
37    # Set of all possible Triton templates keyed by their attributes
38    triton_templates = set()
39    for timings in benchmark_logs.values():
40        for timing in timings:
41            if timing["type"] == "triton":
42                triton_templates.add(
43                    (
44                        timing["BLOCK_M"],
45                        timing["BLOCK_N"],
46                        timing["BLOCK_K"],
47                        timing["num_stages"],
48                        timing["num_warps"],
49                    )
50                )
51
52    # Print the initial data
53    if verbose:
54        print("Occurrence Count:", occurrence_count)
55        print("Triton Templates:", triton_templates)
56
57    # Create a dictionary to store template selection variables
58    template_vars = {
59        template: pulp.LpVariable(f"Template_{template}", 0, 1, pulp.LpBinary)
60        for template in triton_templates
61    }
62
63    # Variables to select specific timing option for each shape
64    selection_vars = {
65        (shape, "cublas"): pulp.LpVariable(
66            f"Select_{shape}_cublas", 0, 1, pulp.LpBinary
67        )
68        for shape in occurrence_count
69    }
70    for shape in occurrence_count:
71        for template in triton_templates:
72            selection_vars[(shape, template)] = pulp.LpVariable(
73                f"Select_{shape}_{template}", 0, 1, pulp.LpBinary
74            )
75
76    # Variables for the total time for each shape
77    min_time_vars = pulp.LpVariable.dicts(
78        "MinTime", occurrence_count.keys(), 0, None, pulp.LpContinuous
79    )
80
81    # Define the problem
82    prob = pulp.LpProblem("MatrixMultiplicationOptimization", pulp.LpMinimize)
83
84    # Objective: Minimize the weighted total time
85    prob += pulp.lpSum(
86        [occurrence_count[shape] * min_time_vars[shape] for shape in occurrence_count]
87    )
88
89    # Constraints to select exactly N templates
90    prob += pulp.lpSum([template_vars[template] for template in triton_templates]) == N
91
92    # Store triton options per shape for debugging
93    triton_options_per_shape = {}
94
95    # Constraints for the total time for each shape
96    for shape in occurrence_count:
97        # Get cuBLAS time
98        cublas_times = [
99            timing["time"]
100            for timing in benchmark_logs[shape]
101            if timing["type"] == "cublas"
102        ]
103        min_cublas_time = min(cublas_times)
104
105        # Collect Triton options
106        triton_options = []
107        for template in triton_templates:
108            triton_times = [
109                timing["time"]
110                for timing in benchmark_logs[shape]
111                if timing["type"] == "triton"
112                and (
113                    timing["BLOCK_M"],
114                    timing["BLOCK_N"],
115                    timing["BLOCK_K"],
116                    timing["num_stages"],
117                    timing["num_warps"],
118                )
119                == template
120            ]
121            if triton_times:
122                min_triton_time = min(triton_times)
123                triton_options.append((min_triton_time, template))
124
125        # Save triton options for debugging
126        triton_options_per_shape[shape] = triton_options
127
128        # Ensure exactly one timing option is selected for each shape
129        prob += (
130            pulp.lpSum(
131                [selection_vars[(shape, "cublas")]]
132                + [
133                    selection_vars[(shape, template)]
134                    for triton_time, template in triton_options
135                ]
136            )
137            == 1
138        )
139
140        # Ensure min_time_vars[shape] matches the selected timing option
141        prob += min_time_vars[shape] == (
142            selection_vars[(shape, "cublas")] * min_cublas_time
143            + pulp.lpSum(
144                [
145                    selection_vars[(shape, template)] * triton_time
146                    for triton_time, template in triton_options
147                ]
148            )
149        )
150
151        # Ensure Triton templates can only be selected if they are included in the N allowed templates
152        for triton_time, template in triton_options:
153            prob += selection_vars[(shape, template)] <= template_vars[template]
154
155    # Print the constraints
156    if verbose:
157        print("Constraints:")
158        for constraint in prob.constraints.values():
159            print(constraint)
160
161    # Solve the problem with suppressed output
162    prob.solve(pulp.PULP_CBC_CMD(msg=False))
163
164    # Output the selected templates and their configurations
165    selected_templates = [
166        template
167        for template in triton_templates
168        if pulp.value(template_vars[template]) == 1
169    ]
170    total_time = sum(
171        pulp.value(min_time_vars[shape]) * occurrence_count[shape]
172        for shape in occurrence_count
173    )
174
175    # Print the values of the decision variables after solving
176    if verbose:
177        print("Decision Variable Values:")
178        for var in prob.variables():
179            print(f"{var.name} = {var.varValue}")
180
181    # # Debugging information
182    if verbose:
183        for shape in occurrence_count:
184            print(f"Shape: {shape}")
185            print(f"  Min Time: {pulp.value(min_time_vars[shape])}")
186            print(f"  Occurrences: {occurrence_count[shape]}")
187            print(
188                f"  Min CuBLAS Time: {min_cublas_time} Selected: {pulp.value(selection_vars[(shape, 'cublas')])}"
189            )
190            for triton_time, template in triton_options_per_shape[shape]:
191                print(
192                    f"  Triton Template: {template} Time: {triton_time} Selected: {pulp.value(selection_vars[(shape, template)])}"
193                )
194
195    return selected_templates, total_time
196
197
198# Main code to parse the log file and optimize templates
199@click.command()
200@click.argument("filename")
201@click.option("--min-templates", default=0, help="Minimum number of templates.")
202@click.option("--max-templates", default=10, help="Maximum number of templates.")
203@click.option("--verbose", is_flag=True, help="Enable verbose output.")
204def main(filename, min_templates, max_templates, verbose):
205    occurrence_count, benchmark_logs = parse_log_file(filename)
206    times = []
207    for N in range(min_templates, max_templates + 1):
208        selected_templates, total_time = optimize_templates(
209            N, occurrence_count, benchmark_logs, verbose
210        )
211        print(f"N = {N}")
212        print(f"Selected Templates: {selected_templates}")
213        print(f"Total Weighted Time: {total_time}")
214        times.append(total_time)
215    print(times)
216
217
218if __name__ == "__main__":
219    main()
220