1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A client interface for TensorFlow.""" 16 17import collections 18import functools 19import re 20import threading 21import warnings 22 23import numpy as np 24import wrapt 25 26from tensorflow.core.protobuf import config_pb2 27from tensorflow.core.protobuf import rewriter_config_pb2 28from tensorflow.python.client import pywrap_tf_session as tf_session 29from tensorflow.python.eager import context 30from tensorflow.python.eager import monitoring 31from tensorflow.python.framework import device 32from tensorflow.python.framework import error_interpolation 33from tensorflow.python.framework import errors 34from tensorflow.python.framework import indexed_slices 35from tensorflow.python.framework import ops 36from tensorflow.python.framework import sparse_tensor 37from tensorflow.python.ops import session_ops 38from tensorflow.python.platform import tf_logging as logging 39from tensorflow.python.training.experimental import mixed_precision_global_state 40from tensorflow.python.util import compat 41from tensorflow.python.util import nest 42from tensorflow.python.util.compat import collections_abc 43from tensorflow.python.util.tf_export import tf_export 44 45_python_session_create_counter = monitoring.Counter( 46 '/tensorflow/api/python/session_create_counter', 47 'Counter for number of sessions created in Python.') 48 49 50class SessionInterface(object): 51 """Base class for implementations of TensorFlow client sessions.""" 52 53 @property 54 def graph(self): 55 """The underlying TensorFlow graph, to be used in building Operations.""" 56 raise NotImplementedError('graph') 57 58 @property 59 def sess_str(self): 60 """The TensorFlow process to which this session will connect.""" 61 raise NotImplementedError('sess_str') 62 63 def run(self, fetches, feed_dict=None, options=None, run_metadata=None): 64 """Runs operations in the session. See `BaseSession.run()` for details.""" 65 raise NotImplementedError('run') 66 67 def partial_run_setup(self, fetches, feeds=None): 68 """Sets up the feeds and fetches for partial runs in the session.""" 69 raise NotImplementedError('partial_run_setup') 70 71 def partial_run(self, handle, fetches, feed_dict=None): 72 """Continues the execution with additional feeds and fetches.""" 73 raise NotImplementedError('partial_run') 74 75 76def _get_indexed_slices_value_from_fetches(fetched_vals): 77 return indexed_slices.IndexedSlicesValue( 78 fetched_vals[0], fetched_vals[1], 79 fetched_vals[2] if len(fetched_vals) == 3 else None) 80 81 82def _get_feeds_for_indexed_slices(feed, feed_val): 83 return list( 84 zip([feed.values, feed.indices] if feed.dense_shape is None else 85 [feed.values, feed.indices, feed.dense_shape], feed_val)) 86 87 88# List of extensions supported to convert run arguments into actual fetches and 89# feeds. 90# 91# Each element in the list is a tuple of (Type, fetch_fn, feed_fn1, feed_fn2), 92# where the function signatures are: 93# fetch_fn : Type -> (list of Tensors, 94# lambda: list of fetched np.ndarray -> TypeVal) 95# feed_fn1 : Type, TypeVal -> list of (Tensor, value) 96# feed_fn2 : Type -> list of Tensors 97# 98# `fetch_fn` describes how to expand fetch into its 99# component Tensors and how to contract the fetched results back into 100# a single return value. 101# 102# Each feed function describes how to unpack a single fed value and map it to 103# feeds of one or more tensors and their corresponding values: `feed_fn1` is 104# used to feed a run, `feed_fn2` to set up a partial run. 105# 106# TODO(touts): We could reimplement these as specialized _FeedMapper 107# implementations after we refactor the feed handling code to use them. 108# 109# Eventually, this registration could be opened up to support custom Tensor 110# expansions. 111# pylint: disable=g-long-lambda 112_REGISTERED_EXPANSIONS = [ 113 # SparseTensors are fetched as SparseTensorValues. They can be fed 114 # SparseTensorValues or normal tuples. 115 (sparse_tensor.SparseTensor, lambda fetch: ([ 116 fetch.indices, fetch.values, fetch.dense_shape 117 ], lambda fetched_vals: sparse_tensor.SparseTensorValue(*fetched_vals)), 118 lambda feed, feed_val: list( 119 zip([feed.indices, feed.values, feed.dense_shape], feed_val)), 120 lambda feed: [feed.indices, feed.values, feed.dense_shape]), 121 # IndexedSlices are fetched as IndexedSlicesValues. They can be fed 122 # IndexedSlicesValues or normal tuples. 123 (indexed_slices.IndexedSlices, 124 lambda fetch: ([fetch.values, fetch.indices] if fetch.dense_shape is None 125 else [fetch.values, fetch.indices, fetch.dense_shape 126 ], _get_indexed_slices_value_from_fetches), 127 _get_feeds_for_indexed_slices, 128 lambda feed: [feed.values, feed.indices] if feed.dense_shape is None else 129 [feed.values, feed.indices, feed.dense_shape]), 130 # The default catches all other types and performs no expansions. 131 (object, lambda fetch: ([fetch], lambda fetched_vals: fetched_vals[0]), 132 lambda feed, feed_val: [(feed, feed_val)], lambda feed: [feed]) 133] 134 135# pylint: enable=g-long-lambda 136 137 138def _convert_to_numpy_obj(numpy_dtype, obj): 139 """Explicitly convert obj based on numpy type except for string type.""" 140 return numpy_dtype(obj) if numpy_dtype is not object else str(obj) 141 142 143def register_session_run_conversion_functions( 144 tensor_type, 145 fetch_function, 146 feed_function=None, 147 feed_function_for_partial_run=None): 148 """Register fetch and feed conversion functions for `tf.Session.run()`. 149 150 This function registers a triple of conversion functions for fetching and/or 151 feeding values of user-defined types in a call to tf.Session.run(). 152 153 An example 154 155 ```python 156 class SquaredTensor(object): 157 def __init__(self, tensor): 158 self.sq = tf.square(tensor) 159 #you can define conversion functions as follows: 160 fetch_function = lambda squared_tensor:([squared_tensor.sq], 161 lambda val: val[0]) 162 feed_function = lambda feed, feed_val: [(feed.sq, feed_val)] 163 feed_function_for_partial_run = lambda feed: [feed.sq] 164 #then after invoking this register function, you can use as follows: 165 session.run(squared_tensor1, 166 feed_dict = {squared_tensor2 : some_numpy_array}) 167 ``` 168 169 Args: 170 tensor_type: The type for which you want to register a conversion function. 171 fetch_function: A callable that takes an object of type `tensor_type` and 172 returns a tuple, where the first element is a list of `tf.Tensor` objects, 173 and the second element is a callable that takes a list of ndarrays and 174 returns an object of some value type that corresponds to `tensor_type`. 175 fetch_function describes how to expand fetch into its component Tensors 176 and how to contract the fetched results back into a single return value. 177 feed_function: A callable that takes feed_key and feed_value as input, and 178 returns a list of tuples (feed_tensor, feed_val), feed_key must have type 179 `tensor_type`, and feed_tensor must have type `tf.Tensor`. Each feed 180 function describes how to unpack a single fed value and map it to feeds of 181 one or more tensors and their corresponding values. 182 feed_function_for_partial_run: A callable for specifying tensor values to 183 feed when setting up a partial run, which takes a `tensor_type` type 184 object as input, and returns a list of Tensors. 185 186 Raises: 187 ValueError: If `tensor_type` has already been registered. 188 """ 189 for conversion_function in _REGISTERED_EXPANSIONS: 190 if issubclass(conversion_function[0], tensor_type): 191 raise ValueError(f'{tensor_type} has already been registered so ignore ' 192 'it.') 193 194 _REGISTERED_EXPANSIONS.insert(0, (tensor_type, fetch_function, feed_function, 195 feed_function_for_partial_run)) 196 197 198def _is_attrs_instance(obj): 199 """Returns True if the given obj is an instance of attrs-decorated class.""" 200 return getattr(obj.__class__, '__attrs_attrs__', None) is not None 201 202 203def _get_attrs_values(obj): 204 """Returns the list of values from an attrs instance.""" 205 attrs = getattr(obj.__class__, '__attrs_attrs__') 206 return [getattr(obj, a.name) for a in attrs] 207 208 209class _FetchMapper(object): 210 """Definition of the interface provided by fetch mappers. 211 212 Fetch mappers are utility classes used by the _FetchHandler to handle 213 arbitrary structures for the `fetch` argument to `Session.run()`. 214 215 The `fetch` argument can be of various shapes: single tensor or op, list of 216 fetches, tuple of fetches, namedtuple of fetches, or dict of fetches. The 217 structures can be arbitrarily nested. 218 219 The low level run() API only wants a list of tensor or op names. The various 220 `_FetchMapper` subclasses below take care of handling the different shapes: 221 uniquifying the fetches, and constructing results with the original shape. 222 """ 223 224 def unique_fetches(self): 225 """Return the list of unique tensors or ops needed by this fetch mapper. 226 227 Returns: 228 A list of tensors or ops. 229 """ 230 raise NotImplementedError( 231 'unique_fetches must be implemented by subclasses') 232 233 def build_results(self, values): 234 """Build results that match the original shape of the fetch. 235 236 Args: 237 values: List of values returned by run(). The values correspond exactly to 238 the list tensors or ops returned by unique_fetches(). 239 240 Returns: 241 A struct of the same shape as the original fetch object handled by 242 this fetch mapper. In the returned struct, the original fetches are 243 replaced by their fetched values. 244 """ 245 raise NotImplementedError('build_results must be implemented by subclasses') 246 247 @staticmethod 248 def for_fetch(fetch): 249 """Creates fetch mapper that handles the structure of `fetch`. 250 251 The default graph must be the one from which we want to fetch values when 252 this function is called. 253 254 Args: 255 fetch: An arbitrary fetch structure: singleton, list, tuple, namedtuple, 256 or dict. 257 258 Returns: 259 An instance of a subclass of `_FetchMapper` that handles the shape. 260 """ 261 if fetch is None: 262 raise TypeError(f'Argument `fetch` = {fetch} has invalid type ' 263 f'"{type(fetch).__name__}". Cannot be None') 264 elif isinstance(fetch, (list, tuple)): 265 # NOTE(touts): This is also the code path for namedtuples. 266 return _ListFetchMapper(fetch) 267 elif isinstance(fetch, collections_abc.Mapping): 268 return _DictFetchMapper(fetch) 269 elif _is_attrs_instance(fetch): 270 return _AttrsFetchMapper(fetch) 271 else: 272 # Look for a handler in the registered expansions. 273 for tensor_type, fetch_fn, _, _ in _REGISTERED_EXPANSIONS: 274 if isinstance(fetch, tensor_type): 275 fetches, contraction_fn = fetch_fn(fetch) 276 return _ElementFetchMapper(fetches, contraction_fn) 277 # Did not find anything. 278 raise TypeError(f'Argument `fetch` = {fetch} has invalid type ' 279 f'"{type(fetch).__name__}"') 280 281 282class _ElementFetchMapper(_FetchMapper): 283 """Fetch mapper for singleton tensors and ops.""" 284 285 def __init__(self, fetches, contraction_fn): 286 """Creates an _ElementFetchMapper. 287 288 This is the fetch mapper used for leaves in the fetch struct. Because of 289 the expansions mechanism, a leaf can actually fetch more than one tensor. 290 291 Also note that the fetches here can be just strings (tensor or op names) or 292 any other object that the graph knows how to convert to a tensor, such as a 293 Variable. So we have to run each fetch through `as_graph_element()` to get 294 the corresponding tensor or op. 295 296 Args: 297 fetches: List of objects, as returned by a fetch_fn defined in 298 _REGISTERED_EXPANSIONS. 299 contraction_fn: Callable as returned by a fetch_fn. 300 """ 301 self._unique_fetches = [] 302 for fetch in fetches: 303 try: 304 self._unique_fetches.append(ops.get_default_graph().as_graph_element( 305 fetch, allow_tensor=True, allow_operation=True)) 306 except TypeError as e: 307 raise TypeError(f'Argument `fetch` = {fetch} has invalid type ' 308 f'"{type(fetch).__name__}" must be a string or Tensor. ' 309 f'({str(e)})') 310 except ValueError as e: 311 raise ValueError(f'Argument `fetch` = {fetch} cannot be interpreted as ' 312 f'a Tensor. ({str(e)})') 313 except KeyError as e: 314 raise ValueError(f'Argument `fetch` = {fetch} cannot be interpreted as ' 315 f'a Tensor. ({str(e)})') 316 self._contraction_fn = contraction_fn 317 318 def unique_fetches(self): 319 return self._unique_fetches 320 321 def build_results(self, values): 322 if not values: 323 # 'Operation' case 324 return None 325 else: 326 return self._contraction_fn(values) 327 328 329def _uniquify_fetches(fetch_mappers): 330 """Uniquifies fetches from a list of fetch_mappers. 331 332 This is a utility function used by _ListFetchMapper and _DictFetchMapper. It 333 gathers all the unique fetches from a list of mappers and builds a list 334 containing all of them but without duplicates (unique_fetches). 335 336 It also returns a 2-D list of integers (values_indices) indicating at which 337 index in unique_fetches the fetches of the mappers are located. 338 339 This list is as follows: 340 values_indices[mapper_index][mapper_fetch_index] = unique_fetches_index 341 342 Args: 343 fetch_mappers: list of fetch mappers. 344 345 Returns: 346 A list of fetches. 347 A 2-D list of integers. 348 """ 349 unique_fetches = [] 350 value_indices = [] 351 seen_fetches = {} 352 for m in fetch_mappers: 353 m_value_indices = [] 354 for f in m.unique_fetches(): 355 j = seen_fetches.get(id(f)) 356 if j is None: 357 j = len(seen_fetches) 358 seen_fetches[id(f)] = j 359 unique_fetches.append(f) 360 m_value_indices.append(j) 361 value_indices.append(m_value_indices) 362 return unique_fetches, value_indices 363 364 365class _ListFetchMapper(_FetchMapper): 366 """Fetch mapper for lists, tuples, and namedtuples.""" 367 368 def __init__(self, fetches): 369 """Creates a _ListFetchMapper. 370 371 Args: 372 fetches: List, tuple, or namedtuple of fetches. 373 """ 374 if isinstance(fetches, wrapt.ObjectProxy): 375 self._fetch_type = type(fetches.__wrapped__) 376 else: 377 self._fetch_type = type(fetches) 378 self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches] 379 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers) 380 381 def unique_fetches(self): 382 return self._unique_fetches 383 384 def build_results(self, values): 385 # Create the list of results for each mapper. 386 results = [] 387 for m, vi in zip(self._mappers, self._value_indices): 388 results.append(m.build_results([values[j] for j in vi])) 389 # Return a value of the original type of the fetches. 390 if issubclass(self._fetch_type, list): 391 return results 392 elif self._fetch_type == tuple: 393 return tuple(results) 394 else: 395 # This is the code path for namedtuple. 396 return self._fetch_type(*results) 397 398 399class _DictFetchMapper(_FetchMapper): 400 """Fetch mapper for dicts.""" 401 402 def __init__(self, fetches): 403 """Creates a _DictFetchMapper. 404 405 Args: 406 fetches: Dict of fetches. 407 """ 408 self._fetch_type = type(fetches) 409 if isinstance(fetches, collections.defaultdict): 410 self._type_ctor = functools.partial(collections.defaultdict, 411 fetches.default_factory) 412 else: 413 self._type_ctor = self._fetch_type 414 415 self._keys = fetches.keys() 416 self._mappers = [ 417 _FetchMapper.for_fetch(fetch) for fetch in fetches.values() 418 ] 419 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers) 420 421 def unique_fetches(self): 422 return self._unique_fetches 423 424 def build_results(self, values): 425 426 def _generator(): 427 for k, m, vi in zip(self._keys, self._mappers, self._value_indices): 428 yield k, m.build_results([values[j] for j in vi]) 429 430 return self._type_ctor(_generator()) 431 432 433class _AttrsFetchMapper(_FetchMapper): 434 """Fetch mapper for attrs decorated classes.""" 435 436 def __init__(self, fetches): 437 """Creates a _AttrsFetchMapper. 438 439 Args: 440 fetches: An instance of an attrs decorated class. 441 """ 442 values = _get_attrs_values(fetches) 443 self._fetch_type = type(fetches) 444 self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in values] 445 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers) 446 447 def unique_fetches(self): 448 return self._unique_fetches 449 450 def build_results(self, values): 451 results = [] 452 for m, vi in zip(self._mappers, self._value_indices): 453 results.append(m.build_results([values[j] for j in vi])) 454 return self._fetch_type(*results) 455 456 457class _FetchHandler(object): 458 """Handler for structured fetches. 459 460 Given a graph, a user-provided structure for fetches, and a feed dict, this 461 class takes care of generating a list of tensor names to fetch and op names 462 to run for a low level `run()` call. 463 464 Given the results of the low level run call, this class can also rebuild a 465 result structure matching the user-provided structure for fetches, but 466 containing the corresponding results. 467 """ 468 469 # TODO(touts): Make this class also take care of destructuring the feed 470 # dict instead of doing it in the callers. 471 472 def __init__(self, graph, fetches, feeds, feed_handles=None): 473 """Creates a fetch handler. 474 475 Args: 476 graph: Graph of the fetches. Used to check for fetchability and to 477 convert all fetches to tensors or ops as needed. 478 fetches: An arbitrary fetch structure: singleton, list, tuple, namedtuple, 479 or dict. 480 feeds: A feed dict where keys are Tensors. 481 feed_handles: A dict from feed Tensors to TensorHandle objects used as 482 direct feeds. 483 """ 484 with graph.as_default(): 485 self._fetch_mapper = _FetchMapper.for_fetch(fetches) 486 self._fetches = [] 487 self._targets = [] 488 self._feeds = feeds 489 self._feed_handles = feed_handles or {} 490 self._ops = [] 491 self._fetch_handles = {} 492 for fetch in self._fetch_mapper.unique_fetches(): 493 if isinstance(fetch, ops.Operation): 494 self._assert_fetchable(graph, fetch) 495 self._targets.append(fetch) 496 self._ops.append(True) 497 else: 498 self._assert_fetchable(graph, fetch.op) 499 self._fetches.append(fetch) 500 self._ops.append(False) 501 # Remember the fetch if it is for a tensor handle. 502 if (isinstance(fetch, ops.Tensor) and 503 (fetch.op.type == 'GetSessionHandle' or 504 fetch.op.type == 'GetSessionHandleV2')): 505 self._fetch_handles[fetch.ref()] = fetch.op.inputs[0].dtype 506 self._final_fetches = [x for x in self._fetches if x.ref() not in feeds] 507 508 def _assert_fetchable(self, graph, op): 509 if not graph.is_fetchable(op): 510 raise errors.InaccessibleTensorError( 511 f'Operation {op.name} has been marked as not fetchable. Typically ' 512 'this happens when it is defined in another function or code block. ' 513 'Use return values, explicit Python locals or TensorFlow collections ' 514 'to access it.') 515 516 def fetches(self): 517 """Return the unique names of tensors to fetch. 518 519 Returns: 520 A list of strings. 521 """ 522 return self._final_fetches 523 524 def targets(self): 525 """Return the unique names of ops to run. 526 527 Returns: 528 A list of strings. 529 """ 530 return self._targets 531 532 def build_results(self, session, tensor_values): 533 """Build results matching the original fetch shape. 534 535 `tensor_values` must be a list of the same length as 536 the one returned by `fetches()`, and holding the requested 537 fetch values. 538 539 This method builds a struct with the same shape as the original `fetches` 540 passed to the constructor, in which the fetches are replaced by their 541 fetched value. 542 543 Args: 544 session: The enclosing session. Used for tensor handles. 545 tensor_values: List of values matching the list returned by fetches(). 546 547 Returns: 548 A structure of the same shape as the original `fetches` argument but 549 containing tensors or None (for fetched ops). 550 """ 551 full_values = [] 552 assert len(self._final_fetches) == len(tensor_values) 553 i = 0 554 j = 0 555 for is_op in self._ops: 556 if is_op: 557 full_values.append(None) 558 else: 559 # If the fetch was in the feeds, use the fed value, otherwise 560 # use the returned value. 561 if self._fetches[i].ref() in self._feed_handles: 562 # A fetch had a corresponding direct TensorHandle feed. Call eval() 563 # to obtain the Tensor value from the TensorHandle. 564 value = self._feed_handles[self._fetches[i].ref()].eval() 565 else: 566 value = self._feeds.get(self._fetches[i].ref()) 567 if value is None: 568 value = tensor_values[j] 569 j += 1 570 dtype = self._fetch_handles.get(self._fetches[i].ref()) 571 if dtype: 572 full_values.append(session_ops.TensorHandle(value, dtype, session)) 573 else: 574 full_values.append(value) 575 i += 1 576 assert j == len(tensor_values) 577 return self._fetch_mapper.build_results(full_values) 578 579 580def _name_list(tensor_list): 581 """Utility function for transitioning to the new session API. 582 583 Args: 584 tensor_list: a list of `Tensor`s. 585 586 Returns: 587 A list of each `Tensor`s name (as byte arrays). 588 """ 589 return [compat.as_bytes(t.name) for t in tensor_list] 590 591 592class _DeviceAttributes(object): 593 """Struct-like object describing a device's attributes. 594 595 Each device has 3 key properties: 596 - name: the fully-qualified TensorFlow path to the device. For 597 example: /job:worker/replica:0/task:3/device:CPU:0 598 - device_type: the type of the device (e.g. CPU, GPU, TPU, etc.) 599 - memory_limit_bytes: the maximum amount of memory available on the device 600 (in bytes). 601 """ 602 603 def __init__(self, name, device_type, memory_limit_bytes, incarnation): 604 self._name = device.canonical_name(name) 605 self._device_type = device_type 606 self._memory_limit_bytes = memory_limit_bytes 607 self._incarnation = incarnation 608 609 @property 610 def name(self): 611 return self._name 612 613 @property 614 def device_type(self): 615 return self._device_type 616 617 @property 618 def memory_limit_bytes(self): 619 return self._memory_limit_bytes 620 621 @property 622 def incarnation(self): 623 return self._incarnation 624 625 def __repr__(self): 626 return '_DeviceAttributes(%s, %s, %d, %d)' % ( 627 self.name, 628 self.device_type, 629 self.memory_limit_bytes, 630 self.incarnation, 631 ) 632 633 634class BaseSession(SessionInterface): 635 """A class for interacting with a TensorFlow computation. 636 637 The BaseSession enables incremental graph building with inline 638 execution of Operations and evaluation of Tensors. 639 """ 640 641 def __init__(self, target='', graph=None, config=None): 642 """Constructs a new TensorFlow session. 643 644 Args: 645 target: (Optional) The TensorFlow execution engine to connect to. 646 graph: (Optional) The graph to be used. If this argument is None, the 647 default graph will be used. 648 config: (Optional) ConfigProto proto used to configure the session. If no 649 config is specified, the global default will be used. The global default 650 can be configured via the tf.config APIs. 651 652 Raises: 653 tf.errors.OpError: Or one of its subclasses if an error occurs while 654 creating the TensorFlow session. 655 TypeError: If one of the arguments has the wrong type. 656 """ 657 _python_session_create_counter.get_cell().increase_by(1) 658 if graph is None: 659 self._graph = ops.get_default_graph() 660 else: 661 if not isinstance(graph, ops.Graph): 662 raise TypeError('Argument `graph` must be a tf.Graph, but got ' 663 f'"{type(graph).__name__}"') 664 self._graph = graph 665 666 self._closed = False 667 668 if target is not None: 669 try: 670 self._target = compat.as_bytes(target) 671 except TypeError: 672 if isinstance(target, config_pb2.ConfigProto): 673 raise TypeError('Argument `target` must be a string, but got ' 674 f'"{type(target).__name__}". Did you do ' 675 '"Session(config)" instead of ' 676 '"Session(config=config)"?') 677 raise TypeError('Argument `target` must be a string, but got ' 678 f'"{type(target).__name__}"') 679 else: 680 self._target = None 681 682 self._delete_lock = threading.Lock() 683 self._dead_handles = [] 684 685 if config is None: 686 config = context.context().config 687 688 if not isinstance(config, config_pb2.ConfigProto): 689 raise TypeError('Argument `config` must be a tf.ConfigProto, but got ' 690 f'"{type(config).__name__}"') 691 692 if (mixed_precision_global_state.is_mixed_precision_graph_rewrite_enabled() 693 and config.graph_options.rewrite_options.auto_mixed_precision != 694 rewriter_config_pb2.RewriterConfig.OFF): 695 new_config = config_pb2.ConfigProto() 696 new_config.CopyFrom(config) 697 new_config.graph_options.rewrite_options.auto_mixed_precision = ( 698 rewriter_config_pb2.RewriterConfig.ON) 699 config = new_config 700 elif (config.graph_options.rewrite_options.auto_mixed_precision != 701 rewriter_config_pb2.RewriterConfig.ON): 702 mixed_precision_global_state.set_non_mixed_precision_session_created(True) 703 704 self._config = config 705 self._add_shapes = config.graph_options.infer_shapes 706 707 self._session = None 708 opts = tf_session.TF_NewSessionOptions(target=self._target, config=config) 709 try: 710 # pylint: disable=protected-access 711 with self._graph._c_graph.get() as c_graph: 712 self._session = tf_session.TF_NewSessionRef(c_graph, opts) 713 # pylint: enable=protected-access 714 finally: 715 tf_session.TF_DeleteSessionOptions(opts) 716 717 def list_devices(self): 718 """Lists available devices in this session. 719 720 ```python 721 devices = sess.list_devices() 722 for d in devices: 723 print(d.name) 724 ``` 725 726 Where: 727 Each element in the list has the following properties 728 name: A string with the full name of the device. ex: 729 `/job:worker/replica:0/task:3/device:CPU:0` 730 device_type: The type of the device (e.g. `CPU`, `GPU`, `TPU`.) 731 memory_limit: The maximum amount of memory available on the device. 732 Note: depending on the device, it is possible the usable memory could 733 be substantially less. 734 735 Raises: 736 tf.errors.OpError: If it encounters an error (e.g. session is in an 737 invalid state, or network errors occur). 738 739 Returns: 740 A list of devices in the session. 741 """ 742 raw_device_list = tf_session.TF_SessionListDevices(self._session) 743 device_list = [] 744 size = tf_session.TF_DeviceListCount(raw_device_list) 745 for i in range(size): 746 name = tf_session.TF_DeviceListName(raw_device_list, i) 747 device_type = tf_session.TF_DeviceListType(raw_device_list, i) 748 memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, i) 749 incarnation = tf_session.TF_DeviceListIncarnation(raw_device_list, i) 750 device_list.append( 751 _DeviceAttributes(name, device_type, memory, incarnation)) 752 tf_session.TF_DeleteDeviceList(raw_device_list) 753 return device_list 754 755 def close(self): 756 """Closes this session. 757 758 Calling this method frees all resources associated with the session. 759 760 Raises: 761 tf.errors.OpError: Or one of its subclasses if an error occurs while 762 closing the TensorFlow session. 763 """ 764 if self._session and not self._closed: 765 self._closed = True 766 tf_session.TF_CloseSession(self._session) 767 768 def __del__(self): 769 # cleanly ignore all exceptions 770 try: 771 self.close() 772 except Exception: # pylint: disable=broad-except 773 pass 774 if self._session is not None: 775 try: 776 tf_session.TF_DeleteSession(self._session) 777 except (AttributeError, TypeError): 778 # At shutdown, `c_api_util`, `tf_session`, or 779 # `tf_session.TF_DeleteSession` may have been garbage collected, causing 780 # the above method calls to fail. In this case, silently leak since the 781 # program is about to terminate anyway. 782 pass 783 self._session = None 784 785 @property 786 def graph(self): 787 """The graph that was launched in this session.""" 788 return self._graph 789 790 @property 791 def graph_def(self): 792 """A serializable version of the underlying TensorFlow graph. 793 794 Returns: 795 A graph_pb2.GraphDef proto containing nodes for all of the Operations in 796 the underlying TensorFlow graph. 797 """ 798 return self._graph.as_graph_def(add_shapes=self._add_shapes) 799 800 @property 801 def sess_str(self): 802 return self._target 803 804 def as_default(self): 805 """Returns a context manager that makes this object the default session. 806 807 Use with the `with` keyword to specify that calls to 808 `tf.Operation.run` or `tf.Tensor.eval` should be executed in 809 this session. 810 811 ```python 812 c = tf.constant(..) 813 sess = tf.compat.v1.Session() 814 815 with sess.as_default(): 816 assert tf.compat.v1.get_default_session() is sess 817 print(c.eval()) 818 ``` 819 820 To get the current default session, use `tf.compat.v1.get_default_session`. 821 822 *N.B.* The `as_default` context manager *does not* close the 823 session when you exit the context, and you must close the session 824 explicitly. 825 826 ```python 827 c = tf.constant(...) 828 sess = tf.compat.v1.Session() 829 with sess.as_default(): 830 print(c.eval()) 831 # ... 832 with sess.as_default(): 833 print(c.eval()) 834 835 sess.close() 836 ``` 837 838 Alternatively, you can use `with tf.compat.v1.Session():` to create a 839 session that is automatically closed on exiting the context, 840 including when an uncaught exception is raised. 841 842 *N.B.* The default session is a property of the current thread. If you 843 create a new thread, and wish to use the default session in that 844 thread, you must explicitly add a `with sess.as_default():` in that 845 thread's function. 846 847 *N.B.* Entering a `with sess.as_default():` block does not affect 848 the current default graph. If you are using multiple graphs, and 849 `sess.graph` is different from the value of 850 `tf.compat.v1.get_default_graph`, you must explicitly enter a 851 `with sess.graph.as_default():` block to make `sess.graph` the default 852 graph. 853 854 Returns: 855 A context manager using this session as the default session. 856 """ 857 return ops.default_session(self) 858 859 def run(self, fetches, feed_dict=None, options=None, run_metadata=None): 860 """Runs operations and evaluates tensors in `fetches`. 861 862 This method runs one "step" of TensorFlow computation, by 863 running the necessary graph fragment to execute every `Operation` 864 and evaluate every `Tensor` in `fetches`, substituting the values in 865 `feed_dict` for the corresponding input values. 866 867 The `fetches` argument may be a single graph element, or an arbitrarily 868 nested list, tuple, namedtuple, dict, or OrderedDict containing graph 869 elements at its leaves. A graph element can be one of the following types: 870 871 * A `tf.Operation`. 872 The corresponding fetched value will be `None`. 873 * A `tf.Tensor`. 874 The corresponding fetched value will be a numpy ndarray containing the 875 value of that tensor. 876 * A `tf.sparse.SparseTensor`. 877 The corresponding fetched value will be a 878 `tf.compat.v1.SparseTensorValue` 879 containing the value of that sparse tensor. 880 * A `get_tensor_handle` op. The corresponding fetched value will be a 881 numpy ndarray containing the handle of that tensor. 882 * A `string` which is the name of a tensor or operation in the graph. 883 884 The value returned by `run()` has the same shape as the `fetches` argument, 885 where the leaves are replaced by the corresponding values returned by 886 TensorFlow. 887 888 Example: 889 890 ```python 891 a = tf.constant([10, 20]) 892 b = tf.constant([1.0, 2.0]) 893 # 'fetches' can be a singleton 894 v = session.run(a) 895 # v is the numpy array [10, 20] 896 # 'fetches' can be a list. 897 v = session.run([a, b]) 898 # v is a Python list with 2 numpy arrays: the 1-D array [10, 20] and the 899 # 1-D array [1.0, 2.0] 900 # 'fetches' can be arbitrary lists, tuples, namedtuple, dicts: 901 MyData = collections.namedtuple('MyData', ['a', 'b']) 902 v = session.run({'k1': MyData(a, b), 'k2': [b, a]}) 903 # v is a dict with 904 # v['k1'] is a MyData namedtuple with 'a' (the numpy array [10, 20]) and 905 # 'b' (the numpy array [1.0, 2.0]) 906 # v['k2'] is a list with the numpy array [1.0, 2.0] and the numpy array 907 # [10, 20]. 908 ``` 909 910 The optional `feed_dict` argument allows the caller to override 911 the value of tensors in the graph. Each key in `feed_dict` can be 912 one of the following types: 913 914 * If the key is a `tf.Tensor`, the 915 value may be a Python scalar, string, list, or numpy ndarray 916 that can be converted to the same `dtype` as that 917 tensor. Additionally, if the key is a 918 `tf.compat.v1.placeholder`, the shape of 919 the value will be checked for compatibility with the placeholder. 920 * If the key is a 921 `tf.sparse.SparseTensor`, 922 the value should be a 923 `tf.compat.v1.SparseTensorValue`. 924 * If the key is a nested tuple of `Tensor`s or `SparseTensor`s, the value 925 should be a nested tuple with the same structure that maps to their 926 corresponding values as above. 927 928 Each value in `feed_dict` must be convertible to a numpy array of the dtype 929 of the corresponding key. 930 931 The optional `options` argument expects a [`RunOptions`] proto. The options 932 allow controlling the behavior of this particular step (e.g. turning tracing 933 on). 934 935 The optional `run_metadata` argument expects a [`RunMetadata`] proto. When 936 appropriate, the non-Tensor output of this step will be collected there. For 937 example, when users turn on tracing in `options`, the profiled info will be 938 collected into this argument and passed back. 939 940 Args: 941 fetches: A single graph element, a list of graph elements, or a dictionary 942 whose values are graph elements or lists of graph elements (described 943 above). 944 feed_dict: A dictionary that maps graph elements to values (described 945 above). 946 options: A [`RunOptions`] protocol buffer 947 run_metadata: A [`RunMetadata`] protocol buffer 948 949 Returns: 950 Either a single value if `fetches` is a single graph element, or 951 a list of values if `fetches` is a list, or a dictionary with the 952 same keys as `fetches` if that is a dictionary (described above). 953 Order in which `fetches` operations are evaluated inside the call 954 is undefined. 955 956 Raises: 957 RuntimeError: If this `Session` is in an invalid state (e.g. has been 958 closed). 959 TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type. 960 ValueError: If `fetches` or `feed_dict` keys are invalid or refer to a 961 `Tensor` that doesn't exist. 962 """ 963 options_ptr = tf_session.TF_NewBufferFromString( 964 compat.as_bytes(options.SerializeToString())) if options else None 965 run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None 966 967 try: 968 result = self._run(None, fetches, feed_dict, options_ptr, 969 run_metadata_ptr) 970 if run_metadata: 971 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr) 972 run_metadata.ParseFromString(compat.as_bytes(proto_data)) 973 finally: 974 if run_metadata_ptr: 975 tf_session.TF_DeleteBuffer(run_metadata_ptr) 976 if options: 977 tf_session.TF_DeleteBuffer(options_ptr) 978 return result 979 980 def partial_run(self, handle, fetches, feed_dict=None): 981 """Continues the execution with more feeds and fetches. 982 983 This is EXPERIMENTAL and subject to change. 984 985 To use partial execution, a user first calls `partial_run_setup()` and 986 then a sequence of `partial_run()`. `partial_run_setup` specifies the 987 list of feeds and fetches that will be used in the subsequent 988 `partial_run` calls. 989 990 The optional `feed_dict` argument allows the caller to override 991 the value of tensors in the graph. See run() for more information. 992 993 Below is a simple example: 994 995 ```python 996 a = array_ops.placeholder(dtypes.float32, shape=[]) 997 b = array_ops.placeholder(dtypes.float32, shape=[]) 998 c = array_ops.placeholder(dtypes.float32, shape=[]) 999 r1 = math_ops.add(a, b) 1000 r2 = math_ops.multiply(r1, c) 1001 1002 h = sess.partial_run_setup([r1, r2], [a, b, c]) 1003 res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2}) 1004 res = sess.partial_run(h, r2, feed_dict={c: res}) 1005 ``` 1006 1007 Args: 1008 handle: A handle for a sequence of partial runs. 1009 fetches: A single graph element, a list of graph elements, or a dictionary 1010 whose values are graph elements or lists of graph elements (see 1011 documentation for `run`). 1012 feed_dict: A dictionary that maps graph elements to values (described 1013 above). 1014 1015 Returns: 1016 Either a single value if `fetches` is a single graph element, or 1017 a list of values if `fetches` is a list, or a dictionary with the 1018 same keys as `fetches` if that is a dictionary 1019 (see documentation for `run`). 1020 1021 Raises: 1022 tf.errors.OpError: Or one of its subclasses on error. 1023 """ 1024 # TODO(touts): Support feeding and fetching the same tensor. 1025 return self._run(handle, fetches, feed_dict, None, None) 1026 1027 def partial_run_setup(self, fetches, feeds=None): 1028 """Sets up a graph with feeds and fetches for partial run. 1029 1030 This is EXPERIMENTAL and subject to change. 1031 1032 Note that contrary to `run`, `feeds` only specifies the graph elements. 1033 The tensors will be supplied by the subsequent `partial_run` calls. 1034 1035 Args: 1036 fetches: A single graph element, or a list of graph elements. 1037 feeds: A single graph element, or a list of graph elements. 1038 1039 Returns: 1040 A handle for partial run. 1041 1042 Raises: 1043 RuntimeError: If this `Session` is in an invalid state (e.g. has been 1044 closed). 1045 TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type. 1046 tf.errors.OpError: Or one of its subclasses if a TensorFlow error happens. 1047 """ 1048 1049 def _feed_fn(feed): 1050 for tensor_type, _, _, feed_fn in _REGISTERED_EXPANSIONS: 1051 if isinstance(feed, tensor_type): 1052 return feed_fn(feed) 1053 raise TypeError(f'Feed argument {feed} has invalid type ' 1054 f'"{type(feed).__name__}"') 1055 1056 # Check session. 1057 if self._closed: 1058 raise RuntimeError('Attempted to use a closed Session.') 1059 if self.graph.version == 0: 1060 raise RuntimeError('The Session graph is empty. Add operations to the ' 1061 'graph before calling run().') 1062 1063 if feeds is None: 1064 feeds = [] 1065 # Create request. 1066 feed_list = [] 1067 1068 # Validate and process feed_list. 1069 is_list_feed = isinstance(feeds, (list, tuple)) 1070 if not is_list_feed: 1071 feeds = [feeds] 1072 for feed in feeds: 1073 for subfeed in _feed_fn(feed): 1074 try: 1075 subfeed_t = self.graph.as_graph_element( 1076 subfeed, allow_tensor=True, allow_operation=False) 1077 # pylint: disable=protected-access 1078 feed_list.append(subfeed_t._as_tf_output()) 1079 # pylint: enable=protected-access 1080 except Exception as e: 1081 e.message = ('Cannot interpret argument `feed` key as Tensor: ' 1082 f'{e.message}') 1083 e.args = (e.message,) 1084 raise e 1085 1086 # Validate and process fetches. 1087 # TODO(touts): Support feeding and fetching the same tensor. 1088 fetch_handler = _FetchHandler(self._graph, fetches, {}) 1089 1090 # Set up a graph with feeds and fetches for partial run. 1091 def _setup_fn(session, feed_list, fetch_list, target_list): 1092 self._extend_graph() 1093 return tf_session.TF_SessionPRunSetup_wrapper(session, feed_list, 1094 fetch_list, target_list) 1095 1096 # pylint: disable=protected-access 1097 final_fetches = [t._as_tf_output() for t in fetch_handler.fetches()] 1098 final_targets = [op._c_op for op in fetch_handler.targets()] 1099 # pylint: enable=protected-access 1100 1101 return self._do_call(_setup_fn, self._session, feed_list, final_fetches, 1102 final_targets) 1103 1104 def _run(self, handle, fetches, feed_dict, options, run_metadata): 1105 """Perform either run or partial_run, depending the presence of `handle`.""" 1106 1107 def _feed_fn(feed, feed_val): 1108 for tensor_type, _, feed_fn, _ in _REGISTERED_EXPANSIONS: 1109 if isinstance(feed, tensor_type): 1110 return feed_fn(feed, feed_val) 1111 raise TypeError(f'{feed} in argument `feed_dict` has invalid type ' 1112 f'"{type(feed).__name__}"') 1113 1114 # Check session. 1115 if self._closed: 1116 raise RuntimeError('Attempted to use a closed Session.') 1117 if self.graph.version == 0: 1118 raise RuntimeError('The Session graph is empty. Add operations to the ' 1119 'graph before calling run().') 1120 1121 # Create request. 1122 feed_dict_tensor = {} 1123 feed_map = {} 1124 1125 # Validate and process feed_dict. 1126 feed_handles = {} 1127 if feed_dict: 1128 feed_dict = nest.flatten_dict_items(feed_dict) 1129 for feed, feed_val in feed_dict.items(): 1130 for subfeed, subfeed_val in _feed_fn(feed, feed_val): 1131 try: 1132 subfeed_t = self.graph.as_graph_element( 1133 subfeed, allow_tensor=True, allow_operation=False) 1134 except Exception as e: 1135 raise TypeError( 1136 f'Cannot interpret feed_dict key as Tensor: {e.args[0]}') 1137 1138 if isinstance(subfeed_val, ops.Tensor): 1139 raise TypeError( 1140 'The value of a feed cannot be a tf.Tensor object. Acceptable ' 1141 'feed values include Python scalars, strings, lists, numpy ' 1142 'ndarrays, or TensorHandles. For reference, the tensor object ' 1143 f'was {str(feed_val)} which was passed to the argument ' 1144 f'`feed_dict` with key {str(feed)}.') 1145 1146 subfeed_dtype = subfeed_t.dtype.as_numpy_dtype 1147 if isinstance(subfeed_val, int) and _convert_to_numpy_obj( 1148 subfeed_dtype, subfeed_val) != subfeed_val: 1149 raise TypeError( 1150 f'Type of feed value {str(subfeed_val)} with type ' + 1151 f'{str(type(subfeed_val))} is not compatible with Tensor type ' 1152 f'{str(subfeed_dtype)}. Try explicitly setting the type of the ' 1153 'feed tensor to a larger type (e.g. int64).') 1154 1155 is_tensor_handle_feed = isinstance(subfeed_val, 1156 session_ops.TensorHandle) 1157 if is_tensor_handle_feed: 1158 np_val = subfeed_val.to_numpy_array() 1159 feed_handles[subfeed_t.ref()] = subfeed_val 1160 else: 1161 np_val = np.asarray(subfeed_val, dtype=subfeed_dtype) 1162 1163 if (not is_tensor_handle_feed and 1164 not subfeed_t.get_shape().is_compatible_with(np_val.shape)): 1165 raise ValueError( 1166 f'Cannot feed value of shape {str(np_val.shape)} for Tensor ' 1167 f'{subfeed_t.name}, which has shape ' 1168 f'{str(subfeed_t.get_shape())}') 1169 if not self.graph.is_feedable(subfeed_t): 1170 raise ValueError(f'Tensor {subfeed_t.name} may not be fed.') 1171 1172 feed_dict_tensor[subfeed_t.ref()] = np_val 1173 feed_map[compat.as_bytes(subfeed_t.name)] = (subfeed_t, subfeed_val) 1174 1175 # Create a fetch handler to take care of the structure of fetches. 1176 fetch_handler = _FetchHandler( 1177 self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles) 1178 1179 # Run request and get response. 1180 # We need to keep the returned movers alive for the following _do_run(). 1181 # These movers are no longer needed when _do_run() completes, and 1182 # are deleted when `movers` goes out of scope when this _run() ends. 1183 # TODO(yuanbyu, keveman): Revisit whether we should just treat feeding 1184 # of a handle from a different device as an error. 1185 _ = self._update_with_movers(feed_dict_tensor, feed_map) 1186 final_fetches = fetch_handler.fetches() 1187 final_targets = fetch_handler.targets() 1188 # We only want to really perform the run if fetches or targets are provided, 1189 # or if the call is a partial run that specifies feeds. 1190 if final_fetches or final_targets or (handle and feed_dict_tensor): 1191 results = self._do_run(handle, final_targets, final_fetches, 1192 feed_dict_tensor, options, run_metadata) 1193 else: 1194 results = [] 1195 return fetch_handler.build_results(self, results) 1196 1197 def make_callable(self, fetches, feed_list=None, accept_options=False): 1198 """Returns a Python callable that runs a particular step. 1199 1200 The returned callable will take `len(feed_list)` arguments whose types 1201 must be compatible feed values for the respective elements of `feed_list`. 1202 For example, if element `i` of `feed_list` is a `tf.Tensor`, the `i`th 1203 argument to the returned callable must be a numpy ndarray (or something 1204 convertible to an ndarray) with matching element type and shape. See 1205 `tf.Session.run` for details of the allowable feed key and value types. 1206 1207 The returned callable will have the same return type as 1208 `tf.Session.run(fetches, ...)`. For example, if `fetches` is a `tf.Tensor`, 1209 the callable will return a numpy ndarray; if `fetches` is a `tf.Operation`, 1210 it will return `None`. 1211 1212 Args: 1213 fetches: A value or list of values to fetch. See `tf.Session.run` for 1214 details of the allowable fetch types. 1215 feed_list: (Optional.) A list of `feed_dict` keys. See `tf.Session.run` 1216 for details of the allowable feed key types. 1217 accept_options: (Optional.) If `True`, the returned `Callable` will be 1218 able to accept `tf.compat.v1.RunOptions` and `tf.compat.v1.RunMetadata` 1219 as optional keyword arguments `options` and `run_metadata`, 1220 respectively, with the same syntax and semantics as `tf.Session.run`, 1221 which is useful for certain use cases (profiling and debugging) but will 1222 result in measurable slowdown of the `Callable`'s 1223 performance. Default: `False`. 1224 1225 Returns: 1226 A function that when called will execute the step defined by 1227 `feed_list` and `fetches` in this session. 1228 1229 Raises: 1230 TypeError: If `fetches` or `feed_list` cannot be interpreted 1231 as arguments to `tf.Session.run`. 1232 """ 1233 if feed_list is not None: 1234 if not isinstance(feed_list, (list, tuple)): 1235 raise TypeError('Argument `feed_list` must be a list or tuple. ' 1236 f'Received: feed_list={feed_list}') 1237 # Delegate any non-empty feed lists to the existing `run()` logic. 1238 # TODO(mrry): Refactor the feed handling logic from 1239 # `Session._run()` so that we can convert the feeds to a list of 1240 # strings here. 1241 def _generic_run(*feed_args, **kwargs): 1242 feed_dict = { 1243 feed: feed_val for feed, feed_val in zip(feed_list, feed_args) 1244 } 1245 return self.run(fetches, feed_dict=feed_dict, **kwargs) 1246 1247 return _generic_run 1248 1249 # Ensure any changes to the graph are reflected in the runtime. 1250 # Note that we don't need to do this on subsequent calls to the 1251 # returned object, because the arguments to `fetches` must already be 1252 # in the graph. 1253 self._extend_graph() 1254 1255 # Create a fetch handler to take care of the structure of fetches. 1256 fetch_handler = _FetchHandler(self._graph, fetches, {}) 1257 # pylint: disable=protected-access 1258 fetch_list = [t._as_tf_output() for t in fetch_handler.fetches()] 1259 target_list = [op._c_op for op in fetch_handler.targets()] 1260 1261 # pylint: enable=protected-access 1262 1263 def _callable_template_with_options_and_metadata(fetch_list, 1264 target_list, 1265 fetch_handler, 1266 options=None, 1267 run_metadata=None): 1268 """Template callable that accepts RunOptions and RunMetadata.""" 1269 options_ptr = tf_session.TF_NewBufferFromString( 1270 compat.as_bytes(options.SerializeToString())) if options else None 1271 run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None 1272 try: 1273 results = self._call_tf_sessionrun(options_ptr, {}, fetch_list, 1274 target_list, run_metadata_ptr) 1275 if fetch_handler: 1276 results = fetch_handler.build_results(self, results) 1277 else: 1278 results = results[0] if results else None 1279 if run_metadata: 1280 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr) 1281 run_metadata.ParseFromString(compat.as_bytes(proto_data)) 1282 finally: 1283 if run_metadata_ptr: 1284 tf_session.TF_DeleteBuffer(run_metadata_ptr) 1285 if options: 1286 tf_session.TF_DeleteBuffer(options_ptr) 1287 return results 1288 1289 if accept_options: 1290 return functools.partial(_callable_template_with_options_and_metadata, 1291 fetch_list, target_list, fetch_handler) 1292 elif isinstance(fetches, ops.Operation): 1293 # Special case for fetching a single operation, because the 1294 # function will have no return value. 1295 assert not fetch_list 1296 assert len(target_list) == 1 1297 1298 def _single_operation_run(): 1299 self._call_tf_sessionrun(None, {}, [], target_list, None) 1300 1301 return _single_operation_run 1302 elif isinstance(fetches, ops.Tensor): 1303 # Special case for fetching a single tensor, because the 1304 # function can return the result of `TF_Run()` directly. 1305 assert len(fetch_list) == 1 1306 assert not target_list 1307 1308 def _single_tensor_run(): 1309 results = self._call_tf_sessionrun(None, {}, fetch_list, [], None) 1310 return results[0] 1311 1312 return _single_tensor_run 1313 else: 1314 # In all other cases, we must use `fetch_handler` to build the 1315 # results for us. 1316 def _fetch_handler_run(): 1317 results = self._call_tf_sessionrun(None, {}, fetch_list, target_list, 1318 None) 1319 return fetch_handler.build_results(self, results) 1320 1321 return _fetch_handler_run 1322 1323 # Captures the name of a node in an error status. The regex below matches 1324 # both the old and the new formats: 1325 # Old format: [[Node: <node_name> = ...]] 1326 # New format: [[{{node <node_name>}} = ...]] 1327 _NODEDEF_NAME_RE = re.compile( 1328 r'\[\[(Node: )?(\{\{node )?([^\} ]*)(\}\})?\s*=*') 1329 1330 def _do_run(self, handle, target_list, fetch_list, feed_dict, options, 1331 run_metadata): 1332 """Runs a step based on the given fetches and feeds. 1333 1334 Args: 1335 handle: a handle for partial_run. None if this is just a call to run(). 1336 target_list: A list of operations to be run, but not fetched. 1337 fetch_list: A list of tensors to be fetched. 1338 feed_dict: A dictionary that maps tensors to numpy ndarrays. 1339 options: A (pointer to a) [`RunOptions`] protocol buffer, or None 1340 run_metadata: A (pointer to a) [`RunMetadata`] protocol buffer, or None 1341 1342 Returns: 1343 A list of numpy ndarrays, corresponding to the elements of 1344 `fetch_list`. If the ith element of `fetch_list` contains the 1345 name of an operation, the first Tensor output of that operation 1346 will be returned for that element. 1347 1348 Raises: 1349 tf.errors.OpError: Or one of its subclasses on error. 1350 """ 1351 # pylint: disable=protected-access 1352 feeds = dict((t.deref()._as_tf_output(), v) for t, v in feed_dict.items()) 1353 fetches = [t._as_tf_output() for t in fetch_list] 1354 targets = [op._c_op for op in target_list] 1355 1356 # pylint: enable=protected-access 1357 1358 def _run_fn(feed_dict, fetch_list, target_list, options, run_metadata): 1359 # Ensure any changes to the graph are reflected in the runtime. 1360 self._extend_graph() 1361 return self._call_tf_sessionrun(options, feed_dict, fetch_list, 1362 target_list, run_metadata) 1363 1364 def _prun_fn(handle, feed_dict, fetch_list): 1365 if target_list: 1366 raise RuntimeError('partial_run() requires empty `target_list`. ' 1367 f'Received: target_list={target_list} (non-empty)') 1368 return self._call_tf_sessionprun(handle, feed_dict, fetch_list) 1369 1370 if handle is None: 1371 return self._do_call(_run_fn, feeds, fetches, targets, options, 1372 run_metadata) 1373 else: 1374 return self._do_call(_prun_fn, handle, feeds, fetches) 1375 1376 def _do_call(self, fn, *args): 1377 try: 1378 return fn(*args) 1379 except errors.OpError as e: 1380 message = compat.as_text(e.message) 1381 m = BaseSession._NODEDEF_NAME_RE.search(message) 1382 node_def = None 1383 op = None 1384 if m is not None: 1385 node_name = m.group(3) 1386 try: 1387 op = self._graph.get_operation_by_name(node_name) 1388 node_def = op.node_def 1389 except KeyError: 1390 pass 1391 message = error_interpolation.interpolate(message, self._graph) 1392 if 'only supports NHWC tensor format' in message: 1393 message += ('\nA possible workaround: Try disabling Grappler optimizer' 1394 '\nby modifying the config for creating the session eg.' 1395 '\nsession_config.graph_options.rewrite_options.' 1396 'disable_meta_optimizer = True') 1397 raise type(e)(node_def, op, message) # pylint: disable=no-value-for-parameter 1398 1399 def _extend_graph(self): 1400 with self._graph._session_run_lock(): # pylint: disable=protected-access 1401 tf_session.ExtendSession(self._session) 1402 1403 # The threshold to run garbage collection to delete dead tensors. 1404 _DEAD_HANDLES_THRESHOLD = 10 1405 1406 def _register_dead_handle(self, handle): 1407 # Register a dead handle in the session. Delete the dead tensors when 1408 # the number of dead tensors exceeds certain threshold. 1409 tensors_to_delete = None 1410 with self._delete_lock: 1411 self._dead_handles.append(handle) 1412 if len(self._dead_handles) == BaseSession._DEAD_HANDLES_THRESHOLD: 1413 tensors_to_delete = self._dead_handles 1414 self._dead_handles = [] 1415 # Delete the dead tensors. 1416 if tensors_to_delete: 1417 feeds = {} 1418 fetches = [] 1419 for deleter_key, tensor_handle in enumerate(tensors_to_delete): 1420 holder, deleter = session_ops._get_handle_deleter( 1421 self.graph, deleter_key, tensor_handle) 1422 feeds[holder] = tensor_handle 1423 fetches.append(deleter) 1424 self.run(fetches, feed_dict=feeds) 1425 1426 def _update_with_movers(self, feed_dict, feed_map): 1427 # If a tensor handle that is fed to a device incompatible placeholder, 1428 # we move the tensor to the right device, generate a new tensor handle, 1429 # and update `feed_dict` to use the new handle. 1430 handle_movers = [] 1431 for feed_name, val in feed_map.items(): 1432 mover = session_ops._get_handle_mover(self.graph, *val) 1433 if mover: 1434 handle_movers.append((feed_name, val[1], mover)) 1435 # Transfer a tensor to the right device if needed. 1436 if not handle_movers: 1437 return [] 1438 else: 1439 feeds = {} 1440 fetches = [] 1441 for _, handle, mover in handle_movers: 1442 feeds[mover[0]] = handle 1443 fetches.append(mover[1]) 1444 handles = self.run(fetches, feed_dict=feeds) 1445 for handle_mover, handle in zip(handle_movers, handles): 1446 np_val = np.array(handle.handle, dtype=np.object_) 1447 feed_name = handle_mover[0] 1448 feed_tensor = feed_map[feed_name][0] 1449 feed_dict[feed_tensor.ref()] = np_val 1450 return handles 1451 1452 def _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list, 1453 run_metadata): 1454 return tf_session.TF_SessionRun_wrapper(self._session, options, feed_dict, 1455 fetch_list, target_list, 1456 run_metadata) 1457 1458 def _call_tf_sessionprun(self, handle, feed_dict, fetch_list): 1459 return tf_session.TF_SessionPRun_wrapper(self._session, handle, feed_dict, 1460 fetch_list) 1461 1462 # pylint: disable=protected-access 1463 class _Callable(object): 1464 """Experimental wrapper for the C++ `Session::MakeCallable()` API.""" 1465 1466 def __init__(self, session, callable_options): 1467 self._session = session 1468 self._handle = None 1469 options_ptr = tf_session.TF_NewBufferFromString( 1470 compat.as_bytes(callable_options.SerializeToString())) 1471 try: 1472 self._handle = tf_session.TF_SessionMakeCallable( 1473 session._session, options_ptr) 1474 finally: 1475 tf_session.TF_DeleteBuffer(options_ptr) 1476 1477 def __call__(self, *args, **kwargs): 1478 run_metadata = kwargs.get('run_metadata', None) 1479 try: 1480 run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None 1481 ret = tf_session.TF_SessionRunCallable(self._session._session, 1482 self._handle, args, 1483 run_metadata_ptr) 1484 if run_metadata: 1485 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr) 1486 run_metadata.ParseFromString(compat.as_bytes(proto_data)) 1487 finally: 1488 if run_metadata_ptr: 1489 tf_session.TF_DeleteBuffer(run_metadata_ptr) 1490 return ret 1491 1492 def __del__(self): 1493 # NOTE(mrry): It is possible that `self._session.__del__()` could be 1494 # called before this destructor, in which case `self._session._session` 1495 # will be `None`. 1496 if (self._handle is not None and self._session._session is not None and 1497 not self._session._closed): 1498 tf_session.TF_SessionReleaseCallable(self._session._session, 1499 self._handle) 1500 1501 # pylint: enable=protected-access 1502 1503 def _make_callable_from_options(self, callable_options): 1504 """Returns a handle to a "callable" with the given options. 1505 1506 Args: 1507 callable_options: A `CallableOptions` protocol buffer message describing 1508 the computation that will be performed by the callable. 1509 1510 Returns: 1511 A handle to the new callable. 1512 """ 1513 self._extend_graph() 1514 return BaseSession._Callable(self, callable_options) 1515 1516 1517@tf_export(v1=['Session']) 1518class Session(BaseSession): 1519 """A class for running TensorFlow operations. 1520 1521 A `Session` object encapsulates the environment in which `Operation` 1522 objects are executed, and `Tensor` objects are evaluated. For 1523 example: 1524 1525 ```python 1526 tf.compat.v1.disable_eager_execution() # need to disable eager in TF2.x 1527 # Build a graph. 1528 a = tf.constant(5.0) 1529 b = tf.constant(6.0) 1530 c = a * b 1531 1532 # Launch the graph in a session. 1533 sess = tf.compat.v1.Session() 1534 1535 # Evaluate the tensor `c`. 1536 print(sess.run(c)) # prints 30.0 1537 ``` 1538 1539 A session may own resources, such as 1540 `tf.Variable`, `tf.queue.QueueBase`, 1541 and `tf.compat.v1.ReaderBase`. It is important to release 1542 these resources when they are no longer required. To do this, either 1543 invoke the `tf.Session.close` method on the session, or use 1544 the session as a context manager. The following two examples are 1545 equivalent: 1546 1547 ```python 1548 # Using the `close()` method. 1549 sess = tf.compat.v1.Session() 1550 sess.run(...) 1551 sess.close() 1552 1553 # Using the context manager. 1554 with tf.compat.v1.Session() as sess: 1555 sess.run(...) 1556 ``` 1557 1558 The 1559 [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto) 1560 protocol buffer exposes various configuration options for a 1561 session. For example, to create a session that uses soft constraints 1562 for device placement, and log the resulting placement decisions, 1563 create a session as follows: 1564 1565 ```python 1566 # Launch the graph in a session that allows soft device placement and 1567 # logs the placement decisions. 1568 sess = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto( 1569 allow_soft_placement=True, 1570 log_device_placement=True)) 1571 ``` 1572 1573 @compatibility(TF2) 1574 `Session` does not work with either eager execution or `tf.function`, and you 1575 should not invoke it directly. To migrate code that uses sessions to TF2, 1576 rewrite the code without it. See the 1577 [migration 1578 guide](https://www.tensorflow.org/guide/migrate#1_replace_v1sessionrun_calls) 1579 on replacing `Session.run` calls. 1580 @end_compatibility 1581 """ 1582 1583 def __init__(self, target='', graph=None, config=None): 1584 """Creates a new TensorFlow session. 1585 1586 If no `graph` argument is specified when constructing the session, 1587 the default graph will be launched in the session. If you are 1588 using more than one graph (created with `tf.Graph()`) in the same 1589 process, you will have to use different sessions for each graph, 1590 but each graph can be used in multiple sessions. In this case, it 1591 is often clearer to pass the graph to be launched explicitly to 1592 the session constructor. 1593 1594 Args: 1595 target: (Optional.) The execution engine to connect to. Defaults to using 1596 an in-process engine. See 1597 [Distributed TensorFlow](https://tensorflow.org/deploy/distributed) for 1598 more examples. 1599 graph: (Optional.) The `Graph` to be launched (described above). 1600 config: (Optional.) A 1601 [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto) 1602 protocol buffer with configuration options for the session. 1603 """ 1604 super(Session, self).__init__(target, graph, config=config) 1605 # NOTE(mrry): Create these on first `__enter__` to avoid a reference cycle. 1606 self._default_graph_context_manager = None 1607 self._default_session_context_manager = None 1608 1609 def __enter__(self): 1610 if self._default_graph_context_manager is None: 1611 self._default_graph_context_manager = self.graph.as_default() 1612 else: 1613 raise RuntimeError('Session context managers are not re-entrant. ' 1614 'Use `Session.as_default()` if you want to enter ' 1615 'a session multiple times.') 1616 if self._default_session_context_manager is None: 1617 self._default_session_context_manager = self.as_default() 1618 self._default_graph_context_manager.__enter__() 1619 return self._default_session_context_manager.__enter__() 1620 1621 def __exit__(self, exec_type, exec_value, exec_tb): 1622 if exec_type is errors.OpError: 1623 logging.error('Session closing due to OpError: %s', (exec_value,)) 1624 try: 1625 self._default_session_context_manager.__exit__(exec_type, exec_value, 1626 exec_tb) 1627 except RuntimeError as error: 1628 if error == exec_value: 1629 # NOTE(skyewm): for some reason, in Python3, 1630 # _default_session_context_manager.__exit__ will re-raise the "not 1631 # re-entrant" exception raised in __enter__ above (note that if we're 1632 # here, we're in the outer session context manager, since __exit__ is 1633 # not called when __enter__ raises an exception). We still want to 1634 # continue cleaning up this context manager before the exception is 1635 # further propagated, so we ignore it here (note that it'll continue 1636 # being propagated after this method completes). 1637 pass 1638 else: 1639 raise 1640 self._default_graph_context_manager.__exit__(exec_type, exec_value, exec_tb) 1641 1642 self._default_session_context_manager = None 1643 self._default_graph_context_manager = None 1644 1645 # If we are closing due to an exception, set a time limit on our Close() to 1646 # avoid blocking forever. 1647 # TODO(b/120204635) remove this when deadlock is fixed. 1648 if exec_type: 1649 close_thread = threading.Thread( 1650 name='SessionCloseThread', target=self.close) 1651 close_thread.daemon = True 1652 close_thread.start() 1653 close_thread.join(30.0) 1654 if close_thread.is_alive(): 1655 logging.error( 1656 'Session failed to close after 30 seconds. Continuing after this ' 1657 'point may leave your program in an undefined state.') 1658 else: 1659 self.close() 1660 1661 @staticmethod 1662 def reset(target, containers=None, config=None): 1663 """Resets resource containers on `target`, and close all connected sessions. 1664 1665 A resource container is distributed across all workers in the 1666 same cluster as `target`. When a resource container on `target` 1667 is reset, resources associated with that container will be cleared. 1668 In particular, all Variables in the container will become undefined: 1669 they lose their values and shapes. 1670 1671 NOTE: 1672 (i) reset() is currently only implemented for distributed sessions. 1673 (ii) Any sessions on the master named by `target` will be closed. 1674 1675 If no resource containers are provided, all containers are reset. 1676 1677 Args: 1678 target: The execution engine to connect to. 1679 containers: A list of resource container name strings, or `None` if all of 1680 all the containers are to be reset. 1681 config: (Optional.) Protocol buffer with configuration options. 1682 1683 Raises: 1684 tf.errors.OpError: Or one of its subclasses if an error occurs while 1685 resetting containers. 1686 """ 1687 if target is not None: 1688 target = compat.as_bytes(target) 1689 if containers is not None: 1690 containers = [compat.as_bytes(c) for c in containers] 1691 else: 1692 containers = [] 1693 tf_session.TF_Reset(target, containers, config) 1694 1695 1696@tf_export(v1=['InteractiveSession']) 1697class InteractiveSession(BaseSession): 1698 """A TensorFlow `Session` for use in interactive contexts, such as a shell. 1699 1700 The only difference with a regular `Session` is that an `InteractiveSession` 1701 installs itself as the default session on construction. 1702 The methods `tf.Tensor.eval` 1703 and `tf.Operation.run` 1704 will use that session to run ops. 1705 1706 This is convenient in interactive shells and [IPython 1707 notebooks](http://ipython.org), as it avoids having to pass an explicit 1708 `Session` object to run ops. 1709 1710 For example: 1711 1712 ```python 1713 sess = tf.compat.v1.InteractiveSession() 1714 a = tf.constant(5.0) 1715 b = tf.constant(6.0) 1716 c = a * b 1717 # We can just use 'c.eval()' without passing 'sess' 1718 print(c.eval()) 1719 sess.close() 1720 ``` 1721 1722 Note that a regular session installs itself as the default session when it 1723 is created in a `with` statement. The common usage in non-interactive 1724 programs is to follow that pattern: 1725 1726 ```python 1727 a = tf.constant(5.0) 1728 b = tf.constant(6.0) 1729 c = a * b 1730 with tf.compat.v1.Session(): 1731 # We can also use 'c.eval()' here. 1732 print(c.eval()) 1733 ``` 1734 """ 1735 1736 _count_lock = threading.Lock() 1737 _active_session_count = 0 # GUARDED_BY(_count_lock) 1738 1739 def __init__(self, target='', graph=None, config=None): 1740 """Creates a new interactive TensorFlow session. 1741 1742 If no `graph` argument is specified when constructing the session, 1743 the default graph will be launched in the session. If you are 1744 using more than one graph (created with `tf.Graph()`) in the same 1745 process, you will have to use different sessions for each graph, 1746 but each graph can be used in multiple sessions. In this case, it 1747 is often clearer to pass the graph to be launched explicitly to 1748 the session constructor. 1749 1750 Args: 1751 target: (Optional.) The execution engine to connect to. Defaults to using 1752 an in-process engine. 1753 graph: (Optional.) The `Graph` to be launched (described above). 1754 config: (Optional) `ConfigProto` proto used to configure the session. 1755 """ 1756 if not config: 1757 # If config is not provided, choose some reasonable defaults for 1758 # interactive use: 1759 # 1760 # - Grow GPU memory as needed at the cost of fragmentation. 1761 gpu_options = config_pb2.GPUOptions(allow_growth=True) 1762 config = config_pb2.ConfigProto(gpu_options=gpu_options) 1763 # Interactive sessions always place pruned graphs. 1764 config.graph_options.place_pruned_graph = True 1765 1766 super(InteractiveSession, self).__init__(target, graph, config) 1767 with InteractiveSession._count_lock: 1768 if InteractiveSession._active_session_count > 0: 1769 warnings.warn('An interactive session is already active. This can ' 1770 'cause out-of-memory errors in some cases. You must ' 1771 'explicitly call `InteractiveSession.close()` to release ' 1772 'resources held by the other session(s).') 1773 InteractiveSession._active_session_count += 1 1774 # NOTE(mrry): We do not use `Session._closed` here because it has unhelpful 1775 # semantics (in particular, it is not set to true if `Session.close()` is 1776 # called on a session that has not been "opened" by running a step) and we 1777 # cannot change those semantics without breaking existing code. 1778 self._explicitly_closed = False 1779 1780 self._default_session = self.as_default() 1781 self._default_session.enforce_nesting = False 1782 self._default_session.__enter__() 1783 self._explicit_graph = graph 1784 if self._explicit_graph is not None: 1785 self._default_graph = graph.as_default() 1786 self._default_graph.enforce_nesting = False 1787 self._default_graph.__enter__() 1788 1789 def close(self): 1790 """Closes an `InteractiveSession`.""" 1791 super(InteractiveSession, self).close() 1792 with InteractiveSession._count_lock: 1793 if not self._explicitly_closed: 1794 InteractiveSession._active_session_count -= 1 1795 self._explicitly_closed = True 1796 else: 1797 return 1798 if self._explicit_graph is not None: 1799 self._default_graph.__exit__(None, None, None) 1800 self._default_graph = None 1801 self._default_session.__exit__(None, None, None) 1802 self._default_session = None 1803