1import logging 2 3import torch 4 5from ._meta_parser import ( 6 OpenRegTensorData, 7 receive_after_sending, 8 safe_str, 9 validate_send_queue_args, 10) 11 12 13log = logging.getLogger(__name__) 14mp_context = torch.multiprocessing.get_context("spawn") 15 16# Constant properties of our device 17NUM_DEVICES = 7 18 19# Global state of our driver 20CURR_DEVICE_IDX = 0 21CURR_STREAM = 0 22 23 24# Our allocator 25class Allocator: 26 def __init__(self): 27 self.allocated = {} 28 29 def malloc(self, size): 30 new_data = torch.empty(size, dtype=torch.uint8) 31 ptr = new_data.data_ptr() 32 self.allocated[ptr] = new_data 33 return ptr 34 35 def free(self, ptr): 36 if ptr not in self.allocated: 37 return False 38 else: 39 del self.allocated[ptr] 40 return True 41 42 def tensor_from_meta(self, meta): 43 # Usual case, we're receiving a known Tensor 44 found_base = self.allocated.get(meta.data_ptr, None) 45 46 # Might be a rewrap of another storage at a different offset 47 # Slow path to try and find the corresponding storage 48 if found_base is None: 49 for tag, t in self.allocated.items(): 50 # t is always a 1D uint8 storage! 51 if meta.data_ptr > tag and meta.data_ptr < tag + t.nelement(): 52 # Blame @ngimel for this 53 slice_size = t.nelement() - (meta.data_ptr - tag) 54 found_base = torch.tensor((), dtype=torch.uint8).set_( 55 t.untyped_storage()[meta.data_ptr - tag :], 56 size=(slice_size,), 57 stride=(1,), 58 storage_offset=0, 59 ) 60 61 # This pointer is not allocated here, segfault ! 62 if found_base is None: 63 log.info("Currently allocated blocks:\n %s", safe_str(self.allocated)) 64 log.info("Trying to access %s", meta) 65 raise RuntimeError("SEGFAULT!") 66 67 # Raw 1d uint8 data 68 raw = found_base 69 # Slice the right storage part 70 raw_slice = raw.narrow(0, 0, meta.nelem_in_bytes) 71 # Reinterpret cast in the right dtype 72 as_dtype = raw_slice.view(dtype=meta.dtype) 73 # View to the right shape/stride/offset 74 view = as_dtype.as_strided(meta.size, meta.stride, meta.storage_offset) 75 return view 76 77 78def run_op(allocator, op_name, args, kwargs): 79 op, _ = torch._C._jit_get_operation(op_name) 80 args, kwargs = receive_after_sending(allocator, args, kwargs) 81 return op(*args, **kwargs) 82 83 84class _Daemon: 85 def __init__(self): 86 super().__init__() 87 self.is_initialized = False 88 89 def _lazy_init(self): 90 if self.is_initialized: 91 return 92 self.req_queue = mp_context.Queue() 93 self.ans_queue = mp_context.Queue() 94 95 self.runner = mp_context.Process( 96 target=self.run_forever, args=(self.req_queue, self.ans_queue), daemon=True 97 ) 98 self.runner.start() 99 self.is_initialized = True 100 101 def exec(self, cmd, *args): 102 self._lazy_init() 103 log.info("Main process launched: %s(*%s)", cmd, safe_str(args)) 104 validate_send_queue_args(cmd, args) 105 self.req_queue.put((cmd,) + args) 106 res = self.ans_queue.get() 107 log.info("Main process result for %s received: %s", cmd, safe_str(res)) 108 if res == "ERROR": 109 raise RuntimeError(f"Error in daemon while executing {cmd}, see logs") 110 else: 111 return res 112 113 @staticmethod 114 def run_forever(req_queue, ans_queue): 115 # Initialize our device 116 global CURR_DEVICE_IDX 117 empty_res = object() 118 allocator = Allocator() 119 120 # Serve all requests 121 while True: 122 cmd, *args = req_queue.get() 123 log.info("Worker executing: %s", cmd) 124 res = empty_res 125 if cmd == "deviceCount": 126 assert len(args) == 0 127 res = NUM_DEVICES 128 elif cmd == "getDevice": 129 res = CURR_DEVICE_IDX 130 elif cmd == "uncheckedSetDevice": 131 assert len(args) == 1 132 CURR_DEVICE_IDX = int(args[0]) 133 res = None 134 elif cmd == "exchangeDevice": 135 assert len(args) == 1 136 res = CURR_DEVICE_IDX 137 CURR_DEVICE_IDX = int(args[0]) 138 elif cmd == "malloc": 139 res = allocator.malloc(*args) 140 elif cmd == "free": 141 res = allocator.free(*args) 142 elif cmd == "run_op": 143 op_name, args, kwargs = args 144 run_op(allocator, op_name, args, kwargs) 145 res = None 146 elif cmd == "send_data": 147 assert len(args) == 1 148 res = OpenRegTensorData.from_meta(allocator, args[0]) 149 elif cmd == "recv_data": 150 assert len(args) == 2 151 host_tensor, dev_mem = args 152 dev_tensor = OpenRegTensorData.from_meta(allocator, dev_mem) 153 dev_tensor.copy_(host_tensor) 154 res = None 155 elif cmd == "get_op_output_shape": 156 op_name, args, kwargs = args 157 res = run_op(allocator, op_name, args, kwargs).size() 158 else: 159 log.warning("Bad command in worker") 160 res = "ERROR" 161 162 if res == empty_res: 163 raise RuntimeError("Bad impl didn't return anything") 164 log.info("Worker answering to: %s", cmd) 165 ans_queue.put(res) 166 167 168daemon = _Daemon() 169