1import logging
2
3import torch
4
5from ._meta_parser import (
6    OpenRegTensorData,
7    receive_after_sending,
8    safe_str,
9    validate_send_queue_args,
10)
11
12
13log = logging.getLogger(__name__)
14mp_context = torch.multiprocessing.get_context("spawn")
15
16# Constant properties of our device
17NUM_DEVICES = 7
18
19# Global state of our driver
20CURR_DEVICE_IDX = 0
21CURR_STREAM = 0
22
23
24# Our allocator
25class Allocator:
26    def __init__(self):
27        self.allocated = {}
28
29    def malloc(self, size):
30        new_data = torch.empty(size, dtype=torch.uint8)
31        ptr = new_data.data_ptr()
32        self.allocated[ptr] = new_data
33        return ptr
34
35    def free(self, ptr):
36        if ptr not in self.allocated:
37            return False
38        else:
39            del self.allocated[ptr]
40            return True
41
42    def tensor_from_meta(self, meta):
43        # Usual case, we're receiving a known Tensor
44        found_base = self.allocated.get(meta.data_ptr, None)
45
46        # Might be a rewrap of another storage at a different offset
47        # Slow path to try and find the corresponding storage
48        if found_base is None:
49            for tag, t in self.allocated.items():
50                # t is always a 1D uint8 storage!
51                if meta.data_ptr > tag and meta.data_ptr < tag + t.nelement():
52                    # Blame @ngimel for this
53                    slice_size = t.nelement() - (meta.data_ptr - tag)
54                    found_base = torch.tensor((), dtype=torch.uint8).set_(
55                        t.untyped_storage()[meta.data_ptr - tag :],
56                        size=(slice_size,),
57                        stride=(1,),
58                        storage_offset=0,
59                    )
60
61        # This pointer is not allocated here, segfault !
62        if found_base is None:
63            log.info("Currently allocated blocks:\n %s", safe_str(self.allocated))
64            log.info("Trying to access %s", meta)
65            raise RuntimeError("SEGFAULT!")
66
67        # Raw 1d uint8 data
68        raw = found_base
69        # Slice the right storage part
70        raw_slice = raw.narrow(0, 0, meta.nelem_in_bytes)
71        # Reinterpret cast in the right dtype
72        as_dtype = raw_slice.view(dtype=meta.dtype)
73        # View to the right shape/stride/offset
74        view = as_dtype.as_strided(meta.size, meta.stride, meta.storage_offset)
75        return view
76
77
78def run_op(allocator, op_name, args, kwargs):
79    op, _ = torch._C._jit_get_operation(op_name)
80    args, kwargs = receive_after_sending(allocator, args, kwargs)
81    return op(*args, **kwargs)
82
83
84class _Daemon:
85    def __init__(self):
86        super().__init__()
87        self.is_initialized = False
88
89    def _lazy_init(self):
90        if self.is_initialized:
91            return
92        self.req_queue = mp_context.Queue()
93        self.ans_queue = mp_context.Queue()
94
95        self.runner = mp_context.Process(
96            target=self.run_forever, args=(self.req_queue, self.ans_queue), daemon=True
97        )
98        self.runner.start()
99        self.is_initialized = True
100
101    def exec(self, cmd, *args):
102        self._lazy_init()
103        log.info("Main process launched: %s(*%s)", cmd, safe_str(args))
104        validate_send_queue_args(cmd, args)
105        self.req_queue.put((cmd,) + args)
106        res = self.ans_queue.get()
107        log.info("Main process result for %s received: %s", cmd, safe_str(res))
108        if res == "ERROR":
109            raise RuntimeError(f"Error in daemon while executing {cmd}, see logs")
110        else:
111            return res
112
113    @staticmethod
114    def run_forever(req_queue, ans_queue):
115        # Initialize our device
116        global CURR_DEVICE_IDX
117        empty_res = object()
118        allocator = Allocator()
119
120        # Serve all requests
121        while True:
122            cmd, *args = req_queue.get()
123            log.info("Worker executing: %s", cmd)
124            res = empty_res
125            if cmd == "deviceCount":
126                assert len(args) == 0
127                res = NUM_DEVICES
128            elif cmd == "getDevice":
129                res = CURR_DEVICE_IDX
130            elif cmd == "uncheckedSetDevice":
131                assert len(args) == 1
132                CURR_DEVICE_IDX = int(args[0])
133                res = None
134            elif cmd == "exchangeDevice":
135                assert len(args) == 1
136                res = CURR_DEVICE_IDX
137                CURR_DEVICE_IDX = int(args[0])
138            elif cmd == "malloc":
139                res = allocator.malloc(*args)
140            elif cmd == "free":
141                res = allocator.free(*args)
142            elif cmd == "run_op":
143                op_name, args, kwargs = args
144                run_op(allocator, op_name, args, kwargs)
145                res = None
146            elif cmd == "send_data":
147                assert len(args) == 1
148                res = OpenRegTensorData.from_meta(allocator, args[0])
149            elif cmd == "recv_data":
150                assert len(args) == 2
151                host_tensor, dev_mem = args
152                dev_tensor = OpenRegTensorData.from_meta(allocator, dev_mem)
153                dev_tensor.copy_(host_tensor)
154                res = None
155            elif cmd == "get_op_output_shape":
156                op_name, args, kwargs = args
157                res = run_op(allocator, op_name, args, kwargs).size()
158            else:
159                log.warning("Bad command in worker")
160                res = "ERROR"
161
162            if res == empty_res:
163                raise RuntimeError("Bad impl didn't return anything")
164            log.info("Worker answering to: %s", cmd)
165            ans_queue.put(res)
166
167
168daemon = _Daemon()
169