1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""DLPack modules for Tensorflow.""" 16 17from tensorflow.python import pywrap_tfe 18from tensorflow.python.eager import context 19from tensorflow.python.util.tf_export import tf_export 20 21 22@tf_export("experimental.dlpack.to_dlpack", v1=[]) 23def to_dlpack(tf_tensor): 24 """Returns the dlpack capsule representing the tensor. 25 26 This operation ensures the underlying data memory is ready when returns. 27 28 ```python 29 a = tf.tensor([1, 10]) 30 dlcapsule = tf.experimental.dlpack.to_dlpack(a) 31 # dlcapsule represents the dlpack data structure 32 ``` 33 34 Args: 35 tf_tensor: Tensorflow eager tensor, to be converted to dlpack capsule. 36 37 Returns: 38 A PyCapsule named as dltensor, which shares the underlying memory to other 39 framework. This PyCapsule can be consumed only once. 40 """ 41 return pywrap_tfe.TFE_ToDlpackCapsule(tf_tensor) 42 43 44@tf_export("experimental.dlpack.from_dlpack", v1=[]) 45def from_dlpack(dlcapsule): 46 """Returns the Tensorflow eager tensor. 47 48 The returned tensor uses the memory shared by dlpack capsules from other 49 framework. 50 51 ```python 52 a = tf.experimental.dlpack.from_dlpack(dlcapsule) 53 # `a` uses the memory shared by dlpack 54 ``` 55 56 Args: 57 dlcapsule: A PyCapsule named as dltensor 58 59 Returns: 60 A Tensorflow eager tensor 61 """ 62 context.context().ensure_initialized() 63 return pywrap_tfe.TFE_FromDlpackCapsule(dlcapsule, context.context()._handle) # pylint: disable=protected-access 64