1""" 2This script uses linear programming to analyze outputs of triton mm config tuning. 3To generate output that can be fed into this script set the env varTORCHINDUCTOR_MM_LOGGING_FILE. 4 5That file can be fed into this script to generate the minimizes total, weighted matmul time as a function of allowed templates. 6""" 7import json 8 9import click 10import pulp 11 12 13def parse_log_file(file_path): 14 with open(file_path) as f: 15 logs = json.load(f) 16 17 occurrence_count = {} 18 benchmark_logs = {} 19 20 # Parse the logs 21 for entry in logs: 22 if "invoke" in entry: 23 shape = entry["invoke"] 24 if shape not in occurrence_count: 25 occurrence_count[shape] = 0 26 occurrence_count[shape] += 1 27 else: 28 for shape, timings in entry.items(): 29 if shape not in benchmark_logs: 30 benchmark_logs[shape] = [] 31 benchmark_logs[shape].extend(timings) 32 33 return occurrence_count, benchmark_logs 34 35 36def optimize_templates(N, occurrence_count, benchmark_logs, verbose=False): 37 # Set of all possible Triton templates keyed by their attributes 38 triton_templates = set() 39 for timings in benchmark_logs.values(): 40 for timing in timings: 41 if timing["type"] == "triton": 42 triton_templates.add( 43 ( 44 timing["BLOCK_M"], 45 timing["BLOCK_N"], 46 timing["BLOCK_K"], 47 timing["num_stages"], 48 timing["num_warps"], 49 ) 50 ) 51 52 # Print the initial data 53 if verbose: 54 print("Occurrence Count:", occurrence_count) 55 print("Triton Templates:", triton_templates) 56 57 # Create a dictionary to store template selection variables 58 template_vars = { 59 template: pulp.LpVariable(f"Template_{template}", 0, 1, pulp.LpBinary) 60 for template in triton_templates 61 } 62 63 # Variables to select specific timing option for each shape 64 selection_vars = { 65 (shape, "cublas"): pulp.LpVariable( 66 f"Select_{shape}_cublas", 0, 1, pulp.LpBinary 67 ) 68 for shape in occurrence_count 69 } 70 for shape in occurrence_count: 71 for template in triton_templates: 72 selection_vars[(shape, template)] = pulp.LpVariable( 73 f"Select_{shape}_{template}", 0, 1, pulp.LpBinary 74 ) 75 76 # Variables for the total time for each shape 77 min_time_vars = pulp.LpVariable.dicts( 78 "MinTime", occurrence_count.keys(), 0, None, pulp.LpContinuous 79 ) 80 81 # Define the problem 82 prob = pulp.LpProblem("MatrixMultiplicationOptimization", pulp.LpMinimize) 83 84 # Objective: Minimize the weighted total time 85 prob += pulp.lpSum( 86 [occurrence_count[shape] * min_time_vars[shape] for shape in occurrence_count] 87 ) 88 89 # Constraints to select exactly N templates 90 prob += pulp.lpSum([template_vars[template] for template in triton_templates]) == N 91 92 # Store triton options per shape for debugging 93 triton_options_per_shape = {} 94 95 # Constraints for the total time for each shape 96 for shape in occurrence_count: 97 # Get cuBLAS time 98 cublas_times = [ 99 timing["time"] 100 for timing in benchmark_logs[shape] 101 if timing["type"] == "cublas" 102 ] 103 min_cublas_time = min(cublas_times) 104 105 # Collect Triton options 106 triton_options = [] 107 for template in triton_templates: 108 triton_times = [ 109 timing["time"] 110 for timing in benchmark_logs[shape] 111 if timing["type"] == "triton" 112 and ( 113 timing["BLOCK_M"], 114 timing["BLOCK_N"], 115 timing["BLOCK_K"], 116 timing["num_stages"], 117 timing["num_warps"], 118 ) 119 == template 120 ] 121 if triton_times: 122 min_triton_time = min(triton_times) 123 triton_options.append((min_triton_time, template)) 124 125 # Save triton options for debugging 126 triton_options_per_shape[shape] = triton_options 127 128 # Ensure exactly one timing option is selected for each shape 129 prob += ( 130 pulp.lpSum( 131 [selection_vars[(shape, "cublas")]] 132 + [ 133 selection_vars[(shape, template)] 134 for triton_time, template in triton_options 135 ] 136 ) 137 == 1 138 ) 139 140 # Ensure min_time_vars[shape] matches the selected timing option 141 prob += min_time_vars[shape] == ( 142 selection_vars[(shape, "cublas")] * min_cublas_time 143 + pulp.lpSum( 144 [ 145 selection_vars[(shape, template)] * triton_time 146 for triton_time, template in triton_options 147 ] 148 ) 149 ) 150 151 # Ensure Triton templates can only be selected if they are included in the N allowed templates 152 for triton_time, template in triton_options: 153 prob += selection_vars[(shape, template)] <= template_vars[template] 154 155 # Print the constraints 156 if verbose: 157 print("Constraints:") 158 for constraint in prob.constraints.values(): 159 print(constraint) 160 161 # Solve the problem with suppressed output 162 prob.solve(pulp.PULP_CBC_CMD(msg=False)) 163 164 # Output the selected templates and their configurations 165 selected_templates = [ 166 template 167 for template in triton_templates 168 if pulp.value(template_vars[template]) == 1 169 ] 170 total_time = sum( 171 pulp.value(min_time_vars[shape]) * occurrence_count[shape] 172 for shape in occurrence_count 173 ) 174 175 # Print the values of the decision variables after solving 176 if verbose: 177 print("Decision Variable Values:") 178 for var in prob.variables(): 179 print(f"{var.name} = {var.varValue}") 180 181 # # Debugging information 182 if verbose: 183 for shape in occurrence_count: 184 print(f"Shape: {shape}") 185 print(f" Min Time: {pulp.value(min_time_vars[shape])}") 186 print(f" Occurrences: {occurrence_count[shape]}") 187 print( 188 f" Min CuBLAS Time: {min_cublas_time} Selected: {pulp.value(selection_vars[(shape, 'cublas')])}" 189 ) 190 for triton_time, template in triton_options_per_shape[shape]: 191 print( 192 f" Triton Template: {template} Time: {triton_time} Selected: {pulp.value(selection_vars[(shape, template)])}" 193 ) 194 195 return selected_templates, total_time 196 197 198# Main code to parse the log file and optimize templates 199@click.command() 200@click.argument("filename") 201@click.option("--min-templates", default=0, help="Minimum number of templates.") 202@click.option("--max-templates", default=10, help="Maximum number of templates.") 203@click.option("--verbose", is_flag=True, help="Enable verbose output.") 204def main(filename, min_templates, max_templates, verbose): 205 occurrence_count, benchmark_logs = parse_log_file(filename) 206 times = [] 207 for N in range(min_templates, max_templates + 1): 208 selected_templates, total_time = optimize_templates( 209 N, occurrence_count, benchmark_logs, verbose 210 ) 211 print(f"N = {N}") 212 print(f"Selected Templates: {selected_templates}") 213 print(f"Total Weighted Time: {total_time}") 214 times.append(total_time) 215 print(times) 216 217 218if __name__ == "__main__": 219 main() 220