xref: /aosp_15_r20/external/tensorflow/tensorflow/python/client/session.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A client interface for TensorFlow."""
16
17import collections
18import functools
19import re
20import threading
21import warnings
22
23import numpy as np
24import wrapt
25
26from tensorflow.core.protobuf import config_pb2
27from tensorflow.core.protobuf import rewriter_config_pb2
28from tensorflow.python.client import pywrap_tf_session as tf_session
29from tensorflow.python.eager import context
30from tensorflow.python.eager import monitoring
31from tensorflow.python.framework import device
32from tensorflow.python.framework import error_interpolation
33from tensorflow.python.framework import errors
34from tensorflow.python.framework import indexed_slices
35from tensorflow.python.framework import ops
36from tensorflow.python.framework import sparse_tensor
37from tensorflow.python.ops import session_ops
38from tensorflow.python.platform import tf_logging as logging
39from tensorflow.python.training.experimental import mixed_precision_global_state
40from tensorflow.python.util import compat
41from tensorflow.python.util import nest
42from tensorflow.python.util.compat import collections_abc
43from tensorflow.python.util.tf_export import tf_export
44
45_python_session_create_counter = monitoring.Counter(
46    '/tensorflow/api/python/session_create_counter',
47    'Counter for number of sessions created in Python.')
48
49
50class SessionInterface(object):
51  """Base class for implementations of TensorFlow client sessions."""
52
53  @property
54  def graph(self):
55    """The underlying TensorFlow graph, to be used in building Operations."""
56    raise NotImplementedError('graph')
57
58  @property
59  def sess_str(self):
60    """The TensorFlow process to which this session will connect."""
61    raise NotImplementedError('sess_str')
62
63  def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
64    """Runs operations in the session. See `BaseSession.run()` for details."""
65    raise NotImplementedError('run')
66
67  def partial_run_setup(self, fetches, feeds=None):
68    """Sets up the feeds and fetches for partial runs in the session."""
69    raise NotImplementedError('partial_run_setup')
70
71  def partial_run(self, handle, fetches, feed_dict=None):
72    """Continues the execution with additional feeds and fetches."""
73    raise NotImplementedError('partial_run')
74
75
76def _get_indexed_slices_value_from_fetches(fetched_vals):
77  return indexed_slices.IndexedSlicesValue(
78      fetched_vals[0], fetched_vals[1],
79      fetched_vals[2] if len(fetched_vals) == 3 else None)
80
81
82def _get_feeds_for_indexed_slices(feed, feed_val):
83  return list(
84      zip([feed.values, feed.indices] if feed.dense_shape is None else
85          [feed.values, feed.indices, feed.dense_shape], feed_val))
86
87
88# List of extensions supported to convert run arguments into actual fetches and
89# feeds.
90#
91# Each element in the list is a tuple of (Type, fetch_fn, feed_fn1, feed_fn2),
92# where the function signatures are:
93#   fetch_fn : Type -> (list of Tensors,
94#                       lambda: list of fetched np.ndarray -> TypeVal)
95#   feed_fn1 : Type, TypeVal -> list of (Tensor, value)
96#   feed_fn2 : Type -> list of Tensors
97#
98# `fetch_fn` describes how to expand fetch into its
99# component Tensors and how to contract the fetched results back into
100# a single return value.
101#
102# Each feed function describes how to unpack a single fed value and map it to
103# feeds of one or more tensors and their corresponding values: `feed_fn1` is
104# used to feed a run, `feed_fn2` to set up a partial run.
105#
106# TODO(touts): We could reimplement these as specialized _FeedMapper
107# implementations after we refactor the feed handling code to use them.
108#
109# Eventually, this registration could be opened up to support custom Tensor
110# expansions.
111# pylint: disable=g-long-lambda
112_REGISTERED_EXPANSIONS = [
113    # SparseTensors are fetched as SparseTensorValues. They can be fed
114    # SparseTensorValues or normal tuples.
115    (sparse_tensor.SparseTensor, lambda fetch: ([
116        fetch.indices, fetch.values, fetch.dense_shape
117    ], lambda fetched_vals: sparse_tensor.SparseTensorValue(*fetched_vals)),
118     lambda feed, feed_val: list(
119         zip([feed.indices, feed.values, feed.dense_shape], feed_val)),
120     lambda feed: [feed.indices, feed.values, feed.dense_shape]),
121    # IndexedSlices are fetched as IndexedSlicesValues. They can be fed
122    # IndexedSlicesValues or normal tuples.
123    (indexed_slices.IndexedSlices,
124     lambda fetch: ([fetch.values, fetch.indices] if fetch.dense_shape is None
125                    else [fetch.values, fetch.indices, fetch.dense_shape
126                         ], _get_indexed_slices_value_from_fetches),
127     _get_feeds_for_indexed_slices,
128     lambda feed: [feed.values, feed.indices] if feed.dense_shape is None else
129     [feed.values, feed.indices, feed.dense_shape]),
130    # The default catches all other types and performs no expansions.
131    (object, lambda fetch: ([fetch], lambda fetched_vals: fetched_vals[0]),
132     lambda feed, feed_val: [(feed, feed_val)], lambda feed: [feed])
133]
134
135# pylint: enable=g-long-lambda
136
137
138def _convert_to_numpy_obj(numpy_dtype, obj):
139  """Explicitly convert obj based on numpy type except for string type."""
140  return numpy_dtype(obj) if numpy_dtype is not object else str(obj)
141
142
143def register_session_run_conversion_functions(
144    tensor_type,
145    fetch_function,
146    feed_function=None,
147    feed_function_for_partial_run=None):
148  """Register fetch and feed conversion functions for `tf.Session.run()`.
149
150  This function registers a triple of conversion functions for fetching and/or
151  feeding values of user-defined types in a call to tf.Session.run().
152
153  An example
154
155  ```python
156     class SquaredTensor(object):
157       def __init__(self, tensor):
158         self.sq = tf.square(tensor)
159     #you can define conversion functions as follows:
160     fetch_function = lambda squared_tensor:([squared_tensor.sq],
161                                             lambda val: val[0])
162     feed_function = lambda feed, feed_val: [(feed.sq, feed_val)]
163     feed_function_for_partial_run = lambda feed: [feed.sq]
164     #then after invoking this register function, you can use as follows:
165     session.run(squared_tensor1,
166                 feed_dict = {squared_tensor2 : some_numpy_array})
167  ```
168
169  Args:
170    tensor_type: The type for which you want to register a conversion function.
171    fetch_function: A callable that takes an object of type `tensor_type` and
172      returns a tuple, where the first element is a list of `tf.Tensor` objects,
173      and the second element is a callable that takes a list of ndarrays and
174      returns an object of some value type that corresponds to `tensor_type`.
175      fetch_function describes how to expand fetch into its component Tensors
176      and how to contract the fetched results back into a single return value.
177    feed_function: A callable that takes feed_key and feed_value as input, and
178      returns a list of tuples (feed_tensor, feed_val), feed_key must have type
179      `tensor_type`, and feed_tensor must have type `tf.Tensor`. Each feed
180      function describes how to unpack a single fed value and map it to feeds of
181      one or more tensors and their corresponding values.
182    feed_function_for_partial_run: A callable for specifying tensor values to
183      feed when setting up a partial run, which takes a `tensor_type` type
184      object as input, and returns a list of Tensors.
185
186  Raises:
187    ValueError: If `tensor_type` has already been registered.
188  """
189  for conversion_function in _REGISTERED_EXPANSIONS:
190    if issubclass(conversion_function[0], tensor_type):
191      raise ValueError(f'{tensor_type} has already been registered so ignore '
192                       'it.')
193
194  _REGISTERED_EXPANSIONS.insert(0, (tensor_type, fetch_function, feed_function,
195                                    feed_function_for_partial_run))
196
197
198def _is_attrs_instance(obj):
199  """Returns True if the given obj is an instance of attrs-decorated class."""
200  return getattr(obj.__class__, '__attrs_attrs__', None) is not None
201
202
203def _get_attrs_values(obj):
204  """Returns the list of values from an attrs instance."""
205  attrs = getattr(obj.__class__, '__attrs_attrs__')
206  return [getattr(obj, a.name) for a in attrs]
207
208
209class _FetchMapper(object):
210  """Definition of the interface provided by fetch mappers.
211
212  Fetch mappers are utility classes used by the _FetchHandler to handle
213  arbitrary structures for the `fetch` argument to `Session.run()`.
214
215  The `fetch` argument can be of various shapes: single tensor or op, list of
216  fetches, tuple of fetches, namedtuple of fetches, or dict of fetches.  The
217  structures can be arbitrarily nested.
218
219  The low level run() API only wants a list of tensor or op names.  The various
220  `_FetchMapper` subclasses below take care of handling the different shapes:
221  uniquifying the fetches, and constructing results with the original shape.
222  """
223
224  def unique_fetches(self):
225    """Return the list of unique tensors or ops needed by this fetch mapper.
226
227    Returns:
228      A list of tensors or ops.
229    """
230    raise NotImplementedError(
231        'unique_fetches must be implemented by subclasses')
232
233  def build_results(self, values):
234    """Build results that match the original shape of the fetch.
235
236    Args:
237      values: List of values returned by run(). The values correspond exactly to
238        the list tensors or ops returned by unique_fetches().
239
240    Returns:
241      A struct of the same shape as the original fetch object handled by
242      this fetch mapper.  In the returned struct, the original fetches are
243      replaced by their fetched values.
244    """
245    raise NotImplementedError('build_results must be implemented by subclasses')
246
247  @staticmethod
248  def for_fetch(fetch):
249    """Creates fetch mapper that handles the structure of `fetch`.
250
251    The default graph must be the one from which we want to fetch values when
252    this function is called.
253
254    Args:
255      fetch: An arbitrary fetch structure: singleton, list, tuple, namedtuple,
256        or dict.
257
258    Returns:
259      An instance of a subclass of `_FetchMapper` that handles the shape.
260    """
261    if fetch is None:
262      raise TypeError(f'Argument `fetch` = {fetch} has invalid type '
263                      f'"{type(fetch).__name__}". Cannot be None')
264    elif isinstance(fetch, (list, tuple)):
265      # NOTE(touts): This is also the code path for namedtuples.
266      return _ListFetchMapper(fetch)
267    elif isinstance(fetch, collections_abc.Mapping):
268      return _DictFetchMapper(fetch)
269    elif _is_attrs_instance(fetch):
270      return _AttrsFetchMapper(fetch)
271    else:
272      # Look for a handler in the registered expansions.
273      for tensor_type, fetch_fn, _, _ in _REGISTERED_EXPANSIONS:
274        if isinstance(fetch, tensor_type):
275          fetches, contraction_fn = fetch_fn(fetch)
276          return _ElementFetchMapper(fetches, contraction_fn)
277    # Did not find anything.
278    raise TypeError(f'Argument `fetch` = {fetch} has invalid type '
279                    f'"{type(fetch).__name__}"')
280
281
282class _ElementFetchMapper(_FetchMapper):
283  """Fetch mapper for singleton tensors and ops."""
284
285  def __init__(self, fetches, contraction_fn):
286    """Creates an _ElementFetchMapper.
287
288    This is the fetch mapper used for leaves in the fetch struct.  Because of
289    the expansions mechanism, a leaf can actually fetch more than one tensor.
290
291    Also note that the fetches here can be just strings (tensor or op names) or
292    any other object that the graph knows how to convert to a tensor, such as a
293    Variable.  So we have to run each fetch through `as_graph_element()` to get
294    the corresponding tensor or op.
295
296    Args:
297      fetches: List of objects, as returned by a fetch_fn defined in
298        _REGISTERED_EXPANSIONS.
299      contraction_fn: Callable as returned by a fetch_fn.
300    """
301    self._unique_fetches = []
302    for fetch in fetches:
303      try:
304        self._unique_fetches.append(ops.get_default_graph().as_graph_element(
305            fetch, allow_tensor=True, allow_operation=True))
306      except TypeError as e:
307        raise TypeError(f'Argument `fetch` = {fetch} has invalid type '
308                        f'"{type(fetch).__name__}" must be a string or Tensor. '
309                        f'({str(e)})')
310      except ValueError as e:
311        raise ValueError(f'Argument `fetch` = {fetch} cannot be interpreted as '
312                         f'a Tensor. ({str(e)})')
313      except KeyError as e:
314        raise ValueError(f'Argument `fetch` = {fetch} cannot be interpreted as '
315                         f'a Tensor. ({str(e)})')
316    self._contraction_fn = contraction_fn
317
318  def unique_fetches(self):
319    return self._unique_fetches
320
321  def build_results(self, values):
322    if not values:
323      # 'Operation' case
324      return None
325    else:
326      return self._contraction_fn(values)
327
328
329def _uniquify_fetches(fetch_mappers):
330  """Uniquifies fetches from a list of fetch_mappers.
331
332  This is a utility function used by _ListFetchMapper and _DictFetchMapper.  It
333  gathers all the unique fetches from a list of mappers and builds a list
334  containing all of them but without duplicates (unique_fetches).
335
336  It also returns a 2-D list of integers (values_indices) indicating at which
337  index in unique_fetches the fetches of the mappers are located.
338
339  This list is as follows:
340    values_indices[mapper_index][mapper_fetch_index] = unique_fetches_index
341
342  Args:
343    fetch_mappers: list of fetch mappers.
344
345  Returns:
346    A list of fetches.
347    A 2-D list of integers.
348  """
349  unique_fetches = []
350  value_indices = []
351  seen_fetches = {}
352  for m in fetch_mappers:
353    m_value_indices = []
354    for f in m.unique_fetches():
355      j = seen_fetches.get(id(f))
356      if j is None:
357        j = len(seen_fetches)
358        seen_fetches[id(f)] = j
359        unique_fetches.append(f)
360      m_value_indices.append(j)
361    value_indices.append(m_value_indices)
362  return unique_fetches, value_indices
363
364
365class _ListFetchMapper(_FetchMapper):
366  """Fetch mapper for lists, tuples, and namedtuples."""
367
368  def __init__(self, fetches):
369    """Creates a _ListFetchMapper.
370
371    Args:
372      fetches: List, tuple, or namedtuple of fetches.
373    """
374    if isinstance(fetches, wrapt.ObjectProxy):
375      self._fetch_type = type(fetches.__wrapped__)
376    else:
377      self._fetch_type = type(fetches)
378    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
379    self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
380
381  def unique_fetches(self):
382    return self._unique_fetches
383
384  def build_results(self, values):
385    # Create the list of results for each mapper.
386    results = []
387    for m, vi in zip(self._mappers, self._value_indices):
388      results.append(m.build_results([values[j] for j in vi]))
389    # Return a value of the original type of the fetches.
390    if issubclass(self._fetch_type, list):
391      return results
392    elif self._fetch_type == tuple:
393      return tuple(results)
394    else:
395      # This is the code path for namedtuple.
396      return self._fetch_type(*results)
397
398
399class _DictFetchMapper(_FetchMapper):
400  """Fetch mapper for dicts."""
401
402  def __init__(self, fetches):
403    """Creates a _DictFetchMapper.
404
405    Args:
406      fetches: Dict of fetches.
407    """
408    self._fetch_type = type(fetches)
409    if isinstance(fetches, collections.defaultdict):
410      self._type_ctor = functools.partial(collections.defaultdict,
411                                          fetches.default_factory)
412    else:
413      self._type_ctor = self._fetch_type
414
415    self._keys = fetches.keys()
416    self._mappers = [
417        _FetchMapper.for_fetch(fetch) for fetch in fetches.values()
418    ]
419    self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
420
421  def unique_fetches(self):
422    return self._unique_fetches
423
424  def build_results(self, values):
425
426    def _generator():
427      for k, m, vi in zip(self._keys, self._mappers, self._value_indices):
428        yield k, m.build_results([values[j] for j in vi])
429
430    return self._type_ctor(_generator())
431
432
433class _AttrsFetchMapper(_FetchMapper):
434  """Fetch mapper for attrs decorated classes."""
435
436  def __init__(self, fetches):
437    """Creates a _AttrsFetchMapper.
438
439    Args:
440      fetches: An instance of an attrs decorated class.
441    """
442    values = _get_attrs_values(fetches)
443    self._fetch_type = type(fetches)
444    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in values]
445    self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
446
447  def unique_fetches(self):
448    return self._unique_fetches
449
450  def build_results(self, values):
451    results = []
452    for m, vi in zip(self._mappers, self._value_indices):
453      results.append(m.build_results([values[j] for j in vi]))
454    return self._fetch_type(*results)
455
456
457class _FetchHandler(object):
458  """Handler for structured fetches.
459
460  Given a graph, a user-provided structure for fetches, and a feed dict, this
461  class takes care of generating a list of tensor names to fetch and op names
462  to run for a low level `run()` call.
463
464  Given the results of the low level run call, this class can also rebuild a
465  result structure matching the user-provided structure for fetches, but
466  containing the corresponding results.
467  """
468
469  # TODO(touts): Make this class also take care of destructuring the feed
470  # dict instead of doing it in the callers.
471
472  def __init__(self, graph, fetches, feeds, feed_handles=None):
473    """Creates a fetch handler.
474
475    Args:
476      graph: Graph of the fetches.   Used to check for fetchability and to
477        convert all fetches to tensors or ops as needed.
478      fetches: An arbitrary fetch structure: singleton, list, tuple, namedtuple,
479        or dict.
480      feeds: A feed dict where keys are Tensors.
481      feed_handles: A dict from feed Tensors to TensorHandle objects used as
482        direct feeds.
483    """
484    with graph.as_default():
485      self._fetch_mapper = _FetchMapper.for_fetch(fetches)
486    self._fetches = []
487    self._targets = []
488    self._feeds = feeds
489    self._feed_handles = feed_handles or {}
490    self._ops = []
491    self._fetch_handles = {}
492    for fetch in self._fetch_mapper.unique_fetches():
493      if isinstance(fetch, ops.Operation):
494        self._assert_fetchable(graph, fetch)
495        self._targets.append(fetch)
496        self._ops.append(True)
497      else:
498        self._assert_fetchable(graph, fetch.op)
499        self._fetches.append(fetch)
500        self._ops.append(False)
501      # Remember the fetch if it is for a tensor handle.
502      if (isinstance(fetch, ops.Tensor) and
503          (fetch.op.type == 'GetSessionHandle' or
504           fetch.op.type == 'GetSessionHandleV2')):
505        self._fetch_handles[fetch.ref()] = fetch.op.inputs[0].dtype
506    self._final_fetches = [x for x in self._fetches if x.ref() not in feeds]
507
508  def _assert_fetchable(self, graph, op):
509    if not graph.is_fetchable(op):
510      raise errors.InaccessibleTensorError(
511          f'Operation {op.name} has been marked as not fetchable. Typically '
512          'this happens when it is defined in another function or code block. '
513          'Use return values, explicit Python locals or TensorFlow collections '
514          'to access it.')
515
516  def fetches(self):
517    """Return the unique names of tensors to fetch.
518
519    Returns:
520      A list of strings.
521    """
522    return self._final_fetches
523
524  def targets(self):
525    """Return the unique names of ops to run.
526
527    Returns:
528      A list of strings.
529    """
530    return self._targets
531
532  def build_results(self, session, tensor_values):
533    """Build results matching the original fetch shape.
534
535    `tensor_values` must be a list of the same length as
536    the one returned by `fetches()`, and holding the requested
537    fetch values.
538
539    This method builds a struct with the same shape as the original `fetches`
540    passed to the constructor, in which the fetches are replaced by their
541    fetched value.
542
543    Args:
544      session: The enclosing session.  Used for tensor handles.
545      tensor_values: List of values matching the list returned by fetches().
546
547    Returns:
548      A structure of the same shape as the original `fetches` argument but
549        containing tensors or None (for fetched ops).
550    """
551    full_values = []
552    assert len(self._final_fetches) == len(tensor_values)
553    i = 0
554    j = 0
555    for is_op in self._ops:
556      if is_op:
557        full_values.append(None)
558      else:
559        # If the fetch was in the feeds, use the fed value, otherwise
560        # use the returned value.
561        if self._fetches[i].ref() in self._feed_handles:
562          # A fetch had a corresponding direct TensorHandle feed. Call eval()
563          # to obtain the Tensor value from the TensorHandle.
564          value = self._feed_handles[self._fetches[i].ref()].eval()
565        else:
566          value = self._feeds.get(self._fetches[i].ref())
567        if value is None:
568          value = tensor_values[j]
569          j += 1
570        dtype = self._fetch_handles.get(self._fetches[i].ref())
571        if dtype:
572          full_values.append(session_ops.TensorHandle(value, dtype, session))
573        else:
574          full_values.append(value)
575        i += 1
576    assert j == len(tensor_values)
577    return self._fetch_mapper.build_results(full_values)
578
579
580def _name_list(tensor_list):
581  """Utility function for transitioning to the new session API.
582
583  Args:
584    tensor_list: a list of `Tensor`s.
585
586  Returns:
587    A list of each `Tensor`s name (as byte arrays).
588  """
589  return [compat.as_bytes(t.name) for t in tensor_list]
590
591
592class _DeviceAttributes(object):
593  """Struct-like object describing a device's attributes.
594
595  Each device has 3 key properties:
596   - name: the fully-qualified TensorFlow path to the device. For
597        example: /job:worker/replica:0/task:3/device:CPU:0
598   - device_type: the type of the device (e.g. CPU, GPU, TPU, etc.)
599   - memory_limit_bytes: the maximum amount of memory available on the device
600        (in bytes).
601  """
602
603  def __init__(self, name, device_type, memory_limit_bytes, incarnation):
604    self._name = device.canonical_name(name)
605    self._device_type = device_type
606    self._memory_limit_bytes = memory_limit_bytes
607    self._incarnation = incarnation
608
609  @property
610  def name(self):
611    return self._name
612
613  @property
614  def device_type(self):
615    return self._device_type
616
617  @property
618  def memory_limit_bytes(self):
619    return self._memory_limit_bytes
620
621  @property
622  def incarnation(self):
623    return self._incarnation
624
625  def __repr__(self):
626    return '_DeviceAttributes(%s, %s, %d, %d)' % (
627        self.name,
628        self.device_type,
629        self.memory_limit_bytes,
630        self.incarnation,
631    )
632
633
634class BaseSession(SessionInterface):
635  """A class for interacting with a TensorFlow computation.
636
637  The BaseSession enables incremental graph building with inline
638  execution of Operations and evaluation of Tensors.
639  """
640
641  def __init__(self, target='', graph=None, config=None):
642    """Constructs a new TensorFlow session.
643
644    Args:
645      target: (Optional) The TensorFlow execution engine to connect to.
646      graph: (Optional) The graph to be used. If this argument is None, the
647        default graph will be used.
648      config: (Optional) ConfigProto proto used to configure the session. If no
649        config is specified, the global default will be used. The global default
650        can be configured via the tf.config APIs.
651
652    Raises:
653      tf.errors.OpError: Or one of its subclasses if an error occurs while
654        creating the TensorFlow session.
655      TypeError: If one of the arguments has the wrong type.
656    """
657    _python_session_create_counter.get_cell().increase_by(1)
658    if graph is None:
659      self._graph = ops.get_default_graph()
660    else:
661      if not isinstance(graph, ops.Graph):
662        raise TypeError('Argument `graph` must be a tf.Graph, but got '
663                        f'"{type(graph).__name__}"')
664      self._graph = graph
665
666    self._closed = False
667
668    if target is not None:
669      try:
670        self._target = compat.as_bytes(target)
671      except TypeError:
672        if isinstance(target, config_pb2.ConfigProto):
673          raise TypeError('Argument `target` must be a string, but got '
674                          f'"{type(target).__name__}". Did you do '
675                          '"Session(config)" instead of '
676                          '"Session(config=config)"?')
677        raise TypeError('Argument `target` must be a string, but got '
678                        f'"{type(target).__name__}"')
679    else:
680      self._target = None
681
682    self._delete_lock = threading.Lock()
683    self._dead_handles = []
684
685    if config is None:
686      config = context.context().config
687
688    if not isinstance(config, config_pb2.ConfigProto):
689      raise TypeError('Argument `config` must be a tf.ConfigProto, but got '
690                      f'"{type(config).__name__}"')
691
692    if (mixed_precision_global_state.is_mixed_precision_graph_rewrite_enabled()
693        and config.graph_options.rewrite_options.auto_mixed_precision !=
694        rewriter_config_pb2.RewriterConfig.OFF):
695      new_config = config_pb2.ConfigProto()
696      new_config.CopyFrom(config)
697      new_config.graph_options.rewrite_options.auto_mixed_precision = (
698          rewriter_config_pb2.RewriterConfig.ON)
699      config = new_config
700    elif (config.graph_options.rewrite_options.auto_mixed_precision !=
701          rewriter_config_pb2.RewriterConfig.ON):
702      mixed_precision_global_state.set_non_mixed_precision_session_created(True)
703
704    self._config = config
705    self._add_shapes = config.graph_options.infer_shapes
706
707    self._session = None
708    opts = tf_session.TF_NewSessionOptions(target=self._target, config=config)
709    try:
710      # pylint: disable=protected-access
711      with self._graph._c_graph.get() as c_graph:
712        self._session = tf_session.TF_NewSessionRef(c_graph, opts)
713      # pylint: enable=protected-access
714    finally:
715      tf_session.TF_DeleteSessionOptions(opts)
716
717  def list_devices(self):
718    """Lists available devices in this session.
719
720    ```python
721    devices = sess.list_devices()
722    for d in devices:
723      print(d.name)
724    ```
725
726    Where:
727      Each element in the list has the following properties
728      name: A string with the full name of the device. ex:
729          `/job:worker/replica:0/task:3/device:CPU:0`
730      device_type: The type of the device (e.g. `CPU`, `GPU`, `TPU`.)
731      memory_limit: The maximum amount of memory available on the device.
732          Note: depending on the device, it is possible the usable memory could
733          be substantially less.
734
735    Raises:
736      tf.errors.OpError: If it encounters an error (e.g. session is in an
737      invalid state, or network errors occur).
738
739    Returns:
740      A list of devices in the session.
741    """
742    raw_device_list = tf_session.TF_SessionListDevices(self._session)
743    device_list = []
744    size = tf_session.TF_DeviceListCount(raw_device_list)
745    for i in range(size):
746      name = tf_session.TF_DeviceListName(raw_device_list, i)
747      device_type = tf_session.TF_DeviceListType(raw_device_list, i)
748      memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, i)
749      incarnation = tf_session.TF_DeviceListIncarnation(raw_device_list, i)
750      device_list.append(
751          _DeviceAttributes(name, device_type, memory, incarnation))
752    tf_session.TF_DeleteDeviceList(raw_device_list)
753    return device_list
754
755  def close(self):
756    """Closes this session.
757
758    Calling this method frees all resources associated with the session.
759
760    Raises:
761      tf.errors.OpError: Or one of its subclasses if an error occurs while
762        closing the TensorFlow session.
763    """
764    if self._session and not self._closed:
765      self._closed = True
766      tf_session.TF_CloseSession(self._session)
767
768  def __del__(self):
769    # cleanly ignore all exceptions
770    try:
771      self.close()
772    except Exception:  # pylint: disable=broad-except
773      pass
774    if self._session is not None:
775      try:
776        tf_session.TF_DeleteSession(self._session)
777      except (AttributeError, TypeError):
778        # At shutdown, `c_api_util`, `tf_session`, or
779        # `tf_session.TF_DeleteSession` may have been garbage collected, causing
780        # the above method calls to fail. In this case, silently leak since the
781        # program is about to terminate anyway.
782        pass
783      self._session = None
784
785  @property
786  def graph(self):
787    """The graph that was launched in this session."""
788    return self._graph
789
790  @property
791  def graph_def(self):
792    """A serializable version of the underlying TensorFlow graph.
793
794    Returns:
795      A graph_pb2.GraphDef proto containing nodes for all of the Operations in
796      the underlying TensorFlow graph.
797    """
798    return self._graph.as_graph_def(add_shapes=self._add_shapes)
799
800  @property
801  def sess_str(self):
802    return self._target
803
804  def as_default(self):
805    """Returns a context manager that makes this object the default session.
806
807    Use with the `with` keyword to specify that calls to
808    `tf.Operation.run` or `tf.Tensor.eval` should be executed in
809    this session.
810
811    ```python
812    c = tf.constant(..)
813    sess = tf.compat.v1.Session()
814
815    with sess.as_default():
816      assert tf.compat.v1.get_default_session() is sess
817      print(c.eval())
818    ```
819
820    To get the current default session, use `tf.compat.v1.get_default_session`.
821
822    *N.B.* The `as_default` context manager *does not* close the
823    session when you exit the context, and you must close the session
824    explicitly.
825
826    ```python
827    c = tf.constant(...)
828    sess = tf.compat.v1.Session()
829    with sess.as_default():
830      print(c.eval())
831    # ...
832    with sess.as_default():
833      print(c.eval())
834
835    sess.close()
836    ```
837
838    Alternatively, you can use `with tf.compat.v1.Session():` to create a
839    session that is automatically closed on exiting the context,
840    including when an uncaught exception is raised.
841
842    *N.B.* The default session is a property of the current thread. If you
843    create a new thread, and wish to use the default session in that
844    thread, you must explicitly add a `with sess.as_default():` in that
845    thread's function.
846
847    *N.B.* Entering a `with sess.as_default():` block does not affect
848    the current default graph. If you are using multiple graphs, and
849    `sess.graph` is different from the value of
850    `tf.compat.v1.get_default_graph`, you must explicitly enter a
851    `with sess.graph.as_default():` block to make `sess.graph` the default
852    graph.
853
854    Returns:
855      A context manager using this session as the default session.
856    """
857    return ops.default_session(self)
858
859  def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
860    """Runs operations and evaluates tensors in `fetches`.
861
862    This method runs one "step" of TensorFlow computation, by
863    running the necessary graph fragment to execute every `Operation`
864    and evaluate every `Tensor` in `fetches`, substituting the values in
865    `feed_dict` for the corresponding input values.
866
867    The `fetches` argument may be a single graph element, or an arbitrarily
868    nested list, tuple, namedtuple, dict, or OrderedDict containing graph
869    elements at its leaves.  A graph element can be one of the following types:
870
871    * A `tf.Operation`.
872      The corresponding fetched value will be `None`.
873    * A `tf.Tensor`.
874      The corresponding fetched value will be a numpy ndarray containing the
875      value of that tensor.
876    * A `tf.sparse.SparseTensor`.
877      The corresponding fetched value will be a
878      `tf.compat.v1.SparseTensorValue`
879      containing the value of that sparse tensor.
880    * A `get_tensor_handle` op.  The corresponding fetched value will be a
881      numpy ndarray containing the handle of that tensor.
882    * A `string` which is the name of a tensor or operation in the graph.
883
884    The value returned by `run()` has the same shape as the `fetches` argument,
885    where the leaves are replaced by the corresponding values returned by
886    TensorFlow.
887
888    Example:
889
890    ```python
891       a = tf.constant([10, 20])
892       b = tf.constant([1.0, 2.0])
893       # 'fetches' can be a singleton
894       v = session.run(a)
895       # v is the numpy array [10, 20]
896       # 'fetches' can be a list.
897       v = session.run([a, b])
898       # v is a Python list with 2 numpy arrays: the 1-D array [10, 20] and the
899       # 1-D array [1.0, 2.0]
900       # 'fetches' can be arbitrary lists, tuples, namedtuple, dicts:
901       MyData = collections.namedtuple('MyData', ['a', 'b'])
902       v = session.run({'k1': MyData(a, b), 'k2': [b, a]})
903       # v is a dict with
904       # v['k1'] is a MyData namedtuple with 'a' (the numpy array [10, 20]) and
905       # 'b' (the numpy array [1.0, 2.0])
906       # v['k2'] is a list with the numpy array [1.0, 2.0] and the numpy array
907       # [10, 20].
908    ```
909
910    The optional `feed_dict` argument allows the caller to override
911    the value of tensors in the graph. Each key in `feed_dict` can be
912    one of the following types:
913
914    * If the key is a `tf.Tensor`, the
915      value may be a Python scalar, string, list, or numpy ndarray
916      that can be converted to the same `dtype` as that
917      tensor. Additionally, if the key is a
918      `tf.compat.v1.placeholder`, the shape of
919      the value will be checked for compatibility with the placeholder.
920    * If the key is a
921      `tf.sparse.SparseTensor`,
922      the value should be a
923      `tf.compat.v1.SparseTensorValue`.
924    * If the key is a nested tuple of `Tensor`s or `SparseTensor`s, the value
925      should be a nested tuple with the same structure that maps to their
926      corresponding values as above.
927
928    Each value in `feed_dict` must be convertible to a numpy array of the dtype
929    of the corresponding key.
930
931    The optional `options` argument expects a [`RunOptions`] proto. The options
932    allow controlling the behavior of this particular step (e.g. turning tracing
933    on).
934
935    The optional `run_metadata` argument expects a [`RunMetadata`] proto. When
936    appropriate, the non-Tensor output of this step will be collected there. For
937    example, when users turn on tracing in `options`, the profiled info will be
938    collected into this argument and passed back.
939
940    Args:
941      fetches: A single graph element, a list of graph elements, or a dictionary
942        whose values are graph elements or lists of graph elements (described
943        above).
944      feed_dict: A dictionary that maps graph elements to values (described
945        above).
946      options: A [`RunOptions`] protocol buffer
947      run_metadata: A [`RunMetadata`] protocol buffer
948
949    Returns:
950      Either a single value if `fetches` is a single graph element, or
951      a list of values if `fetches` is a list, or a dictionary with the
952      same keys as `fetches` if that is a dictionary (described above).
953      Order in which `fetches` operations are evaluated inside the call
954      is undefined.
955
956    Raises:
957      RuntimeError: If this `Session` is in an invalid state (e.g. has been
958        closed).
959      TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type.
960      ValueError: If `fetches` or `feed_dict` keys are invalid or refer to a
961        `Tensor` that doesn't exist.
962    """
963    options_ptr = tf_session.TF_NewBufferFromString(
964        compat.as_bytes(options.SerializeToString())) if options else None
965    run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None
966
967    try:
968      result = self._run(None, fetches, feed_dict, options_ptr,
969                         run_metadata_ptr)
970      if run_metadata:
971        proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
972        run_metadata.ParseFromString(compat.as_bytes(proto_data))
973    finally:
974      if run_metadata_ptr:
975        tf_session.TF_DeleteBuffer(run_metadata_ptr)
976      if options:
977        tf_session.TF_DeleteBuffer(options_ptr)
978    return result
979
980  def partial_run(self, handle, fetches, feed_dict=None):
981    """Continues the execution with more feeds and fetches.
982
983    This is EXPERIMENTAL and subject to change.
984
985    To use partial execution, a user first calls `partial_run_setup()` and
986    then a sequence of `partial_run()`. `partial_run_setup` specifies the
987    list of feeds and fetches that will be used in the subsequent
988    `partial_run` calls.
989
990    The optional `feed_dict` argument allows the caller to override
991    the value of tensors in the graph. See run() for more information.
992
993    Below is a simple example:
994
995    ```python
996    a = array_ops.placeholder(dtypes.float32, shape=[])
997    b = array_ops.placeholder(dtypes.float32, shape=[])
998    c = array_ops.placeholder(dtypes.float32, shape=[])
999    r1 = math_ops.add(a, b)
1000    r2 = math_ops.multiply(r1, c)
1001
1002    h = sess.partial_run_setup([r1, r2], [a, b, c])
1003    res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
1004    res = sess.partial_run(h, r2, feed_dict={c: res})
1005    ```
1006
1007    Args:
1008      handle: A handle for a sequence of partial runs.
1009      fetches: A single graph element, a list of graph elements, or a dictionary
1010        whose values are graph elements or lists of graph elements (see
1011        documentation for `run`).
1012      feed_dict: A dictionary that maps graph elements to values (described
1013        above).
1014
1015    Returns:
1016      Either a single value if `fetches` is a single graph element, or
1017      a list of values if `fetches` is a list, or a dictionary with the
1018      same keys as `fetches` if that is a dictionary
1019      (see documentation for `run`).
1020
1021    Raises:
1022      tf.errors.OpError: Or one of its subclasses on error.
1023    """
1024    # TODO(touts): Support feeding and fetching the same tensor.
1025    return self._run(handle, fetches, feed_dict, None, None)
1026
1027  def partial_run_setup(self, fetches, feeds=None):
1028    """Sets up a graph with feeds and fetches for partial run.
1029
1030    This is EXPERIMENTAL and subject to change.
1031
1032    Note that contrary to `run`, `feeds` only specifies the graph elements.
1033    The tensors will be supplied by the subsequent `partial_run` calls.
1034
1035    Args:
1036      fetches: A single graph element, or a list of graph elements.
1037      feeds: A single graph element, or a list of graph elements.
1038
1039    Returns:
1040      A handle for partial run.
1041
1042    Raises:
1043      RuntimeError: If this `Session` is in an invalid state (e.g. has been
1044        closed).
1045      TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type.
1046      tf.errors.OpError: Or one of its subclasses if a TensorFlow error happens.
1047    """
1048
1049    def _feed_fn(feed):
1050      for tensor_type, _, _, feed_fn in _REGISTERED_EXPANSIONS:
1051        if isinstance(feed, tensor_type):
1052          return feed_fn(feed)
1053      raise TypeError(f'Feed argument {feed} has invalid type '
1054                      f'"{type(feed).__name__}"')
1055
1056    # Check session.
1057    if self._closed:
1058      raise RuntimeError('Attempted to use a closed Session.')
1059    if self.graph.version == 0:
1060      raise RuntimeError('The Session graph is empty. Add operations to the '
1061                         'graph before calling run().')
1062
1063    if feeds is None:
1064      feeds = []
1065    # Create request.
1066    feed_list = []
1067
1068    # Validate and process feed_list.
1069    is_list_feed = isinstance(feeds, (list, tuple))
1070    if not is_list_feed:
1071      feeds = [feeds]
1072    for feed in feeds:
1073      for subfeed in _feed_fn(feed):
1074        try:
1075          subfeed_t = self.graph.as_graph_element(
1076              subfeed, allow_tensor=True, allow_operation=False)
1077          # pylint: disable=protected-access
1078          feed_list.append(subfeed_t._as_tf_output())
1079          # pylint: enable=protected-access
1080        except Exception as e:
1081          e.message = ('Cannot interpret argument `feed` key as Tensor: '
1082                       f'{e.message}')
1083          e.args = (e.message,)
1084          raise e
1085
1086    # Validate and process fetches.
1087    # TODO(touts): Support feeding and fetching the same tensor.
1088    fetch_handler = _FetchHandler(self._graph, fetches, {})
1089
1090    # Set up a graph with feeds and fetches for partial run.
1091    def _setup_fn(session, feed_list, fetch_list, target_list):
1092      self._extend_graph()
1093      return tf_session.TF_SessionPRunSetup_wrapper(session, feed_list,
1094                                                    fetch_list, target_list)
1095
1096    # pylint: disable=protected-access
1097    final_fetches = [t._as_tf_output() for t in fetch_handler.fetches()]
1098    final_targets = [op._c_op for op in fetch_handler.targets()]
1099    # pylint: enable=protected-access
1100
1101    return self._do_call(_setup_fn, self._session, feed_list, final_fetches,
1102                         final_targets)
1103
1104  def _run(self, handle, fetches, feed_dict, options, run_metadata):
1105    """Perform either run or partial_run, depending the presence of `handle`."""
1106
1107    def _feed_fn(feed, feed_val):
1108      for tensor_type, _, feed_fn, _ in _REGISTERED_EXPANSIONS:
1109        if isinstance(feed, tensor_type):
1110          return feed_fn(feed, feed_val)
1111      raise TypeError(f'{feed} in argument `feed_dict` has invalid type '
1112                      f'"{type(feed).__name__}"')
1113
1114    # Check session.
1115    if self._closed:
1116      raise RuntimeError('Attempted to use a closed Session.')
1117    if self.graph.version == 0:
1118      raise RuntimeError('The Session graph is empty. Add operations to the '
1119                         'graph before calling run().')
1120
1121    # Create request.
1122    feed_dict_tensor = {}
1123    feed_map = {}
1124
1125    # Validate and process feed_dict.
1126    feed_handles = {}
1127    if feed_dict:
1128      feed_dict = nest.flatten_dict_items(feed_dict)
1129      for feed, feed_val in feed_dict.items():
1130        for subfeed, subfeed_val in _feed_fn(feed, feed_val):
1131          try:
1132            subfeed_t = self.graph.as_graph_element(
1133                subfeed, allow_tensor=True, allow_operation=False)
1134          except Exception as e:
1135            raise TypeError(
1136                f'Cannot interpret feed_dict key as Tensor: {e.args[0]}')
1137
1138          if isinstance(subfeed_val, ops.Tensor):
1139            raise TypeError(
1140                'The value of a feed cannot be a tf.Tensor object. Acceptable '
1141                'feed values include Python scalars, strings, lists, numpy '
1142                'ndarrays, or TensorHandles. For reference, the tensor object '
1143                f'was {str(feed_val)} which was passed to the argument '
1144                f'`feed_dict` with key {str(feed)}.')
1145
1146          subfeed_dtype = subfeed_t.dtype.as_numpy_dtype
1147          if isinstance(subfeed_val, int) and _convert_to_numpy_obj(
1148              subfeed_dtype, subfeed_val) != subfeed_val:
1149            raise TypeError(
1150                f'Type of feed value {str(subfeed_val)} with type ' +
1151                f'{str(type(subfeed_val))} is not compatible with Tensor type '
1152                f'{str(subfeed_dtype)}. Try explicitly setting the type of the '
1153                'feed tensor to a larger type (e.g. int64).')
1154
1155          is_tensor_handle_feed = isinstance(subfeed_val,
1156                                             session_ops.TensorHandle)
1157          if is_tensor_handle_feed:
1158            np_val = subfeed_val.to_numpy_array()
1159            feed_handles[subfeed_t.ref()] = subfeed_val
1160          else:
1161            np_val = np.asarray(subfeed_val, dtype=subfeed_dtype)
1162
1163          if (not is_tensor_handle_feed and
1164              not subfeed_t.get_shape().is_compatible_with(np_val.shape)):
1165            raise ValueError(
1166                f'Cannot feed value of shape {str(np_val.shape)} for Tensor '
1167                f'{subfeed_t.name}, which has shape '
1168                f'{str(subfeed_t.get_shape())}')
1169          if not self.graph.is_feedable(subfeed_t):
1170            raise ValueError(f'Tensor {subfeed_t.name} may not be fed.')
1171
1172          feed_dict_tensor[subfeed_t.ref()] = np_val
1173          feed_map[compat.as_bytes(subfeed_t.name)] = (subfeed_t, subfeed_val)
1174
1175    # Create a fetch handler to take care of the structure of fetches.
1176    fetch_handler = _FetchHandler(
1177        self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
1178
1179    # Run request and get response.
1180    # We need to keep the returned movers alive for the following _do_run().
1181    # These movers are no longer needed when _do_run() completes, and
1182    # are deleted when `movers` goes out of scope when this _run() ends.
1183    # TODO(yuanbyu, keveman): Revisit whether we should just treat feeding
1184    # of a handle from a different device as an error.
1185    _ = self._update_with_movers(feed_dict_tensor, feed_map)
1186    final_fetches = fetch_handler.fetches()
1187    final_targets = fetch_handler.targets()
1188    # We only want to really perform the run if fetches or targets are provided,
1189    # or if the call is a partial run that specifies feeds.
1190    if final_fetches or final_targets or (handle and feed_dict_tensor):
1191      results = self._do_run(handle, final_targets, final_fetches,
1192                             feed_dict_tensor, options, run_metadata)
1193    else:
1194      results = []
1195    return fetch_handler.build_results(self, results)
1196
1197  def make_callable(self, fetches, feed_list=None, accept_options=False):
1198    """Returns a Python callable that runs a particular step.
1199
1200    The returned callable will take `len(feed_list)` arguments whose types
1201    must be compatible feed values for the respective elements of `feed_list`.
1202    For example, if element `i` of `feed_list` is a `tf.Tensor`, the `i`th
1203    argument to the returned callable must be a numpy ndarray (or something
1204    convertible to an ndarray) with matching element type and shape. See
1205    `tf.Session.run` for details of the allowable feed key and value types.
1206
1207    The returned callable will have the same return type as
1208    `tf.Session.run(fetches, ...)`. For example, if `fetches` is a `tf.Tensor`,
1209    the callable will return a numpy ndarray; if `fetches` is a `tf.Operation`,
1210    it will return `None`.
1211
1212    Args:
1213      fetches: A value or list of values to fetch. See `tf.Session.run` for
1214        details of the allowable fetch types.
1215      feed_list: (Optional.) A list of `feed_dict` keys. See `tf.Session.run`
1216        for details of the allowable feed key types.
1217      accept_options: (Optional.) If `True`, the returned `Callable` will be
1218        able to accept `tf.compat.v1.RunOptions` and `tf.compat.v1.RunMetadata`
1219        as optional keyword arguments `options` and `run_metadata`,
1220        respectively, with the same syntax and semantics as `tf.Session.run`,
1221        which is useful for certain use cases (profiling and debugging) but will
1222        result in measurable slowdown of the `Callable`'s
1223        performance. Default: `False`.
1224
1225    Returns:
1226      A function that when called will execute the step defined by
1227      `feed_list` and `fetches` in this session.
1228
1229    Raises:
1230      TypeError: If `fetches` or `feed_list` cannot be interpreted
1231        as arguments to `tf.Session.run`.
1232    """
1233    if feed_list is not None:
1234      if not isinstance(feed_list, (list, tuple)):
1235        raise TypeError('Argument `feed_list` must be a list or tuple. '
1236                        f'Received: feed_list={feed_list}')
1237      # Delegate any non-empty feed lists to the existing `run()` logic.
1238      # TODO(mrry): Refactor the feed handling logic from
1239      # `Session._run()` so that we can convert the feeds to a list of
1240      # strings here.
1241      def _generic_run(*feed_args, **kwargs):
1242        feed_dict = {
1243            feed: feed_val for feed, feed_val in zip(feed_list, feed_args)
1244        }
1245        return self.run(fetches, feed_dict=feed_dict, **kwargs)
1246
1247      return _generic_run
1248
1249    # Ensure any changes to the graph are reflected in the runtime.
1250    # Note that we don't need to do this on subsequent calls to the
1251    # returned object, because the arguments to `fetches` must already be
1252    # in the graph.
1253    self._extend_graph()
1254
1255    # Create a fetch handler to take care of the structure of fetches.
1256    fetch_handler = _FetchHandler(self._graph, fetches, {})
1257    # pylint: disable=protected-access
1258    fetch_list = [t._as_tf_output() for t in fetch_handler.fetches()]
1259    target_list = [op._c_op for op in fetch_handler.targets()]
1260
1261    # pylint: enable=protected-access
1262
1263    def _callable_template_with_options_and_metadata(fetch_list,
1264                                                     target_list,
1265                                                     fetch_handler,
1266                                                     options=None,
1267                                                     run_metadata=None):
1268      """Template callable that accepts RunOptions and RunMetadata."""
1269      options_ptr = tf_session.TF_NewBufferFromString(
1270          compat.as_bytes(options.SerializeToString())) if options else None
1271      run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None
1272      try:
1273        results = self._call_tf_sessionrun(options_ptr, {}, fetch_list,
1274                                           target_list, run_metadata_ptr)
1275        if fetch_handler:
1276          results = fetch_handler.build_results(self, results)
1277        else:
1278          results = results[0] if results else None
1279        if run_metadata:
1280          proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
1281          run_metadata.ParseFromString(compat.as_bytes(proto_data))
1282      finally:
1283        if run_metadata_ptr:
1284          tf_session.TF_DeleteBuffer(run_metadata_ptr)
1285        if options:
1286          tf_session.TF_DeleteBuffer(options_ptr)
1287      return results
1288
1289    if accept_options:
1290      return functools.partial(_callable_template_with_options_and_metadata,
1291                               fetch_list, target_list, fetch_handler)
1292    elif isinstance(fetches, ops.Operation):
1293      # Special case for fetching a single operation, because the
1294      # function will have no return value.
1295      assert not fetch_list
1296      assert len(target_list) == 1
1297
1298      def _single_operation_run():
1299        self._call_tf_sessionrun(None, {}, [], target_list, None)
1300
1301      return _single_operation_run
1302    elif isinstance(fetches, ops.Tensor):
1303      # Special case for fetching a single tensor, because the
1304      # function can return the result of `TF_Run()` directly.
1305      assert len(fetch_list) == 1
1306      assert not target_list
1307
1308      def _single_tensor_run():
1309        results = self._call_tf_sessionrun(None, {}, fetch_list, [], None)
1310        return results[0]
1311
1312      return _single_tensor_run
1313    else:
1314      # In all other cases, we must use `fetch_handler` to build the
1315      # results for us.
1316      def _fetch_handler_run():
1317        results = self._call_tf_sessionrun(None, {}, fetch_list, target_list,
1318                                           None)
1319        return fetch_handler.build_results(self, results)
1320
1321      return _fetch_handler_run
1322
1323  # Captures the name of a node in an error status. The regex below matches
1324  # both the old and the new formats:
1325  # Old format: [[Node: <node_name> = ...]]
1326  # New format: [[{{node <node_name>}} = ...]]
1327  _NODEDEF_NAME_RE = re.compile(
1328      r'\[\[(Node: )?(\{\{node )?([^\} ]*)(\}\})?\s*=*')
1329
1330  def _do_run(self, handle, target_list, fetch_list, feed_dict, options,
1331              run_metadata):
1332    """Runs a step based on the given fetches and feeds.
1333
1334    Args:
1335      handle: a handle for partial_run. None if this is just a call to run().
1336      target_list: A list of operations to be run, but not fetched.
1337      fetch_list: A list of tensors to be fetched.
1338      feed_dict: A dictionary that maps tensors to numpy ndarrays.
1339      options: A (pointer to a) [`RunOptions`] protocol buffer, or None
1340      run_metadata: A (pointer to a) [`RunMetadata`] protocol buffer, or None
1341
1342    Returns:
1343      A list of numpy ndarrays, corresponding to the elements of
1344      `fetch_list`.  If the ith element of `fetch_list` contains the
1345      name of an operation, the first Tensor output of that operation
1346      will be returned for that element.
1347
1348    Raises:
1349      tf.errors.OpError: Or one of its subclasses on error.
1350    """
1351    # pylint: disable=protected-access
1352    feeds = dict((t.deref()._as_tf_output(), v) for t, v in feed_dict.items())
1353    fetches = [t._as_tf_output() for t in fetch_list]
1354    targets = [op._c_op for op in target_list]
1355
1356    # pylint: enable=protected-access
1357
1358    def _run_fn(feed_dict, fetch_list, target_list, options, run_metadata):
1359      # Ensure any changes to the graph are reflected in the runtime.
1360      self._extend_graph()
1361      return self._call_tf_sessionrun(options, feed_dict, fetch_list,
1362                                      target_list, run_metadata)
1363
1364    def _prun_fn(handle, feed_dict, fetch_list):
1365      if target_list:
1366        raise RuntimeError('partial_run() requires empty `target_list`. '
1367                           f'Received: target_list={target_list} (non-empty)')
1368      return self._call_tf_sessionprun(handle, feed_dict, fetch_list)
1369
1370    if handle is None:
1371      return self._do_call(_run_fn, feeds, fetches, targets, options,
1372                           run_metadata)
1373    else:
1374      return self._do_call(_prun_fn, handle, feeds, fetches)
1375
1376  def _do_call(self, fn, *args):
1377    try:
1378      return fn(*args)
1379    except errors.OpError as e:
1380      message = compat.as_text(e.message)
1381      m = BaseSession._NODEDEF_NAME_RE.search(message)
1382      node_def = None
1383      op = None
1384      if m is not None:
1385        node_name = m.group(3)
1386        try:
1387          op = self._graph.get_operation_by_name(node_name)
1388          node_def = op.node_def
1389        except KeyError:
1390          pass
1391      message = error_interpolation.interpolate(message, self._graph)
1392      if 'only supports NHWC tensor format' in message:
1393        message += ('\nA possible workaround: Try disabling Grappler optimizer'
1394                    '\nby modifying the config for creating the session eg.'
1395                    '\nsession_config.graph_options.rewrite_options.'
1396                    'disable_meta_optimizer = True')
1397      raise type(e)(node_def, op, message)  # pylint: disable=no-value-for-parameter
1398
1399  def _extend_graph(self):
1400    with self._graph._session_run_lock():  # pylint: disable=protected-access
1401      tf_session.ExtendSession(self._session)
1402
1403  # The threshold to run garbage collection to delete dead tensors.
1404  _DEAD_HANDLES_THRESHOLD = 10
1405
1406  def _register_dead_handle(self, handle):
1407    # Register a dead handle in the session. Delete the dead tensors when
1408    # the number of dead tensors exceeds certain threshold.
1409    tensors_to_delete = None
1410    with self._delete_lock:
1411      self._dead_handles.append(handle)
1412      if len(self._dead_handles) == BaseSession._DEAD_HANDLES_THRESHOLD:
1413        tensors_to_delete = self._dead_handles
1414        self._dead_handles = []
1415    # Delete the dead tensors.
1416    if tensors_to_delete:
1417      feeds = {}
1418      fetches = []
1419      for deleter_key, tensor_handle in enumerate(tensors_to_delete):
1420        holder, deleter = session_ops._get_handle_deleter(
1421            self.graph, deleter_key, tensor_handle)
1422        feeds[holder] = tensor_handle
1423        fetches.append(deleter)
1424      self.run(fetches, feed_dict=feeds)
1425
1426  def _update_with_movers(self, feed_dict, feed_map):
1427    # If a tensor handle that is fed to a device incompatible placeholder,
1428    # we move the tensor to the right device, generate a new tensor handle,
1429    # and update `feed_dict` to use the new handle.
1430    handle_movers = []
1431    for feed_name, val in feed_map.items():
1432      mover = session_ops._get_handle_mover(self.graph, *val)
1433      if mover:
1434        handle_movers.append((feed_name, val[1], mover))
1435    # Transfer a tensor to the right device if needed.
1436    if not handle_movers:
1437      return []
1438    else:
1439      feeds = {}
1440      fetches = []
1441      for _, handle, mover in handle_movers:
1442        feeds[mover[0]] = handle
1443        fetches.append(mover[1])
1444      handles = self.run(fetches, feed_dict=feeds)
1445      for handle_mover, handle in zip(handle_movers, handles):
1446        np_val = np.array(handle.handle, dtype=np.object_)
1447        feed_name = handle_mover[0]
1448        feed_tensor = feed_map[feed_name][0]
1449        feed_dict[feed_tensor.ref()] = np_val
1450      return handles
1451
1452  def _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list,
1453                          run_metadata):
1454    return tf_session.TF_SessionRun_wrapper(self._session, options, feed_dict,
1455                                            fetch_list, target_list,
1456                                            run_metadata)
1457
1458  def _call_tf_sessionprun(self, handle, feed_dict, fetch_list):
1459    return tf_session.TF_SessionPRun_wrapper(self._session, handle, feed_dict,
1460                                             fetch_list)
1461
1462  # pylint: disable=protected-access
1463  class _Callable(object):
1464    """Experimental wrapper for the C++ `Session::MakeCallable()` API."""
1465
1466    def __init__(self, session, callable_options):
1467      self._session = session
1468      self._handle = None
1469      options_ptr = tf_session.TF_NewBufferFromString(
1470          compat.as_bytes(callable_options.SerializeToString()))
1471      try:
1472        self._handle = tf_session.TF_SessionMakeCallable(
1473            session._session, options_ptr)
1474      finally:
1475        tf_session.TF_DeleteBuffer(options_ptr)
1476
1477    def __call__(self, *args, **kwargs):
1478      run_metadata = kwargs.get('run_metadata', None)
1479      try:
1480        run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None
1481        ret = tf_session.TF_SessionRunCallable(self._session._session,
1482                                               self._handle, args,
1483                                               run_metadata_ptr)
1484        if run_metadata:
1485          proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
1486          run_metadata.ParseFromString(compat.as_bytes(proto_data))
1487      finally:
1488        if run_metadata_ptr:
1489          tf_session.TF_DeleteBuffer(run_metadata_ptr)
1490      return ret
1491
1492    def __del__(self):
1493      # NOTE(mrry): It is possible that `self._session.__del__()` could be
1494      # called before this destructor, in which case `self._session._session`
1495      # will be `None`.
1496      if (self._handle is not None and self._session._session is not None and
1497          not self._session._closed):
1498        tf_session.TF_SessionReleaseCallable(self._session._session,
1499                                             self._handle)
1500
1501  # pylint: enable=protected-access
1502
1503  def _make_callable_from_options(self, callable_options):
1504    """Returns a handle to a "callable" with the given options.
1505
1506    Args:
1507      callable_options: A `CallableOptions` protocol buffer message describing
1508        the computation that will be performed by the callable.
1509
1510    Returns:
1511      A handle to the new callable.
1512    """
1513    self._extend_graph()
1514    return BaseSession._Callable(self, callable_options)
1515
1516
1517@tf_export(v1=['Session'])
1518class Session(BaseSession):
1519  """A class for running TensorFlow operations.
1520
1521  A `Session` object encapsulates the environment in which `Operation`
1522  objects are executed, and `Tensor` objects are evaluated. For
1523  example:
1524
1525  ```python
1526  tf.compat.v1.disable_eager_execution() # need to disable eager in TF2.x
1527  # Build a graph.
1528  a = tf.constant(5.0)
1529  b = tf.constant(6.0)
1530  c = a * b
1531
1532  # Launch the graph in a session.
1533  sess = tf.compat.v1.Session()
1534
1535  # Evaluate the tensor `c`.
1536  print(sess.run(c)) # prints 30.0
1537  ```
1538
1539  A session may own resources, such as
1540  `tf.Variable`, `tf.queue.QueueBase`,
1541  and `tf.compat.v1.ReaderBase`. It is important to release
1542  these resources when they are no longer required. To do this, either
1543  invoke the `tf.Session.close` method on the session, or use
1544  the session as a context manager. The following two examples are
1545  equivalent:
1546
1547  ```python
1548  # Using the `close()` method.
1549  sess = tf.compat.v1.Session()
1550  sess.run(...)
1551  sess.close()
1552
1553  # Using the context manager.
1554  with tf.compat.v1.Session() as sess:
1555    sess.run(...)
1556  ```
1557
1558  The
1559  [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto)
1560  protocol buffer exposes various configuration options for a
1561  session. For example, to create a session that uses soft constraints
1562  for device placement, and log the resulting placement decisions,
1563  create a session as follows:
1564
1565  ```python
1566  # Launch the graph in a session that allows soft device placement and
1567  # logs the placement decisions.
1568  sess = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(
1569      allow_soft_placement=True,
1570      log_device_placement=True))
1571  ```
1572
1573  @compatibility(TF2)
1574  `Session` does not work with either eager execution or `tf.function`, and you
1575  should not invoke it directly. To migrate code that uses sessions to TF2,
1576  rewrite the code without it. See the
1577  [migration
1578  guide](https://www.tensorflow.org/guide/migrate#1_replace_v1sessionrun_calls)
1579  on replacing `Session.run` calls.
1580  @end_compatibility
1581  """
1582
1583  def __init__(self, target='', graph=None, config=None):
1584    """Creates a new TensorFlow session.
1585
1586    If no `graph` argument is specified when constructing the session,
1587    the default graph will be launched in the session. If you are
1588    using more than one graph (created with `tf.Graph()`) in the same
1589    process, you will have to use different sessions for each graph,
1590    but each graph can be used in multiple sessions. In this case, it
1591    is often clearer to pass the graph to be launched explicitly to
1592    the session constructor.
1593
1594    Args:
1595      target: (Optional.) The execution engine to connect to. Defaults to using
1596        an in-process engine. See
1597        [Distributed TensorFlow](https://tensorflow.org/deploy/distributed) for
1598          more examples.
1599      graph: (Optional.) The `Graph` to be launched (described above).
1600      config: (Optional.) A
1601        [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto)
1602          protocol buffer with configuration options for the session.
1603    """
1604    super(Session, self).__init__(target, graph, config=config)
1605    # NOTE(mrry): Create these on first `__enter__` to avoid a reference cycle.
1606    self._default_graph_context_manager = None
1607    self._default_session_context_manager = None
1608
1609  def __enter__(self):
1610    if self._default_graph_context_manager is None:
1611      self._default_graph_context_manager = self.graph.as_default()
1612    else:
1613      raise RuntimeError('Session context managers are not re-entrant. '
1614                         'Use `Session.as_default()` if you want to enter '
1615                         'a session multiple times.')
1616    if self._default_session_context_manager is None:
1617      self._default_session_context_manager = self.as_default()
1618    self._default_graph_context_manager.__enter__()
1619    return self._default_session_context_manager.__enter__()
1620
1621  def __exit__(self, exec_type, exec_value, exec_tb):
1622    if exec_type is errors.OpError:
1623      logging.error('Session closing due to OpError: %s', (exec_value,))
1624    try:
1625      self._default_session_context_manager.__exit__(exec_type, exec_value,
1626                                                     exec_tb)
1627    except RuntimeError as error:
1628      if error == exec_value:
1629        # NOTE(skyewm): for some reason, in Python3,
1630        # _default_session_context_manager.__exit__ will re-raise the "not
1631        # re-entrant" exception raised in __enter__ above (note that if we're
1632        # here, we're in the outer session context manager, since __exit__ is
1633        # not called when __enter__ raises an exception). We still want to
1634        # continue cleaning up this context manager before the exception is
1635        # further propagated, so we ignore it here (note that it'll continue
1636        # being propagated after this method completes).
1637        pass
1638      else:
1639        raise
1640    self._default_graph_context_manager.__exit__(exec_type, exec_value, exec_tb)
1641
1642    self._default_session_context_manager = None
1643    self._default_graph_context_manager = None
1644
1645    # If we are closing due to an exception, set a time limit on our Close() to
1646    # avoid blocking forever.
1647    # TODO(b/120204635) remove this when deadlock is fixed.
1648    if exec_type:
1649      close_thread = threading.Thread(
1650          name='SessionCloseThread', target=self.close)
1651      close_thread.daemon = True
1652      close_thread.start()
1653      close_thread.join(30.0)
1654      if close_thread.is_alive():
1655        logging.error(
1656            'Session failed to close after 30 seconds. Continuing after this '
1657            'point may leave your program in an undefined state.')
1658    else:
1659      self.close()
1660
1661  @staticmethod
1662  def reset(target, containers=None, config=None):
1663    """Resets resource containers on `target`, and close all connected sessions.
1664
1665    A resource container is distributed across all workers in the
1666    same cluster as `target`.  When a resource container on `target`
1667    is reset, resources associated with that container will be cleared.
1668    In particular, all Variables in the container will become undefined:
1669    they lose their values and shapes.
1670
1671    NOTE:
1672    (i) reset() is currently only implemented for distributed sessions.
1673    (ii) Any sessions on the master named by `target` will be closed.
1674
1675    If no resource containers are provided, all containers are reset.
1676
1677    Args:
1678      target: The execution engine to connect to.
1679      containers: A list of resource container name strings, or `None` if all of
1680        all the containers are to be reset.
1681      config: (Optional.) Protocol buffer with configuration options.
1682
1683    Raises:
1684      tf.errors.OpError: Or one of its subclasses if an error occurs while
1685        resetting containers.
1686    """
1687    if target is not None:
1688      target = compat.as_bytes(target)
1689    if containers is not None:
1690      containers = [compat.as_bytes(c) for c in containers]
1691    else:
1692      containers = []
1693    tf_session.TF_Reset(target, containers, config)
1694
1695
1696@tf_export(v1=['InteractiveSession'])
1697class InteractiveSession(BaseSession):
1698  """A TensorFlow `Session` for use in interactive contexts, such as a shell.
1699
1700  The only difference with a regular `Session` is that an `InteractiveSession`
1701  installs itself as the default session on construction.
1702  The methods `tf.Tensor.eval`
1703  and `tf.Operation.run`
1704  will use that session to run ops.
1705
1706  This is convenient in interactive shells and [IPython
1707  notebooks](http://ipython.org), as it avoids having to pass an explicit
1708  `Session` object to run ops.
1709
1710  For example:
1711
1712  ```python
1713  sess = tf.compat.v1.InteractiveSession()
1714  a = tf.constant(5.0)
1715  b = tf.constant(6.0)
1716  c = a * b
1717  # We can just use 'c.eval()' without passing 'sess'
1718  print(c.eval())
1719  sess.close()
1720  ```
1721
1722  Note that a regular session installs itself as the default session when it
1723  is created in a `with` statement.  The common usage in non-interactive
1724  programs is to follow that pattern:
1725
1726  ```python
1727  a = tf.constant(5.0)
1728  b = tf.constant(6.0)
1729  c = a * b
1730  with tf.compat.v1.Session():
1731    # We can also use 'c.eval()' here.
1732    print(c.eval())
1733  ```
1734  """
1735
1736  _count_lock = threading.Lock()
1737  _active_session_count = 0  # GUARDED_BY(_count_lock)
1738
1739  def __init__(self, target='', graph=None, config=None):
1740    """Creates a new interactive TensorFlow session.
1741
1742    If no `graph` argument is specified when constructing the session,
1743    the default graph will be launched in the session. If you are
1744    using more than one graph (created with `tf.Graph()`) in the same
1745    process, you will have to use different sessions for each graph,
1746    but each graph can be used in multiple sessions. In this case, it
1747    is often clearer to pass the graph to be launched explicitly to
1748    the session constructor.
1749
1750    Args:
1751      target: (Optional.) The execution engine to connect to. Defaults to using
1752        an in-process engine.
1753      graph: (Optional.) The `Graph` to be launched (described above).
1754      config: (Optional) `ConfigProto` proto used to configure the session.
1755    """
1756    if not config:
1757      # If config is not provided, choose some reasonable defaults for
1758      # interactive use:
1759      #
1760      #   - Grow GPU memory as needed at the cost of fragmentation.
1761      gpu_options = config_pb2.GPUOptions(allow_growth=True)
1762      config = config_pb2.ConfigProto(gpu_options=gpu_options)
1763    # Interactive sessions always place pruned graphs.
1764    config.graph_options.place_pruned_graph = True
1765
1766    super(InteractiveSession, self).__init__(target, graph, config)
1767    with InteractiveSession._count_lock:
1768      if InteractiveSession._active_session_count > 0:
1769        warnings.warn('An interactive session is already active. This can '
1770                      'cause out-of-memory errors in some cases. You must '
1771                      'explicitly call `InteractiveSession.close()` to release '
1772                      'resources held by the other session(s).')
1773      InteractiveSession._active_session_count += 1
1774    # NOTE(mrry): We do not use `Session._closed` here because it has unhelpful
1775    # semantics (in particular, it is not set to true if `Session.close()` is
1776    # called on a session that has not been "opened" by running a step) and we
1777    # cannot change those semantics without breaking existing code.
1778    self._explicitly_closed = False
1779
1780    self._default_session = self.as_default()
1781    self._default_session.enforce_nesting = False
1782    self._default_session.__enter__()
1783    self._explicit_graph = graph
1784    if self._explicit_graph is not None:
1785      self._default_graph = graph.as_default()
1786      self._default_graph.enforce_nesting = False
1787      self._default_graph.__enter__()
1788
1789  def close(self):
1790    """Closes an `InteractiveSession`."""
1791    super(InteractiveSession, self).close()
1792    with InteractiveSession._count_lock:
1793      if not self._explicitly_closed:
1794        InteractiveSession._active_session_count -= 1
1795        self._explicitly_closed = True
1796      else:
1797        return
1798    if self._explicit_graph is not None:
1799      self._default_graph.__exit__(None, None, None)
1800      self._default_graph = None
1801    self._default_session.__exit__(None, None, None)
1802    self._default_session = None
1803