1import logging
2
3import torch
4from torch.utils._pytree import tree_any
5
6
7log = logging.getLogger(__name__)
8
9from ._device_daemon import daemon
10from ._meta_parser import prepare_for_sending, to_device_no_copy
11
12
13_IMPL_REGISTRY = {}
14
15
16# Define all the implementations in the registry
17def _register_same_name(name, with_log=False):
18    def _(*args, **kwargs):
19        if with_log:
20            log.info("Calling hook %s", name)
21        return daemon.exec(name, *args, **kwargs)
22
23    _IMPL_REGISTRY[name] = _
24
25
26_register_same_name("deviceCount")
27_register_same_name("getDevice")
28_register_same_name("uncheckedSetDevice")
29_register_same_name("exchangeDevice")
30_register_same_name("malloc", True)
31_register_same_name("free", True)
32
33_openreg_lib = torch.library.Library("_", "IMPL")
34
35
36def _openreg_kernel_fallback(op, *args, **kwargs):
37    log.info("Calling kernel %s", op)
38
39    # Special ops needed to avoid infinite recursion
40    if op is torch.ops.aten._copy_from.default:
41        from_, to_ = args
42        if from_.device.type == to_.device.type:
43            assert from_.device.type == "openreg"
44            op = torch.ops.aten.copy_.default
45            # handled below as a regular copy
46        elif from_.device.type == "openreg":
47            args, _ = prepare_for_sending((from_,), {})
48            host_mem = daemon.exec("send_data", *args)
49            return to_.copy_(host_mem)
50        elif to_.device.type == "openreg":
51            args, _ = prepare_for_sending((to_,), {})
52            daemon.exec("recv_data", from_, *args)
53            return to_
54        else:
55            raise RuntimeError("Should not happen")
56    elif op is torch.ops.aten.set_.source_Tensor:
57        return torch.ops.aten.set_.source_Storage_storage_offset(
58            args[0],
59            args[1].untyped_storage(),
60            args[1].storage_offset(),
61            args[1].size(),
62            args[1].stride(),
63        )
64    elif op is torch.ops.aten._local_scalar_dense.default:
65        args, _ = prepare_for_sending(args, {})
66        host_mem = daemon.exec("send_data", *args)
67        return host_mem.item()
68
69    op_name = None
70    post_process = None
71    if "out" in op._overloadname:
72        # Note that all structured native op will call here
73        if isinstance(kwargs["out"], tuple):
74            raise RuntimeError(f"out= variant {op} with tuple out= not supported")
75        if kwargs["out"].nelement() == 0:
76            # Out variant that needs a resize, convert to an out of place
77            # and handle generically below
78            orig_out = kwargs["out"]
79            del kwargs["out"]
80            if op._overloadname != "out":
81                raise RuntimeError(
82                    "Cannot retranslate non-default out= variant form 0 size"
83                )
84            op = op.overloadpacket.default
85
86            def _post_process():
87                nonlocal real_res
88                orig_out.set_(real_res)
89                real_res = orig_out
90
91            post_process = _post_process
92
93        else:
94            # No metadata update to do, just run the op on the device
95            op_name = op.overloadpacket._qualified_op_name
96            real_res = kwargs["out"]
97    elif not tree_any(lambda obj: isinstance(obj, torch.Tensor), (args, kwargs)):
98        # No Tensor argument means factory function
99        # They should decompose and be handled in our c++ side directly
100        raise RuntimeError(f"{op} not handled yet.")
101    elif op._schema.is_mutable or op is torch.ops.aten._copy_from.default:
102        # Only handle inplace ops returning their first arg
103        assert len(args) >= 1, f"Inplace {op} needs at least one arg"
104        assert (
105            len(op._schema.returns) == 1
106        ), f"NYI Inplace {op} with more than one return"
107        op_name = op.overloadpacket._qualified_op_name
108        real_res = args[0]
109    elif any(r.alias_info is not None for r in op._schema.returns):
110        # View ops
111        if op is torch.ops.aten.view.default:
112            return torch.ops.aten._unsafe_view(*args, **kwargs)
113        raise RuntimeError(f"{op} view op is not handled yet")
114
115    if op_name is None:
116        # 1. Compute updated metadata
117        if torch.Tag.dynamic_output_shape not in op.tags:
118            # Usual case: run the meta op to see the output metadata
119            meta_args, meta_kwargs = to_device_no_copy("meta", args, kwargs)
120            meta_res = op(*meta_args, **meta_kwargs)
121
122            # 2. Allocate the output
123            real_res, _ = to_device_no_copy("openreg", meta_res, {})
124        else:
125            # Slow version for data-dependent functions:
126            # Run the op on the device just to get the output shape
127            args_, kwargs_ = prepare_for_sending(args, kwargs)
128            shape = daemon.exec(
129                "get_op_output_shape",
130                op.overloadpacket._qualified_op_name,
131                args_,
132                kwargs_,
133            )
134
135            # 2. Allocate the output
136            real_res = args[0].new(shape)
137
138        # 3. Move to out variant
139        kwargs["out"] = real_res
140        # Let overload resolution find the out= overload
141        op_name = op.overloadpacket._qualified_op_name
142
143    # 4. Run the compute and populate the output on the device
144    args, kwargs = prepare_for_sending(args, kwargs)
145    daemon.exec("run_op", op_name, args, kwargs)
146
147    if post_process is not None:
148        post_process()
149
150    return real_res
151
152
153_openreg_lib.fallback(_openreg_kernel_fallback, dispatch_key="PrivateUse1")
154