1# mypy: allow-untyped-defs 2import sys 3import types 4from typing import List 5 6import torch 7 8 9# This function should correspond to the enums present in c10/core/QEngine.h 10def _get_qengine_id(qengine: str) -> int: 11 if qengine == "none" or qengine == "" or qengine is None: 12 ret = 0 13 elif qengine == "fbgemm": 14 ret = 1 15 elif qengine == "qnnpack": 16 ret = 2 17 elif qengine == "onednn": 18 ret = 3 19 elif qengine == "x86": 20 ret = 4 21 else: 22 ret = -1 23 raise RuntimeError(f"{qengine} is not a valid value for quantized engine") 24 return ret 25 26 27# This function should correspond to the enums present in c10/core/QEngine.h 28def _get_qengine_str(qengine: int) -> str: 29 all_engines = {0: "none", 1: "fbgemm", 2: "qnnpack", 3: "onednn", 4: "x86"} 30 return all_engines.get(qengine, "*undefined") 31 32 33class _QEngineProp: 34 def __get__(self, obj, objtype) -> str: 35 return _get_qengine_str(torch._C._get_qengine()) 36 37 def __set__(self, obj, val: str) -> None: 38 torch._C._set_qengine(_get_qengine_id(val)) 39 40 41class _SupportedQEnginesProp: 42 def __get__(self, obj, objtype) -> List[str]: 43 qengines = torch._C._supported_qengines() 44 return [_get_qengine_str(qe) for qe in qengines] 45 46 def __set__(self, obj, val) -> None: 47 raise RuntimeError("Assignment not supported") 48 49 50class QuantizedEngine(types.ModuleType): 51 def __init__(self, m, name): 52 super().__init__(name) 53 self.m = m 54 55 def __getattr__(self, attr): 56 return self.m.__getattribute__(attr) 57 58 engine = _QEngineProp() 59 supported_engines = _SupportedQEnginesProp() 60 61 62# This is the sys.modules replacement trick, see 63# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273 64sys.modules[__name__] = QuantizedEngine(sys.modules[__name__], __name__) 65engine: str 66supported_engines: List[str] 67