1# mypy: allow-untyped-defs 2import functools 3import hashlib 4 5 6@functools.lru_cache(None) 7def has_triton_package() -> bool: 8 try: 9 from triton.compiler.compiler import triton_key 10 11 return triton_key is not None 12 except ImportError: 13 return False 14 except RuntimeError: 15 return False 16 17 18@functools.lru_cache(None) 19def has_triton() -> bool: 20 from torch._dynamo.device_interface import get_interface_for_device 21 22 def cuda_extra_check(device_interface): 23 return device_interface.Worker.get_device_properties().major >= 7 24 25 def _return_true(device_interface): 26 return True 27 28 triton_supported_devices = {"cuda": cuda_extra_check, "xpu": _return_true} 29 30 def is_device_compatible_with_triton(): 31 for device, extra_check in triton_supported_devices.items(): 32 device_interface = get_interface_for_device(device) 33 if device_interface.is_available() and extra_check(device_interface): 34 return True 35 return False 36 37 return is_device_compatible_with_triton() and has_triton_package() 38 39 40@functools.lru_cache(None) 41def triton_backend(): 42 from triton.compiler.compiler import make_backend 43 from triton.runtime.driver import driver 44 45 target = driver.active.get_current_target() 46 return make_backend(target) 47 48 49@functools.lru_cache(None) 50def triton_hash_with_backend(): 51 from triton.compiler.compiler import triton_key 52 53 backend = triton_backend() 54 key = f"{triton_key()}-{backend.hash()}" 55 56 # Hash is upper case so that it can't contain any Python keywords. 57 return hashlib.sha256(key.encode("utf-8")).hexdigest().upper() 58 59 60def dtype_to_string(dtype): 61 if dtype.name.startswith("fp"): 62 suffix = "float" + dtype.name[2:] 63 elif dtype.name.startswith("bf"): 64 suffix = "bfloat" + dtype.name[2:] 65 else: 66 suffix = dtype.name 67 return "triton.language." + suffix 68 69 70def patch_triton_dtype_repr(): 71 import triton 72 73 # Hack to get triton dtype repr to produce an evaluatable expression 74 # triton.language.float32 emits triton.language.fp32 which does not 75 # exist 76 # REMOVE when https://github.com/openai/triton/pull/3342 lands 77 triton.language.dtype.__repr__ = lambda self: dtype_to_string(self) 78