xref: /aosp_15_r20/external/pytorch/torch/utils/_triton.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3import hashlib
4
5
6@functools.lru_cache(None)
7def has_triton_package() -> bool:
8    try:
9        from triton.compiler.compiler import triton_key
10
11        return triton_key is not None
12    except ImportError:
13        return False
14    except RuntimeError:
15        return False
16
17
18@functools.lru_cache(None)
19def has_triton() -> bool:
20    from torch._dynamo.device_interface import get_interface_for_device
21
22    def cuda_extra_check(device_interface):
23        return device_interface.Worker.get_device_properties().major >= 7
24
25    def _return_true(device_interface):
26        return True
27
28    triton_supported_devices = {"cuda": cuda_extra_check, "xpu": _return_true}
29
30    def is_device_compatible_with_triton():
31        for device, extra_check in triton_supported_devices.items():
32            device_interface = get_interface_for_device(device)
33            if device_interface.is_available() and extra_check(device_interface):
34                return True
35        return False
36
37    return is_device_compatible_with_triton() and has_triton_package()
38
39
40@functools.lru_cache(None)
41def triton_backend():
42    from triton.compiler.compiler import make_backend
43    from triton.runtime.driver import driver
44
45    target = driver.active.get_current_target()
46    return make_backend(target)
47
48
49@functools.lru_cache(None)
50def triton_hash_with_backend():
51    from triton.compiler.compiler import triton_key
52
53    backend = triton_backend()
54    key = f"{triton_key()}-{backend.hash()}"
55
56    # Hash is upper case so that it can't contain any Python keywords.
57    return hashlib.sha256(key.encode("utf-8")).hexdigest().upper()
58
59
60def dtype_to_string(dtype):
61    if dtype.name.startswith("fp"):
62        suffix = "float" + dtype.name[2:]
63    elif dtype.name.startswith("bf"):
64        suffix = "bfloat" + dtype.name[2:]
65    else:
66        suffix = dtype.name
67    return "triton.language." + suffix
68
69
70def patch_triton_dtype_repr():
71    import triton
72
73    # Hack to get triton dtype repr to produce an evaluatable expression
74    # triton.language.float32 emits triton.language.fp32 which does not
75    # exist
76    # REMOVE when https://github.com/openai/triton/pull/3342 lands
77    triton.language.dtype.__repr__ = lambda self: dtype_to_string(self)
78