1import pprint
2
3import torch
4from torch.utils._pytree import tree_map, tree_map_only
5
6
7class OpenRegTensorMeta:
8    def __init__(self, tensor, checked=True):
9        if checked and not tensor.device.type == "openreg":
10            raise RuntimeError(
11                "Creating OpenRegTensorMeta is only for Tensors on openreg device"
12            )
13        self.data_ptr = tensor.untyped_storage().data_ptr()
14        self.size = tensor.size()
15        self.stride = tensor.stride()
16        self.storage_offset = tensor.storage_offset()
17        self.dtype = tensor.dtype
18        self.nelem_in_bytes = tensor.nelement() * tensor.element_size()
19
20    def __repr__(self):
21        return (
22            f"OpenRegTensorMeta({self.data_ptr=}, {self.size=}, {self.stride=}, "
23            f"{self.storage_offset=}, {self.dtype=}, {self.nelem_in_bytes=})"
24        )
25
26
27class OpenRegTensorData(torch.Tensor):
28    @staticmethod
29    def from_meta(allocator, tensor_meta):
30        return OpenRegTensorData(allocator.tensor_from_meta(tensor_meta))
31
32
33VALID_QUEUE_TYPES_IN = {torch.Tensor, int, float}
34
35VALID_QUEUE_TYPES_OUT = {OpenRegTensorMeta, int, float, str}
36
37
38def safe_str(args):
39    def convert(obj):
40        if isinstance(obj, torch.Tensor):
41            return str(OpenRegTensorMeta(obj, checked=False))
42        else:
43            return obj
44
45    new_args = tree_map(convert, args)
46    return pprint.pformat(new_args)
47
48
49def validate_send_queue_args(cmd, args):
50    def check(obj):
51        if type(obj) not in VALID_QUEUE_TYPES_OUT:
52            if (
53                cmd == "recv_data"
54                and type(obj) is torch.Tensor
55                and obj.device.type == "cpu"
56            ):
57                # Only HtoD copy command can send cpu Tensors over
58                return
59            raise RuntimeError(
60                f"Trying to send invalid object through queue: {type(obj)}"
61            )
62
63    tree_map(check, args)
64
65
66def prepare_for_sending(args, kwargs):
67    def convert(obj):
68        if type(obj) not in VALID_QUEUE_TYPES_IN:
69            raise RuntimeError(
70                f"Cannot send object of type {type(obj)} " "over openreg device pipe."
71            )
72
73        if isinstance(obj, torch.Tensor):
74            return OpenRegTensorMeta(obj)
75        else:
76            return obj
77
78    return tree_map(convert, (args, kwargs))
79
80
81def receive_after_sending(allocator, args, kwargs):
82    def convert(obj):
83        if type(obj) not in VALID_QUEUE_TYPES_OUT:
84            raise RuntimeError(
85                f"Received invalid object of type {type(obj)} "
86                "over openreg device pipe."
87            )
88
89        if isinstance(obj, OpenRegTensorMeta):
90            return allocator.tensor_from_meta(obj)
91        else:
92            return obj
93
94    return tree_map(convert, (args, kwargs))
95
96
97def to_device_no_copy(device, args, kwargs):
98    def safe_to(t):
99        if device == "meta":
100            return t.to(device=device)
101        else:
102            return torch.empty_like(t, device=device)
103
104    return tree_map_only(torch.Tensor, safe_to, (args, kwargs))
105