1import pprint 2 3import torch 4from torch.utils._pytree import tree_map, tree_map_only 5 6 7class OpenRegTensorMeta: 8 def __init__(self, tensor, checked=True): 9 if checked and not tensor.device.type == "openreg": 10 raise RuntimeError( 11 "Creating OpenRegTensorMeta is only for Tensors on openreg device" 12 ) 13 self.data_ptr = tensor.untyped_storage().data_ptr() 14 self.size = tensor.size() 15 self.stride = tensor.stride() 16 self.storage_offset = tensor.storage_offset() 17 self.dtype = tensor.dtype 18 self.nelem_in_bytes = tensor.nelement() * tensor.element_size() 19 20 def __repr__(self): 21 return ( 22 f"OpenRegTensorMeta({self.data_ptr=}, {self.size=}, {self.stride=}, " 23 f"{self.storage_offset=}, {self.dtype=}, {self.nelem_in_bytes=})" 24 ) 25 26 27class OpenRegTensorData(torch.Tensor): 28 @staticmethod 29 def from_meta(allocator, tensor_meta): 30 return OpenRegTensorData(allocator.tensor_from_meta(tensor_meta)) 31 32 33VALID_QUEUE_TYPES_IN = {torch.Tensor, int, float} 34 35VALID_QUEUE_TYPES_OUT = {OpenRegTensorMeta, int, float, str} 36 37 38def safe_str(args): 39 def convert(obj): 40 if isinstance(obj, torch.Tensor): 41 return str(OpenRegTensorMeta(obj, checked=False)) 42 else: 43 return obj 44 45 new_args = tree_map(convert, args) 46 return pprint.pformat(new_args) 47 48 49def validate_send_queue_args(cmd, args): 50 def check(obj): 51 if type(obj) not in VALID_QUEUE_TYPES_OUT: 52 if ( 53 cmd == "recv_data" 54 and type(obj) is torch.Tensor 55 and obj.device.type == "cpu" 56 ): 57 # Only HtoD copy command can send cpu Tensors over 58 return 59 raise RuntimeError( 60 f"Trying to send invalid object through queue: {type(obj)}" 61 ) 62 63 tree_map(check, args) 64 65 66def prepare_for_sending(args, kwargs): 67 def convert(obj): 68 if type(obj) not in VALID_QUEUE_TYPES_IN: 69 raise RuntimeError( 70 f"Cannot send object of type {type(obj)} " "over openreg device pipe." 71 ) 72 73 if isinstance(obj, torch.Tensor): 74 return OpenRegTensorMeta(obj) 75 else: 76 return obj 77 78 return tree_map(convert, (args, kwargs)) 79 80 81def receive_after_sending(allocator, args, kwargs): 82 def convert(obj): 83 if type(obj) not in VALID_QUEUE_TYPES_OUT: 84 raise RuntimeError( 85 f"Received invalid object of type {type(obj)} " 86 "over openreg device pipe." 87 ) 88 89 if isinstance(obj, OpenRegTensorMeta): 90 return allocator.tensor_from_meta(obj) 91 else: 92 return obj 93 94 return tree_map(convert, (args, kwargs)) 95 96 97def to_device_no_copy(device, args, kwargs): 98 def safe_to(t): 99 if device == "meta": 100 return t.to(device=device) 101 else: 102 return torch.empty_like(t, device=device) 103 104 return tree_map_only(torch.Tensor, safe_to, (args, kwargs)) 105