1tensor_engine = None 2 3 4def unsupported(func): 5 def wrapper(self): 6 return func(self) 7 8 wrapper.is_supported = False 9 return wrapper 10 11 12def is_supported(method): 13 if hasattr(method, "is_supported"): 14 return method.is_supported 15 return True 16 17 18def set_engine_mode(mode): 19 global tensor_engine 20 if mode == "tf": 21 from . import tf_engine 22 23 tensor_engine = tf_engine.TensorFlowEngine() 24 elif mode == "pt": 25 from . import pt_engine 26 27 tensor_engine = pt_engine.TorchTensorEngine() 28 elif mode == "topi": 29 from . import topi_engine 30 31 tensor_engine = topi_engine.TopiEngine() 32 elif mode == "relay": 33 from . import relay_engine 34 35 tensor_engine = relay_engine.RelayEngine() 36 elif mode == "nnc": 37 from . import nnc_engine 38 39 tensor_engine = nnc_engine.NncEngine() 40 else: 41 raise ValueError(f"invalid tensor engine mode: {mode}") 42 tensor_engine.mode = mode 43 44 45def get_engine(): 46 if tensor_engine is None: 47 raise ValueError("use of get_engine, before calling set_engine_mode is illegal") 48 return tensor_engine 49