1# mypy: allow-untyped-defs 2from typing import Sequence, Union 3 4from ..scheduler import ( 5 BaseSchedulerNode, 6 BaseScheduling, 7 FusedSchedulerNode, 8 Scheduler, 9 SchedulerNode, 10) 11from .cuda.cuda_cpp_scheduling import CUDACPPScheduling 12from .rocm.rocm_cpp_scheduling import ROCmCPPScheduling 13from .triton import TritonScheduling 14 15 16class CUDACombinedScheduling(BaseScheduling): 17 """ 18 Scheduler for CUDA Kernels, which delegates calls as appropriate 19 to the CUDA-C++ and Triton Schedulers, which both work for CUDA devices 20 and use a unified-wrapper for codegen. 21 22 If Scheduling code needs to be specialized for the case of mixed Triton / CUDA C++ code, 23 this would also be the place to do it. 24 """ 25 26 def __init__(self, scheduler: Scheduler) -> None: 27 super().__init__() 28 self._scheduler = scheduler 29 self._triton_scheduling = TritonScheduling(scheduler) 30 self._cuda_cpp_scheduling = CUDACPPScheduling(scheduler) 31 self._rocm_cpp_scheduling = ROCmCPPScheduling(scheduler) 32 33 def get_backend_features(self, device): 34 return self._triton_scheduling.get_backend_features(device) 35 36 def choose_node_backend(self, node: BaseSchedulerNode) -> BaseScheduling: 37 if self._cuda_cpp_scheduling.is_cuda_cpp_template(node): 38 return self._cuda_cpp_scheduling 39 if self._rocm_cpp_scheduling.is_rocm_cpp_template(node): 40 return self._rocm_cpp_scheduling 41 return self._triton_scheduling 42 43 def can_fuse_vertical(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): 44 if self._cuda_cpp_scheduling.can_fuse_vertical(node1, node2): 45 return True 46 return self._triton_scheduling.can_fuse_vertical(node1, node2) 47 48 def can_fuse_horizontal(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): 49 for node in (node1, node2): 50 if self._cuda_cpp_scheduling.is_cuda_cpp_template(node): 51 return self._cuda_cpp_scheduling.can_fuse_horizontal( 52 node1, node2 53 ) # always False at the moment 54 return self._triton_scheduling.can_fuse_horizontal(node1, node2) 55 56 def group_fn(self, sizes): 57 return self._triton_scheduling.group_fn(sizes) 58 59 def codegen_template( 60 self, 61 template_node: BaseSchedulerNode, 62 epilogue_nodes: Sequence[BaseSchedulerNode], 63 ): 64 if self._cuda_cpp_scheduling.is_cuda_cpp_template(template_node): 65 assert epilogue_nodes is None or len(epilogue_nodes) == 0 66 return self._cuda_cpp_scheduling.codegen_template( 67 template_node, epilogue_nodes 68 ) 69 elif self._rocm_cpp_scheduling.is_rocm_cpp_template(template_node): 70 assert epilogue_nodes is None or len(epilogue_nodes) == 0 71 return self._rocm_cpp_scheduling.codegen_template( 72 template_node, epilogue_nodes 73 ) 74 else: 75 return self._triton_scheduling.codegen_template( 76 template_node, epilogue_nodes 77 ) 78 79 def codegen_node(self, node: Union[FusedSchedulerNode, SchedulerNode]): 80 return self._triton_scheduling.codegen_node(node) 81 82 def codegen_sync(self): 83 return self._triton_scheduling.codegen_sync() 84 85 def flush(self): 86 return self._triton_scheduling.flush() 87 88 def codegen_combo_kernel(self, *args, **kwargs): 89 return self._triton_scheduling.codegen_combo_kernel(*args, **kwargs) 90 91 def benchmark_fused_nodes(self, nodes): 92 return self._triton_scheduling.benchmark_fused_nodes(nodes) 93 94 def generate_kernel_code_from_nodes(self, nodes, benchmark_kernel=False): 95 return self._triton_scheduling.generate_kernel_code_from_nodes( 96 nodes, benchmark_kernel 97 ) 98 99 def benchmark_combo_kernel(self, node_list): 100 return self._triton_scheduling.benchmark_combo_kernel(node_list) 101