1import logging 2 3import torch 4from torch.utils._pytree import tree_any 5 6 7log = logging.getLogger(__name__) 8 9from ._device_daemon import daemon 10from ._meta_parser import prepare_for_sending, to_device_no_copy 11 12 13_IMPL_REGISTRY = {} 14 15 16# Define all the implementations in the registry 17def _register_same_name(name, with_log=False): 18 def _(*args, **kwargs): 19 if with_log: 20 log.info("Calling hook %s", name) 21 return daemon.exec(name, *args, **kwargs) 22 23 _IMPL_REGISTRY[name] = _ 24 25 26_register_same_name("deviceCount") 27_register_same_name("getDevice") 28_register_same_name("uncheckedSetDevice") 29_register_same_name("exchangeDevice") 30_register_same_name("malloc", True) 31_register_same_name("free", True) 32 33_openreg_lib = torch.library.Library("_", "IMPL") 34 35 36def _openreg_kernel_fallback(op, *args, **kwargs): 37 log.info("Calling kernel %s", op) 38 39 # Special ops needed to avoid infinite recursion 40 if op is torch.ops.aten._copy_from.default: 41 from_, to_ = args 42 if from_.device.type == to_.device.type: 43 assert from_.device.type == "openreg" 44 op = torch.ops.aten.copy_.default 45 # handled below as a regular copy 46 elif from_.device.type == "openreg": 47 args, _ = prepare_for_sending((from_,), {}) 48 host_mem = daemon.exec("send_data", *args) 49 return to_.copy_(host_mem) 50 elif to_.device.type == "openreg": 51 args, _ = prepare_for_sending((to_,), {}) 52 daemon.exec("recv_data", from_, *args) 53 return to_ 54 else: 55 raise RuntimeError("Should not happen") 56 elif op is torch.ops.aten.set_.source_Tensor: 57 return torch.ops.aten.set_.source_Storage_storage_offset( 58 args[0], 59 args[1].untyped_storage(), 60 args[1].storage_offset(), 61 args[1].size(), 62 args[1].stride(), 63 ) 64 elif op is torch.ops.aten._local_scalar_dense.default: 65 args, _ = prepare_for_sending(args, {}) 66 host_mem = daemon.exec("send_data", *args) 67 return host_mem.item() 68 69 op_name = None 70 post_process = None 71 if "out" in op._overloadname: 72 # Note that all structured native op will call here 73 if isinstance(kwargs["out"], tuple): 74 raise RuntimeError(f"out= variant {op} with tuple out= not supported") 75 if kwargs["out"].nelement() == 0: 76 # Out variant that needs a resize, convert to an out of place 77 # and handle generically below 78 orig_out = kwargs["out"] 79 del kwargs["out"] 80 if op._overloadname != "out": 81 raise RuntimeError( 82 "Cannot retranslate non-default out= variant form 0 size" 83 ) 84 op = op.overloadpacket.default 85 86 def _post_process(): 87 nonlocal real_res 88 orig_out.set_(real_res) 89 real_res = orig_out 90 91 post_process = _post_process 92 93 else: 94 # No metadata update to do, just run the op on the device 95 op_name = op.overloadpacket._qualified_op_name 96 real_res = kwargs["out"] 97 elif not tree_any(lambda obj: isinstance(obj, torch.Tensor), (args, kwargs)): 98 # No Tensor argument means factory function 99 # They should decompose and be handled in our c++ side directly 100 raise RuntimeError(f"{op} not handled yet.") 101 elif op._schema.is_mutable or op is torch.ops.aten._copy_from.default: 102 # Only handle inplace ops returning their first arg 103 assert len(args) >= 1, f"Inplace {op} needs at least one arg" 104 assert ( 105 len(op._schema.returns) == 1 106 ), f"NYI Inplace {op} with more than one return" 107 op_name = op.overloadpacket._qualified_op_name 108 real_res = args[0] 109 elif any(r.alias_info is not None for r in op._schema.returns): 110 # View ops 111 if op is torch.ops.aten.view.default: 112 return torch.ops.aten._unsafe_view(*args, **kwargs) 113 raise RuntimeError(f"{op} view op is not handled yet") 114 115 if op_name is None: 116 # 1. Compute updated metadata 117 if torch.Tag.dynamic_output_shape not in op.tags: 118 # Usual case: run the meta op to see the output metadata 119 meta_args, meta_kwargs = to_device_no_copy("meta", args, kwargs) 120 meta_res = op(*meta_args, **meta_kwargs) 121 122 # 2. Allocate the output 123 real_res, _ = to_device_no_copy("openreg", meta_res, {}) 124 else: 125 # Slow version for data-dependent functions: 126 # Run the op on the device just to get the output shape 127 args_, kwargs_ = prepare_for_sending(args, kwargs) 128 shape = daemon.exec( 129 "get_op_output_shape", 130 op.overloadpacket._qualified_op_name, 131 args_, 132 kwargs_, 133 ) 134 135 # 2. Allocate the output 136 real_res = args[0].new(shape) 137 138 # 3. Move to out variant 139 kwargs["out"] = real_res 140 # Let overload resolution find the out= overload 141 op_name = op.overloadpacket._qualified_op_name 142 143 # 4. Run the compute and populate the output on the device 144 args, kwargs = prepare_for_sending(args, kwargs) 145 daemon.exec("run_op", op_name, args, kwargs) 146 147 if post_process is not None: 148 post_process() 149 150 return real_res 151 152 153_openreg_lib.fallback(_openreg_kernel_fallback, dispatch_key="PrivateUse1") 154