1from typing import Any 2 3import torch 4import enum 5 6from torch._C import _from_dlpack 7from torch._C import _to_dlpack as to_dlpack 8 9 10class DLDeviceType(enum.IntEnum): 11 # Enums as in DLPack specification (aten/src/ATen/dlpack.h) 12 kDLCPU = 1, 13 kDLGPU = 2, 14 kDLCPUPinned = 3, 15 kDLOpenCL = 4, 16 kDLVulkan = 7, 17 kDLMetal = 8, 18 kDLVPI = 9, 19 kDLROCM = 10, 20 kDLExtDev = 12, 21 kDLOneAPI = 14, 22 23 24torch._C._add_docstr(to_dlpack, r"""to_dlpack(tensor) -> PyCapsule 25 26Returns an opaque object (a "DLPack capsule") representing the tensor. 27 28.. note:: 29 ``to_dlpack`` is a legacy DLPack interface. The capsule it returns 30 cannot be used for anything in Python other than use it as input to 31 ``from_dlpack``. The more idiomatic use of DLPack is to call 32 ``from_dlpack`` directly on the tensor object - this works when that 33 object has a ``__dlpack__`` method, which PyTorch and most other 34 libraries indeed have now. 35 36.. warning:: 37 Only call ``from_dlpack`` once per capsule produced with ``to_dlpack``. 38 Behavior when a capsule is consumed multiple times is undefined. 39 40Args: 41 tensor: a tensor to be exported 42 43The DLPack capsule shares the tensor's memory. 44""") 45 46 47# TODO: add a typing.Protocol to be able to tell Mypy that only objects with 48# __dlpack__ and __dlpack_device__ methods are accepted. 49def from_dlpack(ext_tensor: Any) -> 'torch.Tensor': 50 """from_dlpack(ext_tensor) -> Tensor 51 52 Converts a tensor from an external library into a ``torch.Tensor``. 53 54 The returned PyTorch tensor will share the memory with the input tensor 55 (which may have come from another library). Note that in-place operations 56 will therefore also affect the data of the input tensor. This may lead to 57 unexpected issues (e.g., other libraries may have read-only flags or 58 immutable data structures), so the user should only do this if they know 59 for sure that this is fine. 60 61 Args: 62 ext_tensor (object with ``__dlpack__`` attribute, or a DLPack capsule): 63 The tensor or DLPack capsule to convert. 64 65 If ``ext_tensor`` is a tensor (or ndarray) object, it must support 66 the ``__dlpack__`` protocol (i.e., have a ``ext_tensor.__dlpack__`` 67 method). Otherwise ``ext_tensor`` may be a DLPack capsule, which is 68 an opaque ``PyCapsule`` instance, typically produced by a 69 ``to_dlpack`` function or method. 70 71 Examples:: 72 73 >>> import torch.utils.dlpack 74 >>> t = torch.arange(4) 75 76 # Convert a tensor directly (supported in PyTorch >= 1.10) 77 >>> t2 = torch.from_dlpack(t) 78 >>> t2[:2] = -1 # show that memory is shared 79 >>> t2 80 tensor([-1, -1, 2, 3]) 81 >>> t 82 tensor([-1, -1, 2, 3]) 83 84 # The old-style DLPack usage, with an intermediate capsule object 85 >>> capsule = torch.utils.dlpack.to_dlpack(t) 86 >>> capsule 87 <capsule object "dltensor" at ...> 88 >>> t3 = torch.from_dlpack(capsule) 89 >>> t3 90 tensor([-1, -1, 2, 3]) 91 >>> t3[0] = -9 # now we're sharing memory between 3 tensors 92 >>> t3 93 tensor([-9, -1, 2, 3]) 94 >>> t2 95 tensor([-9, -1, 2, 3]) 96 >>> t 97 tensor([-9, -1, 2, 3]) 98 99 """ 100 if hasattr(ext_tensor, '__dlpack__'): 101 device = ext_tensor.__dlpack_device__() 102 # device is either CUDA or ROCm, we need to pass the current 103 # stream 104 if device[0] in (DLDeviceType.kDLGPU, DLDeviceType.kDLROCM): 105 stream = torch.cuda.current_stream(f'cuda:{device[1]}') 106 # cuda_stream is the pointer to the stream and it is a public 107 # attribute, but it is not documented 108 # The array API specify that the default legacy stream must be passed 109 # with a value of 1 for CUDA 110 # https://data-apis.org/array-api/latest/API_specification/array_object.html?dlpack-self-stream-none#dlpack-self-stream-none 111 is_cuda = device[0] == DLDeviceType.kDLGPU 112 # Since pytorch is not using PTDS by default, lets directly pass 113 # the legacy stream 114 stream_ptr = 1 if is_cuda and stream.cuda_stream == 0 else stream.cuda_stream 115 dlpack = ext_tensor.__dlpack__(stream=stream_ptr) 116 else: 117 dlpack = ext_tensor.__dlpack__() 118 else: 119 # Old versions just call the converter 120 dlpack = ext_tensor 121 return _from_dlpack(dlpack) 122