xref: /aosp_15_r20/external/pytorch/torch/_inductor/codegen/cuda_combined_scheduling.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Sequence, Union
3
4from ..scheduler import (
5    BaseSchedulerNode,
6    BaseScheduling,
7    FusedSchedulerNode,
8    Scheduler,
9    SchedulerNode,
10)
11from .cuda.cuda_cpp_scheduling import CUDACPPScheduling
12from .rocm.rocm_cpp_scheduling import ROCmCPPScheduling
13from .triton import TritonScheduling
14
15
16class CUDACombinedScheduling(BaseScheduling):
17    """
18    Scheduler for CUDA Kernels, which delegates calls as appropriate
19    to the CUDA-C++ and Triton Schedulers, which both work for CUDA devices
20    and use a unified-wrapper for codegen.
21
22    If Scheduling code needs to be specialized for the case of mixed Triton / CUDA C++ code,
23    this would also be the place to do it.
24    """
25
26    def __init__(self, scheduler: Scheduler) -> None:
27        super().__init__()
28        self._scheduler = scheduler
29        self._triton_scheduling = TritonScheduling(scheduler)
30        self._cuda_cpp_scheduling = CUDACPPScheduling(scheduler)
31        self._rocm_cpp_scheduling = ROCmCPPScheduling(scheduler)
32
33    def get_backend_features(self, device):
34        return self._triton_scheduling.get_backend_features(device)
35
36    def choose_node_backend(self, node: BaseSchedulerNode) -> BaseScheduling:
37        if self._cuda_cpp_scheduling.is_cuda_cpp_template(node):
38            return self._cuda_cpp_scheduling
39        if self._rocm_cpp_scheduling.is_rocm_cpp_template(node):
40            return self._rocm_cpp_scheduling
41        return self._triton_scheduling
42
43    def can_fuse_vertical(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
44        if self._cuda_cpp_scheduling.can_fuse_vertical(node1, node2):
45            return True
46        return self._triton_scheduling.can_fuse_vertical(node1, node2)
47
48    def can_fuse_horizontal(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
49        for node in (node1, node2):
50            if self._cuda_cpp_scheduling.is_cuda_cpp_template(node):
51                return self._cuda_cpp_scheduling.can_fuse_horizontal(
52                    node1, node2
53                )  # always False at the moment
54        return self._triton_scheduling.can_fuse_horizontal(node1, node2)
55
56    def group_fn(self, sizes):
57        return self._triton_scheduling.group_fn(sizes)
58
59    def codegen_template(
60        self,
61        template_node: BaseSchedulerNode,
62        epilogue_nodes: Sequence[BaseSchedulerNode],
63    ):
64        if self._cuda_cpp_scheduling.is_cuda_cpp_template(template_node):
65            assert epilogue_nodes is None or len(epilogue_nodes) == 0
66            return self._cuda_cpp_scheduling.codegen_template(
67                template_node, epilogue_nodes
68            )
69        elif self._rocm_cpp_scheduling.is_rocm_cpp_template(template_node):
70            assert epilogue_nodes is None or len(epilogue_nodes) == 0
71            return self._rocm_cpp_scheduling.codegen_template(
72                template_node, epilogue_nodes
73            )
74        else:
75            return self._triton_scheduling.codegen_template(
76                template_node, epilogue_nodes
77            )
78
79    def codegen_node(self, node: Union[FusedSchedulerNode, SchedulerNode]):
80        return self._triton_scheduling.codegen_node(node)
81
82    def codegen_sync(self):
83        return self._triton_scheduling.codegen_sync()
84
85    def flush(self):
86        return self._triton_scheduling.flush()
87
88    def codegen_combo_kernel(self, *args, **kwargs):
89        return self._triton_scheduling.codegen_combo_kernel(*args, **kwargs)
90
91    def benchmark_fused_nodes(self, nodes):
92        return self._triton_scheduling.benchmark_fused_nodes(nodes)
93
94    def generate_kernel_code_from_nodes(self, nodes, benchmark_kernel=False):
95        return self._triton_scheduling.generate_kernel_code_from_nodes(
96            nodes, benchmark_kernel
97        )
98
99    def benchmark_combo_kernel(self, node_list):
100        return self._triton_scheduling.benchmark_combo_kernel(node_list)
101