xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/session_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Tensor Handle Operations."""
17
18# pylint: disable=g-bad-name
19import numpy as np
20
21from tensorflow.core.framework import resource_handle_pb2
22from tensorflow.python.client import pywrap_tf_session
23from tensorflow.python.framework import device as pydev
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import gen_data_flow_ops
28from tensorflow.python.util import compat
29from tensorflow.python.util.tf_export import tf_export
30
31
32def encode_resource_handle(resource_handle):
33  """Encode a ResourceHandle proto as custom numpy struct type."""
34  return np.asarray(bytearray(resource_handle.SerializeToString()),
35                    dtype=dtypes.np_resource)
36
37
38class TensorHandle:
39  """Represents a handle for a live tensor in a session."""
40
41  def __init__(self, handle, dtype, session):
42    """Constructs a new tensor handle.
43
44    A tensor handle for a persistent tensor is a python string
45    that has the form of "tensor_name;unique_id;device_name".
46
47    Args:
48      handle: A tensor handle.
49      dtype: The data type of the tensor represented by `handle`.
50      session: The session in which the tensor is produced.
51    """
52    self._handle = compat.as_str_any(handle)
53    self._resource_handle = None
54    self._dtype = dtype
55    self._session = session
56    self._auto_gc_enabled = True
57
58  def __del__(self):
59    if self._auto_gc_enabled:
60      self._session._register_dead_handle(self.handle)
61
62  def __str__(self):
63    return self._handle
64
65  def _get_resource_handle(self):
66    """The ResourceHandle representation of this handle."""
67    if not self._resource_handle:
68      self._resource_handle = resource_handle_pb2.ResourceHandleProto()
69      self._resource_handle.device = self._handle.split(";")[-1]
70      self._resource_handle.container = (pywrap_tf_session.TENSOR_HANDLE_KEY)
71      self._resource_handle.name = self._handle
72    return self._resource_handle
73
74  def to_numpy_array(self):
75    """Convert a TensorHandle object to a feedable numpy value.
76
77    Returns:
78      A numpy array of a custom struct type that can be used as a feed value
79      to run().
80    """
81    return encode_resource_handle(self._get_resource_handle())
82
83  @property
84  def handle(self):
85    """The string representation of this handle."""
86    return self._handle
87
88  def eval(self):
89    """Return the value of the tensor represented by this handle."""
90    if not self._auto_gc_enabled:
91      raise TypeError("Persistent tensor %s may have already been deleted."
92                      % self.handle)
93    holder, reader = _get_handle_reader(self._session.graph, self._handle,
94                                        self._dtype)
95    return self._session.run(reader, feed_dict={holder: self._handle})
96
97  def delete(self):
98    """Force the deletion of this persistent tensor."""
99    if not self._auto_gc_enabled:
100      raise TypeError("Persistent tensor %s may have already been deleted."
101                      % self.handle)
102    self._auto_gc_enabled = False
103    holder, deleter = _get_handle_deleter(self._session.graph, 0, self._handle)
104    self._session.run(deleter, feed_dict={holder: self.handle})
105
106  def get_raw_handle(self):
107    """Return the raw handle of the tensor.
108
109    Note that the method disables the automatic garbage collection of this
110    persistent tensor. The caller is now responsible for managing the life
111    time of the tensor.
112    """
113    self._auto_gc_enabled = False
114    return self._handle
115
116  @staticmethod
117  def _get_device_name(handle):
118    """The device name encoded in the handle."""
119    handle_str = compat.as_str_any(handle)
120    return pydev.canonical_name(handle_str.split(";")[-1])
121
122  @staticmethod
123  def _get_reader_key(handle):
124    """The graph key for reader."""
125    handle_parts = str(handle).split(";")
126    return handle_parts[0] + ";" + handle_parts[-1]
127
128  @staticmethod
129  def _get_mover_key(feeder, handle):
130    """The graph key for mover."""
131    return feeder.op.name + ";" + TensorHandle._get_reader_key(handle)
132
133
134@tf_export(v1=["get_session_handle"])
135def get_session_handle(data, name=None):
136  """Return the handle of `data`.
137
138  This is EXPERIMENTAL and subject to change.
139
140  Keep `data` "in-place" in the runtime and create a handle that can be
141  used to retrieve `data` in a subsequent run().
142
143  Combined with `get_session_tensor`, we can keep a tensor produced in
144  one run call in place, and use it as the input in a future run call.
145
146  Args:
147    data: A tensor to be stored in the session.
148    name: Optional name prefix for the return tensor.
149
150  Returns:
151    A scalar string tensor representing a unique handle for `data`.
152
153  Raises:
154    TypeError: if `data` is not a Tensor.
155
156  Example:
157
158  ```python
159  c = tf.multiply(a, b)
160  h = tf.compat.v1.get_session_handle(c)
161  h = sess.run(h)
162
163  p, a = tf.compat.v1.get_session_tensor(h.handle, tf.float32)
164  b = tf.multiply(a, 10)
165  c = sess.run(b, feed_dict={p: h.handle})
166  ```
167
168  """
169  if not isinstance(data, ops.Tensor):
170    raise TypeError("`data` must be of type Tensor.")
171
172  # Colocate this operation with data.
173  with ops.colocate_with(data):
174    return gen_data_flow_ops.get_session_handle(data, name=name)
175
176
177@tf_export(v1=["get_session_tensor"])
178def get_session_tensor(handle, dtype, name=None):
179  """Get the tensor of type `dtype` by feeding a tensor handle.
180
181  This is EXPERIMENTAL and subject to change.
182
183  Get the value of the tensor from a tensor handle. The tensor
184  is produced in a previous run() and stored in the state of the
185  session.
186
187  Args:
188    handle: The string representation of a persistent tensor handle.
189    dtype: The type of the output tensor.
190    name: Optional name prefix for the return tensor.
191
192  Returns:
193    A pair of tensors. The first is a placeholder for feeding a
194    tensor handle and the second is the tensor in the session state
195    keyed by the tensor handle.
196
197  Example:
198
199  ```python
200  c = tf.multiply(a, b)
201  h = tf.compat.v1.get_session_handle(c)
202  h = sess.run(h)
203
204  p, a = tf.compat.v1.get_session_tensor(h.handle, tf.float32)
205  b = tf.multiply(a, 10)
206  c = sess.run(b, feed_dict={p: h.handle})
207  ```
208
209  """
210  handle_device = TensorHandle._get_device_name(handle)
211  with ops.device(handle_device):
212    holder = array_ops.placeholder(dtypes.string)
213    _register_handle_feeder(holder.graph, holder, dtype)
214    tensor = gen_data_flow_ops.get_session_tensor(holder, dtype, name=name)
215  return (holder, tensor)
216
217
218@tf_export(v1=["delete_session_tensor"])
219def delete_session_tensor(handle, name=None):
220  """Delete the tensor for the given tensor handle.
221
222  This is EXPERIMENTAL and subject to change.
223
224  Delete the tensor of a given tensor handle. The tensor is produced
225  in a previous run() and stored in the state of the session.
226
227  Args:
228    handle: The string representation of a persistent tensor handle.
229    name: Optional name prefix for the return tensor.
230
231  Returns:
232    A pair of graph elements. The first is a placeholder for feeding a
233    tensor handle and the second is a deletion operation.
234  """
235  handle_device = TensorHandle._get_device_name(handle)
236  with ops.device(handle_device):
237    holder = array_ops.placeholder(dtypes.string)
238    deleter = gen_data_flow_ops.delete_session_tensor(holder, name=name)
239  return (holder, deleter)
240
241
242def _register_handle_feeder(graph, feeder, dtype):
243  graph._handle_feeders[feeder.op.name] = dtype
244
245
246def _get_handle_feeder(graph, feeder):
247  return graph._handle_feeders.get(feeder.op.name)
248
249
250def _get_handle_reader(graph, handle, dtype):
251  """Return a read subgraph for this handle."""
252  graph_key = TensorHandle._get_reader_key(handle)
253  result = graph._handle_readers.get(graph_key)
254  if result is None:
255    # Create reader if we haven't done it.
256    handle_device = TensorHandle._get_device_name(handle)
257    with graph.as_default(), graph.device(handle_device):
258      holder = array_ops.placeholder(dtypes.string)
259      _register_handle_feeder(holder.graph, holder, dtype)
260      reader = gen_data_flow_ops.get_session_tensor(holder, dtype)
261    result = (holder, reader)
262    graph._handle_readers[graph_key] = result
263  return result
264
265
266def _get_handle_mover(graph, feeder, handle):
267  """Return a move subgraph for this pair of feeder and handle."""
268  dtype = _get_handle_feeder(graph, feeder)
269  if dtype is None:
270    return None
271  handle_device = TensorHandle._get_device_name(handle)
272  if feeder.op.device == handle_device:
273    return None
274  # Now we know we have to move the tensor.
275  graph_key = TensorHandle._get_mover_key(feeder, handle)
276  result = graph._handle_movers.get(graph_key)
277  if result is None:
278    # Create mover if we haven't done it.
279    holder, reader = _get_handle_reader(graph, handle, dtype)
280    with graph.as_default(), graph.device(feeder.op.device):
281      mover = gen_data_flow_ops.get_session_handle(reader)
282    result = (holder, mover)
283    graph._handle_movers[graph_key] = result
284  return result
285
286
287def _get_handle_deleter(graph, deleter_key, handle):
288  """Return a deletion subgraph for this handle."""
289  result = graph._handle_deleters.get(deleter_key)
290  if result is None:
291    # Create deleter if we haven't done it.
292    handle_device = TensorHandle._get_device_name(handle)
293    with graph.as_default(), graph.device(handle_device):
294      holder = array_ops.placeholder(dtypes.string)
295      deleter = gen_data_flow_ops.delete_session_tensor(holder)
296    result = (holder, deleter)
297    graph._handle_deleters[deleter_key] = result
298  return result
299