1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.python.client.session.Session.""" 16import collections 17import os 18import random 19import sys 20import threading 21import time 22import warnings 23 24import numpy as np 25import six 26 27from tensorflow.core.framework import attr_value_pb2 28from tensorflow.core.lib.core import error_codes_pb2 29from tensorflow.core.protobuf import config_pb2 30from tensorflow.python.client import session 31from tensorflow.python.eager import context 32from tensorflow.python.eager import def_function 33from tensorflow.python.framework import config 34from tensorflow.python.framework import constant_op 35from tensorflow.python.framework import device as framework_device_lib 36from tensorflow.python.framework import dtypes 37from tensorflow.python.framework import errors 38from tensorflow.python.framework import function 39from tensorflow.python.framework import importer 40from tensorflow.python.framework import indexed_slices 41from tensorflow.python.framework import ops 42from tensorflow.python.framework import sparse_tensor 43from tensorflow.python.framework import tensor_util 44from tensorflow.python.framework import test_util 45from tensorflow.python.framework import versions 46from tensorflow.python.ops import array_ops 47from tensorflow.python.ops import control_flow_ops 48from tensorflow.python.ops import data_flow_ops 49from tensorflow.python.ops import gen_control_flow_ops 50# Import gradients to resolve circular imports 51from tensorflow.python.ops import gradients # pylint: disable=unused-import 52from tensorflow.python.ops import gradients_impl 53from tensorflow.python.ops import math_ops 54# Import resource_variable_ops for the variables-to-tensor implicit conversion. 55from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import 56from tensorflow.python.ops import state_ops 57from tensorflow.python.ops import variables 58from tensorflow.python.platform import googletest 59from tensorflow.python.training import server_lib 60from tensorflow.python.util import compat 61 62try: 63 import attr # pylint:disable=g-import-not-at-top 64except ImportError: 65 attr = None 66 67try: 68 from frozendict import frozendict # pylint:disable=g-import-not-at-top 69except ImportError: 70 frozendict = dict # pylint:disable=invalid-name 71 72defaultdict = collections.defaultdict # pylint:disable=invalid-name 73 74 75@test_util.with_eager_op_as_function 76class SessionTest(test_util.TensorFlowTestCase): 77 78 def setUp(self): 79 super(SessionTest, self).setUp() 80 warnings.simplefilter('always') 81 82 def testUseExistingGraph(self): 83 with ops.Graph().as_default() as g, ops.device('/cpu:0'): 84 a = constant_op.constant(6.0, shape=[1, 1]) 85 b = constant_op.constant(7.0, shape=[1, 1]) 86 c = math_ops.matmul(a, b, name='matmul') 87 with session.Session(graph=g): 88 result = c.eval() 89 self.assertAllEqual(result, [[42.0]]) 90 91 def testUseDefaultGraph(self): 92 with ops.Graph().as_default(), ops.device('/cpu:0'): 93 a = constant_op.constant(6.0, shape=[1, 1]) 94 b = constant_op.constant(7.0, shape=[1, 1]) 95 c = math_ops.matmul(a, b, name='matmul') 96 with session.Session(): 97 result = c.eval() 98 self.assertAllEqual(result, [[42.0]]) 99 100 def testCreate(self): 101 with session.Session(): 102 inp = constant_op.constant(10.0, shape=[2, 3], name='W1') 103 copy = array_ops.identity(inp) 104 # Test with feed. 105 # TODO(mrry): Investigate why order='F' didn't work. 106 arr = np.asarray([[0, 1, 2], [3, 4, 5]], dtype=np.float32, order='C') 107 copy_val = copy.eval({'W1:0': arr}) 108 self.assertAllEqual(arr, copy_val) 109 # Test without feed. 110 copy_val = copy.eval() 111 self.assertAllEqual( 112 np.asarray( 113 [[10.0, 10.0, 10.0], [10.0, 10.0, 10.0]], dtype=np.float32), 114 copy_val) 115 116 def testManyCPUs(self): 117 with session.Session( 118 config=config_pb2.ConfigProto(device_count={ 119 'CPU': 2, 'GPU': 0 120 })) as sess: 121 inp = constant_op.constant(10.0, name='W1') 122 self.assertAllEqual(inp, 10.0) 123 124 num_cpu_devices = 0 125 num_gpu_devices = 0 126 for device in sess.list_devices(): 127 device_type = framework_device_lib.DeviceSpec.from_string( 128 device.name).device_type 129 if device_type == 'CPU': 130 num_cpu_devices += 1 131 elif device_type == 'GPU': 132 num_gpu_devices += 1 133 self.assertEqual(2, num_cpu_devices) 134 self.assertEqual(0, num_gpu_devices) 135 136 def testPerSessionThreads(self): 137 with session.Session( 138 config=config_pb2.ConfigProto(use_per_session_threads=True)): 139 inp = constant_op.constant(10.0, name='W1') 140 self.assertAllEqual(inp, 10.0) 141 142 def testSessionInterOpThreadPool(self): 143 config_pb = config_pb2.ConfigProto() 144 pool = config_pb.session_inter_op_thread_pool.add() 145 with session.Session(config=config_pb) as s: 146 inp = constant_op.constant(10.0, name='W1') 147 results = s.run([inp]) 148 self.assertAllEqual([10.0], results) 149 150 pool = config_pb.session_inter_op_thread_pool.add() 151 pool.num_threads = 1 152 with session.Session(config=config_pb) as s: 153 inp = constant_op.constant(20.0, name='W2') 154 results = s.run([inp]) 155 self.assertAllEqual([20.0], results) 156 157 pool = config_pb.session_inter_op_thread_pool.add() 158 pool.num_threads = 1 159 pool.global_name = 't1' 160 run_options = config_pb2.RunOptions() 161 run_options.inter_op_thread_pool = ( 162 len(config_pb.session_inter_op_thread_pool) - 1) 163 with session.Session(config=config_pb) as s: 164 inp = constant_op.constant(30.0, name='W2') 165 results = s.run([inp], options=run_options) 166 self.assertAllEqual([30.0], results) 167 168 def testErrorsReported(self): 169 with session.Session() as s: 170 constant_op.constant(10.0, name='W1') 171 with self.assertRaises(ValueError): 172 s.run('foo:0') 173 174 def testErrorPayload(self): 175 with session.Session(): 176 a = array_ops.placeholder(dtypes.float32) 177 with self.assertRaisesOpError(lambda e: e.op == a.op): 178 a.eval() 179 180 def testErrorCodeWithNoNodeDef(self): 181 with session.Session() as s: 182 a = array_ops.placeholder(dtypes.float32, shape=[]) 183 b = array_ops.placeholder(dtypes.float32, shape=[]) 184 r1 = math_ops.add(a, b) 185 186 def exc_predicate(e): 187 return (e.op is None and e.node_def is None and 188 e.error_code == error_codes_pb2.INVALID_ARGUMENT) 189 190 with self.assertRaisesOpError(exc_predicate): 191 # Run with a bogus handle. 192 s.partial_run('foo', r1, feed_dict={a: 1, b: 2}) 193 194 def testErrorBasedOn(self): 195 with session.Session() as sess: 196 a = constant_op.constant(0.0, shape=[2, 3]) 197 # NOTE(mrry): The original_op is nonsense, but used here to test that the 198 # errors are reported correctly. 199 with sess.graph._original_op(a.op): 200 b = array_ops.identity(a, name='id') 201 with sess.graph._original_op(b.op): 202 c = array_ops.placeholder(dtypes.float32) 203 204 def exc_predicate(e): 205 return (e.op == c.op and e.op._original_op == b.op and 206 e.op._original_op._original_op == a.op) 207 208 with self.assertRaisesOpError(exc_predicate): 209 c.eval() 210 211 def testFetchNone(self): 212 with session.Session() as s: 213 a = constant_op.constant(1.0) 214 with self.assertRaises(TypeError): 215 s.run(None) 216 with self.assertRaises(TypeError): 217 s.run([None]) 218 with self.assertRaises(TypeError): 219 s.run({'b': None}) 220 with self.assertRaises(TypeError): 221 s.run({'a': a, 'b': None}) 222 223 def testFetchSingleton(self): 224 with session.Session() as sess: 225 a = constant_op.constant(42.0) 226 res = sess.run(a) 227 self.assertEqual(42.0, res) 228 res = sess.run(a.op) # An op, not a tensor. 229 self.assertIsNone(res) 230 tensor_runner = sess.make_callable(a) 231 res = tensor_runner() 232 self.assertEqual(42.0, res) 233 op_runner = sess.make_callable(a.op) 234 res = op_runner() 235 self.assertIsNone(res) 236 237 def testFetchSingletonByName(self): 238 with session.Session() as sess: 239 a = constant_op.constant(42.0) 240 res = sess.run(a.name) 241 self.assertEqual(42.0, res) 242 res = sess.run(a.op) # An op, not a tensor. 243 self.assertIsNone(res) 244 245 def testFetchList(self): 246 with session.Session() as sess: 247 a = constant_op.constant(42.0) 248 b = control_flow_ops.no_op() # An op, not a tensor. 249 c = constant_op.constant(44.0) 250 v = variables.Variable([54.0]) 251 assign = v.assign([63.0]) 252 res = sess.run([a, b, c, a.name, assign.op]) 253 self.assertIsInstance(res, list) 254 self.assertEqual([42.0, None, 44.0, 42.0, None], res) 255 list_runner = sess.make_callable([a, b, c, a.name, assign.op]) 256 res = list_runner() 257 self.assertIsInstance(res, list) 258 self.assertEqual([42.0, None, 44.0, 42.0, None], res) 259 260 def testFetchTuple(self): 261 with session.Session() as sess: 262 a = constant_op.constant(42.0) 263 b = control_flow_ops.no_op() # An op, not a tensor. 264 c = constant_op.constant(44.0) 265 res = sess.run((a, b, c, a.name)) 266 self.assertIsInstance(res, tuple) 267 self.assertEqual((42.0, None, 44.0, 42.0), res) 268 tuple_runner = sess.make_callable((a, b, c, a.name)) 269 res = tuple_runner() 270 self.assertIsInstance(res, tuple) 271 self.assertEqual((42.0, None, 44.0, 42.0), res) 272 273 def testFetchNamedTuple(self): 274 # pylint: disable=invalid-name 275 ABC = collections.namedtuple('ABC', ['a', 'b', 'c']) 276 # pylint: enable=invalid-name 277 with session.Session() as sess: 278 a = constant_op.constant(42.0) 279 b = control_flow_ops.no_op() # An op, not a tensor. 280 c = constant_op.constant(44.0) 281 res = sess.run(ABC(a, b, c)) 282 self.assertIsInstance(res, ABC) 283 self.assertEqual(42.0, res.a) 284 self.assertIsNone(res.b) 285 self.assertEqual(44.0, res.c) 286 namedtuple_runner = sess.make_callable(ABC(a, b, c)) 287 res = namedtuple_runner() 288 self.assertIsInstance(res, ABC) 289 self.assertEqual(42.0, res.a) 290 self.assertIsNone(res.b) 291 self.assertEqual(44.0, res.c) 292 293 def testFetchDict(self): 294 with session.Session() as sess: 295 a = constant_op.constant(42.0) 296 b = control_flow_ops.no_op() # An op, not a tensor. 297 c = constant_op.constant(44.0) 298 res = sess.run({'a': a, 'b': b, 'c': c}) 299 self.assertIsInstance(res, dict) 300 self.assertEqual(42.0, res['a']) 301 self.assertIsNone(res['b']) 302 self.assertEqual(44.0, res['c']) 303 304 def testFetchOrderedDict(self): 305 with session.Session() as sess: 306 a = constant_op.constant(42.0) 307 b = control_flow_ops.no_op() # An op, not a tensor. 308 c = constant_op.constant(44.0) 309 res = sess.run(collections.OrderedDict([(3, a), (2, b), (1, c)])) 310 self.assertIsInstance(res, collections.OrderedDict) 311 self.assertEqual([3, 2, 1], list(res.keys())) 312 self.assertEqual(42.0, res[3]) 313 self.assertIsNone(res[2]) 314 self.assertEqual(44.0, res[1]) 315 316 @test_util.run_v1_only('b/120545219') 317 def testFetchAttrs(self): 318 if attr is None: 319 self.skipTest('attr module is unavailable.') 320 321 @attr.s 322 class SampleAttr(object): 323 field1 = attr.ib() 324 field2 = attr.ib() 325 326 val1 = np.array([1.2, 3.4, 5.6]) 327 val2 = np.array([[1, 2], [4, 3]]) 328 val3 = np.array([10, 20, 30]) 329 330 t1 = constant_op.constant(val1) 331 t2 = constant_op.constant(val2) 332 333 sample = SampleAttr(t1, t2) 334 with session.Session() as sess: 335 result = sess.run(sample) 336 self.assertIsInstance(result, SampleAttr) 337 self.assertAllEqual(val1, result.field1) 338 self.assertAllEqual(val2, result.field2) 339 340 result = sess.run(sample, feed_dict={sample.field1: val3}) 341 self.assertIsInstance(result, SampleAttr) 342 self.assertAllEqual(val3, result.field1) 343 self.assertAllEqual(val2, result.field2) 344 345 @test_util.run_v1_only('b/120545219') 346 def testFetchNestedAttrs(self): 347 if attr is None: 348 self.skipTest('attr module is unavailable.') 349 350 @attr.s 351 class SampleAttr(object): 352 field0 = attr.ib() 353 field1 = attr.ib() 354 355 v1 = 10 356 v2 = 20 357 v3 = np.float32(1.2) 358 v4 = np.float32(3.4) 359 v5 = np.float64(100.001) 360 v6 = np.float64(-23.451) 361 arr1 = np.array([1.2, 6.7, 3.4]) 362 arr2 = np.array([7, 11, 3]) 363 sample = SampleAttr( 364 SampleAttr( 365 SampleAttr(constant_op.constant(v1), constant_op.constant(v2)), 366 SampleAttr(constant_op.constant(arr1), constant_op.constant(arr2))), 367 {'A': SampleAttr(constant_op.constant(v3), constant_op.constant(v4)), 368 'B': [SampleAttr(constant_op.constant(v5), constant_op.constant(v6))]}) 369 370 with session.Session() as sess: 371 result = sess.run(sample) 372 self.assertIsInstance(result, SampleAttr) 373 self.assertIsInstance(result.field0, SampleAttr) 374 self.assertIsInstance(result.field0.field0, SampleAttr) 375 self.assertIsInstance(result.field0.field1, SampleAttr) 376 self.assertIsInstance(result.field0.field1.field0, np.ndarray) 377 self.assertAllEqual(arr1, result.field0.field1.field0) 378 self.assertIsInstance(result.field0.field1.field1, np.ndarray) 379 self.assertAllEqual(arr2, result.field0.field1.field1) 380 self.assertIsInstance(result.field1, dict) 381 self.assertIn('A', result.field1) 382 self.assertIn('B', result.field1) 383 self.assertIsInstance(result.field1['A'], SampleAttr) 384 self.assertAllEqual( 385 [v3, v4], 386 [result.field1['A'].field0, result.field1['A'].field1]) 387 self.assertIsInstance(result.field1['B'], list) 388 self.assertEqual(1, len(result.field1['B'])) 389 self.assertIsInstance(result.field1['B'][0], SampleAttr) 390 self.assertAllEqual( 391 [v5, v6], 392 [result.field1['B'][0].field0, result.field1['B'][0].field1]) 393 394 def testFetchNestingEmptyOneLevel(self): 395 with session.Session() as sess: 396 a_val = 11.0 397 a = constant_op.constant(a_val) 398 399 res = sess.run([[], tuple(), {}]) 400 self.assertIsInstance(res, list) 401 self.assertEqual(3, len(res)) 402 self.assertIsInstance(res[0], list) 403 self.assertEqual(0, len(res[0])) 404 self.assertIsInstance(res[1], tuple) 405 self.assertEqual(0, len(res[1])) 406 self.assertIsInstance(res[2], dict) 407 self.assertEqual(0, len(res[2])) 408 409 res = sess.run([[], tuple(), {}, a]) 410 self.assertIsInstance(res, list) 411 self.assertEqual(4, len(res)) 412 self.assertIsInstance(res[0], list) 413 self.assertEqual(0, len(res[0])) 414 self.assertIsInstance(res[1], tuple) 415 self.assertEqual(0, len(res[1])) 416 self.assertIsInstance(res[2], dict) 417 self.assertEqual(0, len(res[2])) 418 self.assertEqual(a_val, res[3]) 419 420 def testFetchNestingOneLevel(self): 421 with session.Session() as sess: 422 # pylint: disable=invalid-name 423 ABC = collections.namedtuple('ABC', ['a', 'b', 'c']) 424 DEFGHI = collections.namedtuple('DEFGHI', ['d', 'e', 'f', 'g', 'h', 'i']) 425 # pylint: enable=invalid-name 426 a_val = 42.0 427 b_val = None 428 c_val = 44.0 429 a = constant_op.constant(a_val) 430 b = control_flow_ops.no_op() # An op, not a tensor. 431 c = constant_op.constant(c_val) 432 test_dct = {'a': a.name, 'c': c, 'b': b} 433 test_dct_types = [dict, frozendict, defaultdict] 434 # List of lists, tuples, namedtuple, dict, frozendict, and defaultdict 435 res = sess.run([ 436 [a, b, c], 437 (a, b, c), 438 ABC(a=a, b=b, c=c), 439 dict(test_dct), 440 frozendict(test_dct), 441 defaultdict(str, test_dct), 442 ]) 443 self.assertIsInstance(res, list) 444 self.assertEqual(6, len(res)) 445 self.assertIsInstance(res[0], list) 446 self.assertEqual(3, len(res[0])) 447 self.assertEqual(a_val, res[0][0]) 448 self.assertEqual(b_val, res[0][1]) 449 self.assertEqual(c_val, res[0][2]) 450 self.assertIsInstance(res[1], tuple) 451 self.assertEqual(3, len(res[1])) 452 self.assertEqual(a_val, res[1][0]) 453 self.assertEqual(b_val, res[1][1]) 454 self.assertEqual(c_val, res[1][2]) 455 self.assertIsInstance(res[2], ABC) 456 self.assertEqual(a_val, res[2].a) 457 self.assertEqual(b_val, res[2].b) 458 self.assertEqual(c_val, res[2].c) 459 for expected_type, r in zip(test_dct_types, res[3:]): 460 self.assertIsInstance(r, expected_type) 461 self.assertEqual(3, len(r)) 462 self.assertEqual(a_val, r['a']) 463 self.assertEqual(b_val, r['b']) 464 self.assertEqual(c_val, r['c']) 465 self.assertEqual(res[5].default_factory, str) 466 # Tuple of lists, tuples, namedtuple, dict, frozendict, and defaultdict 467 res = sess.run(([a, b, c], (a.name, b, c), ABC(a=a, b=b, 468 c=c), dict(test_dct), 469 frozendict(test_dct), defaultdict(str, test_dct))) 470 self.assertIsInstance(res, tuple) 471 self.assertEqual(6, len(res)) 472 self.assertIsInstance(res[0], list) 473 self.assertEqual(3, len(res[0])) 474 self.assertEqual(a_val, res[0][0]) 475 self.assertEqual(b_val, res[0][1]) 476 self.assertEqual(c_val, res[0][2]) 477 self.assertIsInstance(res[1], tuple) 478 self.assertEqual(3, len(res[1])) 479 self.assertEqual(a_val, res[1][0]) 480 self.assertEqual(b_val, res[1][1]) 481 self.assertEqual(c_val, res[1][2]) 482 self.assertIsInstance(res[2], ABC) 483 self.assertEqual(a_val, res[2].a) 484 self.assertEqual(b_val, res[2].b) 485 self.assertEqual(c_val, res[2].c) 486 for expected_type, r in zip(test_dct_types, res[3:]): 487 self.assertIsInstance(r, expected_type) 488 self.assertEqual(3, len(r)) 489 self.assertEqual(a_val, r['a']) 490 self.assertEqual(b_val, r['b']) 491 self.assertEqual(c_val, r['c']) 492 self.assertEqual(res[5].default_factory, str) 493 494 # Namedtuple of lists, tuples, namedtuples, dict, frozendict, defaultdict 495 res = sess.run( 496 DEFGHI( 497 d=[a, b, c], 498 e=(a, b, c), 499 f=ABC(a=a.name, b=b, c=c), 500 g=dict(test_dct), 501 h=frozendict(test_dct), 502 i=defaultdict(str, test_dct))) 503 self.assertIsInstance(res, DEFGHI) 504 self.assertIsInstance(res.d, list) 505 self.assertEqual(3, len(res.d)) 506 self.assertEqual(a_val, res.d[0]) 507 self.assertEqual(b_val, res.d[1]) 508 self.assertEqual(c_val, res.d[2]) 509 self.assertIsInstance(res.e, tuple) 510 self.assertEqual(3, len(res.e)) 511 self.assertEqual(a_val, res.e[0]) 512 self.assertEqual(b_val, res.e[1]) 513 self.assertEqual(c_val, res.e[2]) 514 self.assertIsInstance(res.f, ABC) 515 self.assertEqual(a_val, res.f.a) 516 self.assertEqual(b_val, res.f.b) 517 self.assertEqual(c_val, res.f.c) 518 self.assertIsInstance(res.g, dict) 519 self.assertEqual(3, len(res.g)) 520 self.assertEqual(a_val, res.g['a']) 521 self.assertEqual(b_val, res.g['b']) 522 self.assertEqual(c_val, res.g['c']) 523 self.assertIsInstance(res.h, frozendict) 524 self.assertEqual(3, len(res.h)) 525 self.assertEqual(a_val, res.h['a']) 526 self.assertEqual(b_val, res.h['b']) 527 self.assertEqual(c_val, res.h['c']) 528 self.assertIsInstance(res.i, defaultdict) 529 self.assertEqual(3, len(res.i)) 530 self.assertEqual(a_val, res.i['a']) 531 self.assertEqual(b_val, res.i['b']) 532 self.assertEqual(c_val, res.i['c']) 533 self.assertEqual(res.i.default_factory, str) 534 # Dict of lists, tuples, namedtuples, dict, frozendict, defaultdict 535 res = sess.run({ 536 'd': [a, b, c], 537 'e': (a, b, c), 538 'f': ABC(a=a, b=b, c=c), 539 'g': dict(test_dct), 540 'h': frozendict(test_dct), 541 'i': defaultdict(str, test_dct), 542 }) 543 self.assertIsInstance(res, dict) 544 self.assertEqual(6, len(res)) 545 self.assertIsInstance(res['d'], list) 546 self.assertEqual(3, len(res['d'])) 547 self.assertEqual(a_val, res['d'][0]) 548 self.assertEqual(b_val, res['d'][1]) 549 self.assertEqual(c_val, res['d'][2]) 550 self.assertIsInstance(res['e'], tuple) 551 self.assertEqual(3, len(res['e'])) 552 self.assertEqual(a_val, res['e'][0]) 553 self.assertEqual(b_val, res['e'][1]) 554 self.assertEqual(c_val, res['e'][2]) 555 self.assertIsInstance(res['f'], ABC) 556 self.assertEqual(a_val, res['f'].a) 557 self.assertEqual(b_val, res['f'].b) 558 self.assertEqual(c_val, res['f'].c) 559 for expected_type, r_key in zip(test_dct_types, ('g', 'h', 'i')): 560 r = res[r_key] 561 self.assertIsInstance(r, expected_type) 562 self.assertEqual(3, len(r)) 563 self.assertEqual(a_val, r['a']) 564 self.assertEqual(b_val, r['b']) 565 self.assertEqual(c_val, r['c']) 566 self.assertEqual(res['i'].default_factory, str) 567 568 def testFetchTensorObject(self): 569 with session.Session() as s: 570 a = constant_op.constant(1.0, shape=[1, 2]) 571 b = constant_op.constant(2.0, shape=[2, 3]) 572 c = math_ops.matmul(a, b) 573 results_with_list = s.run([c]) 574 self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_list[0]) 575 results_with_single = s.run(c) 576 self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_single) 577 results_with_get = c.eval() 578 self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_get) 579 a_val, b_val = s.run([a, b]) # Test multiple fetches. 580 self.assertAllEqual([[1.0, 1.0]], a_val) 581 self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], b_val) 582 results_with_dict = s.run({'a': [a], 'b': b, 'z': [a, b]}) 583 self.assertAllEqual([[1.0, 1.0]], results_with_dict['a'][0]) 584 self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], 585 results_with_dict['b']) 586 self.assertAllEqual(results_with_dict['a'][0], results_with_dict['z'][0]) 587 self.assertAllEqual(results_with_dict['b'], results_with_dict['z'][1]) 588 589 # Test nested structures 590 results_with_nested_list = s.run([[[a, b], b], a, [a, b]]) 591 self.assertAllEqual([[1.0, 1.0]], results_with_nested_list[0][0][0]) 592 self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], 593 results_with_nested_list[0][0][1]) 594 self.assertAllEqual(results_with_nested_list[0][0][0], 595 results_with_nested_list[1]) 596 self.assertAllEqual(results_with_nested_list[1], 597 results_with_nested_list[2][0]) 598 self.assertAllEqual(results_with_nested_list[0][0][1], 599 results_with_nested_list[0][1]) 600 self.assertAllEqual(results_with_nested_list[0][1], 601 results_with_nested_list[2][1]) 602 603 def testFetchScalar(self): 604 with session.Session() as s: 605 for scalar in np.int32, np.int64, np.float16, np.float32, np.float64: 606 x = scalar(7) 607 y = scalar(8) 608 tf_x = constant_op.constant(x, shape=[]) 609 tf_y = constant_op.constant(y) 610 tf_xy = math_ops.add(tf_x, tf_y) 611 # Single fetch 612 xy = s.run(tf_xy) 613 self.assertEqual(scalar, type(xy)) 614 self.assertEqual(x + y, xy) 615 # List fetch 616 xy, = s.run([tf_xy]) 617 self.assertEqual(scalar, type(xy)) 618 self.assertEqual(x + y, xy) 619 # Dict fetch 620 xy = s.run({'xy': tf_xy})['xy'] 621 self.assertEqual(scalar, type(xy)) 622 self.assertEqual(x + y, xy) 623 # Nested list fetch 624 xy = s.run([[[tf_xy]], tf_xy, [tf_xy]]) 625 self.assertAllEqual(xy, [[[x + y]], x + y, [x + y]]) 626 self.assertEqual(scalar, type(xy[0][0][0])) 627 self.assertEqual(scalar, type(xy[1])) 628 self.assertEqual(scalar, type(xy[2][0])) 629 630 def testFetchOperationObject(self): 631 with session.Session() as s: 632 a = constant_op.constant(1.0, shape=[1, 2]) 633 v = variables.Variable(a, name='testFetchOperationObject_v') 634 s.run(v.initializer) 635 v_val = s.run(v) 636 self.assertAllEqual([[1.0, 1.0]], v_val) 637 638 def testFetchSparseTensor(self): 639 with session.Session() as s: 640 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 641 values = np.array([1.0, 2.0]).astype(np.float32) 642 shape = np.array([7, 9, 2]).astype(np.int64) 643 sp = sparse_tensor.SparseTensor( 644 constant_op.constant(indices), constant_op.constant(values), 645 constant_op.constant(shape)) 646 # Single fetch, use as tuple 647 sp_out = s.run(sp) 648 indices_out, values_out, shape_out = sp_out 649 self.assertAllEqual(indices_out, indices) 650 self.assertAllEqual(values_out, values) 651 self.assertAllEqual(shape_out, shape) 652 # Single fetch, use as SparseTensorValue 653 sp_out = s.run(sp) 654 self.assertAllEqual(sp_out.indices, indices) 655 self.assertAllEqual(sp_out.values, values) 656 self.assertAllEqual(sp_out.dense_shape, shape) 657 # Tuple fetch, use as tuple 658 indices_out, values_out, shape_out = s.run(sp) 659 self.assertAllEqual(indices_out, indices) 660 self.assertAllEqual(values_out, values) 661 self.assertAllEqual(shape_out, shape) 662 # List fetch, use as tuple 663 (indices_out, values_out, shape_out), = s.run([sp]) 664 self.assertAllEqual(indices_out, indices) 665 self.assertAllEqual(values_out, values) 666 self.assertAllEqual(shape_out, shape) 667 # List fetch, use as SparseTensorValue 668 sp_out, = s.run([sp]) 669 self.assertAllEqual(sp_out.indices, indices) 670 self.assertAllEqual(sp_out.values, values) 671 self.assertAllEqual(sp_out.dense_shape, shape) 672 # Dict fetch (single value), use as tuple 673 indices_out, values_out, shape_out = s.run({'sp': sp})['sp'] 674 self.assertAllEqual(indices_out, indices) 675 self.assertAllEqual(values_out, values) 676 self.assertAllEqual(shape_out, shape) 677 # Dict fetch (list value), use as tuple 678 (indices_out, values_out, shape_out), = s.run({'sp': [sp]})['sp'] 679 self.assertAllEqual(indices_out, indices) 680 self.assertAllEqual(values_out, values) 681 self.assertAllEqual(shape_out, shape) 682 # Dict fetch, use as SparseTensorValue 683 sp_out = s.run({'sp': sp})['sp'] 684 self.assertAllEqual(sp_out.indices, indices) 685 self.assertAllEqual(sp_out.values, values) 686 self.assertAllEqual(sp_out.dense_shape, shape) 687 # Nested list fetch use as tuple 688 sp_out = s.run([[[sp]], sp]) 689 indices_out, values_out, shape_out = sp_out[0][0][0] 690 self.assertAllEqual(indices_out, indices) 691 self.assertAllEqual(values_out, values) 692 self.assertAllEqual(shape_out, shape) 693 indices_out, values_out, shape_out = sp_out[1] 694 self.assertAllEqual(indices_out, indices) 695 self.assertAllEqual(values_out, values) 696 self.assertAllEqual(shape_out, shape) 697 # Nested list fetch, use as SparseTensorValue 698 sp_out = s.run([[[sp]], sp]) 699 self.assertAllEqual(sp_out[0][0][0].indices, indices) 700 self.assertAllEqual(sp_out[0][0][0].values, values) 701 self.assertAllEqual(sp_out[0][0][0].dense_shape, shape) 702 self.assertAllEqual(sp_out[1].indices, indices) 703 self.assertAllEqual(sp_out[1].values, values) 704 self.assertAllEqual(sp_out[1].dense_shape, shape) 705 706 def testFeedSparseTensor(self): 707 with session.Session() as s: 708 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 709 values = np.array([1.0, 2.0]).astype(np.float32) 710 shape = np.array([7, 9, 2]).astype(np.int64) 711 sp = sparse_tensor.SparseTensor( 712 array_ops.placeholder(dtype=np.int64, shape=(2, 3)), 713 array_ops.placeholder(dtype=np.float32, shape=(2,)), 714 array_ops.placeholder(dtype=np.int64, shape=(3,)), 715 ) 716 sp_indices = array_ops.identity(sp.indices) 717 sp_values = array_ops.identity(sp.values) 718 sp_shape = array_ops.identity(sp.dense_shape) 719 sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape) 720 # Feed with tuple 721 indices_out, values_out, shape_out = s.run( 722 [sp_indices, sp_values, sp_shape], { 723 sp: (indices, values, shape) 724 }) 725 self.assertAllEqual(indices_out, indices) 726 self.assertAllEqual(values_out, values) 727 self.assertAllEqual(shape_out, shape) 728 # Feed with tuple, fetch sp directly 729 sp_out = s.run(sp, {sp: (indices, values, shape)}) 730 self.assertAllEqual(sp_out.indices, indices) 731 self.assertAllEqual(sp_out.values, values) 732 self.assertAllEqual(sp_out.dense_shape, shape) 733 # Feed with SparseTensorValue 734 indices_out, values_out, shape_out = s.run( 735 [sp_indices, sp_values, sp_shape], { 736 sp: sparse_tensor.SparseTensorValue(indices, values, shape) 737 }) 738 self.assertAllEqual(indices_out, indices) 739 self.assertAllEqual(values_out, values) 740 self.assertAllEqual(shape_out, shape) 741 # Feed with SparseTensorValue, fetch SparseTensorValue 742 sp2_out = s.run(sp2, { 743 sp: sparse_tensor.SparseTensorValue(indices, values, shape) 744 }) 745 self.assertAllEqual(sp2_out.indices, indices) 746 self.assertAllEqual(sp2_out.values, values) 747 self.assertAllEqual(sp2_out.dense_shape, shape) 748 # Feed SparseTensorValue and fetch sp directly. 749 sp_out = s.run(sp, { 750 sp: sparse_tensor.SparseTensorValue(indices, values, shape) 751 }) 752 self.assertAllEqual(sp_out.indices, indices) 753 self.assertAllEqual(sp_out.values, values) 754 self.assertAllEqual(sp_out.dense_shape, shape) 755 756 def testFeedSparsePlaceholder(self): 757 with session.Session() as s: 758 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 759 values = np.array([1.0, 2.0]).astype(np.float32) 760 shape = np.array([7, 9, 2]).astype(np.int64) 761 sp = array_ops.sparse_placeholder(dtype=np.float32, name='placeholder1') 762 sp_indices = array_ops.identity(sp.indices) 763 sp_values = array_ops.identity(sp.values) 764 sp_shape = array_ops.identity(sp.dense_shape) 765 sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape) 766 # Feed with tuple 767 indices_out, values_out, shape_out = s.run( 768 [sp_indices, sp_values, sp_shape], { 769 sp: (indices, values, shape) 770 }) 771 self.assertAllEqual(indices_out, indices) 772 self.assertAllEqual(values_out, values) 773 self.assertAllEqual(shape_out, shape) 774 # Feed with SparseTensorValue 775 indices_out, values_out, shape_out = s.run( 776 [sp_indices, sp_values, sp_shape], { 777 sp: sparse_tensor.SparseTensorValue(indices, values, shape) 778 }) 779 self.assertAllEqual(indices_out, indices) 780 self.assertAllEqual(values_out, values) 781 self.assertAllEqual(shape_out, shape) 782 # Feed with SparseTensorValue, fetch SparseTensorValue 783 sp2_out = s.run(sp2, { 784 sp: sparse_tensor.SparseTensorValue(indices, values, shape) 785 }) 786 self.assertAllEqual(sp2_out.indices, indices) 787 self.assertAllEqual(sp2_out.values, values) 788 self.assertAllEqual(sp2_out.dense_shape, shape) 789 790 def testFeedSparsePlaceholderPartialShape(self): 791 with session.Session() as s: 792 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 793 values = np.array([1.0, 2.0]).astype(np.float32) 794 shape = np.array([7, 9, 2]).astype(np.int64) 795 sp = array_ops.sparse_placeholder( 796 shape=[None, 9, 2], dtype=np.float32, name='placeholder1') 797 sp_indices = array_ops.identity(sp.indices) 798 sp_values = array_ops.identity(sp.values) 799 sp_shape = array_ops.identity(sp.dense_shape) 800 sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape) 801 # Feed with tuple 802 indices_out, values_out, shape_out = s.run( 803 [sp_indices, sp_values, sp_shape], { 804 sp: (indices, values, shape) 805 }) 806 self.assertAllEqual(indices_out, indices) 807 self.assertAllEqual(values_out, values) 808 self.assertAllEqual(shape_out, shape) 809 # Feed with SparseTensorValue 810 indices_out, values_out, shape_out = s.run( 811 [sp_indices, sp_values, sp_shape], { 812 sp: sparse_tensor.SparseTensorValue(indices, values, shape) 813 }) 814 self.assertAllEqual(indices_out, indices) 815 self.assertAllEqual(values_out, values) 816 self.assertAllEqual(shape_out, shape) 817 # Feed with SparseTensorValue, fetch SparseTensorValue 818 sp2_out = s.run(sp2, { 819 sp: sparse_tensor.SparseTensorValue(indices, values, shape) 820 }) 821 self.assertAllEqual(sp2_out.indices, indices) 822 self.assertAllEqual(sp2_out.values, values) 823 self.assertAllEqual(sp2_out.dense_shape, shape) 824 825 def testFeedSparsePlaceholderConstantShape(self): 826 with session.Session() as s: 827 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 828 values = np.array([1.0, 2.0]).astype(np.float32) 829 shape = np.array([7, 9, 2]).astype(np.int64) 830 sp = array_ops.sparse_placeholder( 831 dtype=np.float32, shape=shape, name='placeholder1') 832 self.assertAllEqual(sp.dense_shape.eval(session=s), shape) 833 self.assertAllEqual(tensor_util.constant_value(sp.shape), shape) 834 sp_indices = array_ops.identity(sp.indices) 835 sp_values = array_ops.identity(sp.values) 836 sp_shape = array_ops.identity(sp.dense_shape) 837 # Feed with tuple 838 indices_out, values_out, shape_out = s.run( 839 [sp_indices, sp_values, sp_shape], { 840 sp: (indices, values) 841 }) 842 self.assertAllEqual(indices_out, indices) 843 self.assertAllEqual(values_out, values) 844 self.assertAllEqual(shape_out, shape) 845 846 def testFetchIndexedSlices(self): 847 with session.Session() as s: 848 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 849 values = np.array([1.0, 2.0]).astype(np.float32) 850 dense_shape = np.array([7, 9, 2]).astype(np.int64) 851 ind = indexed_slices.IndexedSlices( 852 constant_op.constant(values), constant_op.constant(indices), 853 constant_op.constant(dense_shape)) 854 # Single fetch, use as tuple 855 ind_out = s.run(ind) 856 values_out, indices_out, dense_shape_out = ind_out 857 self.assertAllEqual(values_out, values) 858 self.assertAllEqual(indices_out, indices) 859 self.assertAllEqual(dense_shape_out, dense_shape) 860 # Single fetch, use as IndexedSlicesValue 861 ind_out = s.run(ind) 862 self.assertAllEqual(ind_out.values, values) 863 self.assertAllEqual(ind_out.indices, indices) 864 self.assertAllEqual(ind_out.dense_shape, dense_shape) 865 # Tuple fetch, use as tuple 866 values_out, indices_out, dense_shape_out = s.run(ind) 867 self.assertAllEqual(values_out, values) 868 self.assertAllEqual(indices_out, indices) 869 self.assertAllEqual(dense_shape_out, dense_shape) 870 # List fetch, use as tuple 871 (values_out, indices_out, dense_shape_out), = s.run([ind]) 872 self.assertAllEqual(values_out, values) 873 self.assertAllEqual(indices_out, indices) 874 self.assertAllEqual(dense_shape_out, dense_shape) 875 # List fetch, use as IndexedSlicesValue 876 ind_out, = s.run([ind]) 877 self.assertAllEqual(ind_out.values, values) 878 self.assertAllEqual(ind_out.indices, indices) 879 self.assertAllEqual(ind_out.dense_shape, dense_shape) 880 881 def testFeedIndexedSlices(self): 882 with session.Session() as s: 883 values = np.array([1.0, 2.0]).astype(np.float32) 884 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 885 dense_shape = np.array([7, 9, 2]).astype(np.int64) 886 ind = indexed_slices.IndexedSlices( 887 array_ops.placeholder(dtype=np.float32, shape=(2,)), 888 array_ops.placeholder(dtype=np.int64, shape=(2, 3)), 889 array_ops.placeholder(dtype=np.int64, shape=(3,)), 890 ) 891 ind_values = array_ops.identity(ind.values) 892 ind_indices = array_ops.identity(ind.indices) 893 ind_dense_shape = array_ops.identity(ind.dense_shape) 894 ind2 = indexed_slices.IndexedSlices(ind_values, ind_indices, 895 ind_dense_shape) 896 # Feed with tuple 897 values_out, indices_out, dense_shape_out = s.run( 898 [ind_values, ind_indices, ind_dense_shape], { 899 ind: (values, indices, dense_shape) 900 }) 901 self.assertAllEqual(values_out, values) 902 self.assertAllEqual(indices_out, indices) 903 self.assertAllEqual(dense_shape_out, dense_shape) 904 # Feed with IndexedSlicesValue 905 values_out, indices_out, dense_shape_out = s.run([ 906 ind_values, ind_indices, ind_dense_shape 907 ], {ind: indexed_slices.IndexedSlicesValue(values, indices, dense_shape)}) 908 self.assertAllEqual(values_out, values) 909 self.assertAllEqual(indices_out, indices) 910 self.assertAllEqual(dense_shape_out, dense_shape) 911 # Feed with IndexedSlicesValue, fetch IndexedSlicesValue 912 ind2_out = s.run(ind2, { 913 ind: indexed_slices.IndexedSlicesValue(values, indices, dense_shape) 914 }) 915 self.assertAllEqual(ind2_out.values, values) 916 self.assertAllEqual(ind2_out.indices, indices) 917 self.assertAllEqual(ind2_out.dense_shape, dense_shape) 918 919 def testFetchIndexedSlicesWithoutDenseShape(self): 920 with session.Session() as s: 921 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 922 values = np.array([1.0, 2.0]).astype(np.float32) 923 dense_shape = None 924 ind = indexed_slices.IndexedSlices( 925 constant_op.constant(values), constant_op.constant(indices), None) 926 # Single fetch, use as tuple 927 ind_out = s.run(ind) 928 values_out, indices_out, dense_shape_out = ind_out 929 self.assertAllEqual(values_out, values) 930 self.assertAllEqual(indices_out, indices) 931 self.assertAllEqual(dense_shape_out, dense_shape) 932 # Single fetch, use as IndexedSlicesValue 933 ind_out = s.run(ind) 934 self.assertAllEqual(ind_out.values, values) 935 self.assertAllEqual(ind_out.indices, indices) 936 self.assertAllEqual(ind_out.dense_shape, dense_shape) 937 # Tuple fetch, use as tuple 938 values_out, indices_out, dense_shape_out = s.run(ind) 939 self.assertAllEqual(values_out, values) 940 self.assertAllEqual(indices_out, indices) 941 self.assertAllEqual(dense_shape_out, dense_shape) 942 # List fetch, use as tuple 943 (values_out, indices_out, dense_shape_out), = s.run([ind]) 944 self.assertAllEqual(values_out, values) 945 self.assertAllEqual(indices_out, indices) 946 self.assertAllEqual(dense_shape_out, dense_shape) 947 # List fetch, use as IndexedSlicesValue 948 ind_out, = s.run([ind]) 949 self.assertAllEqual(ind_out.values, values) 950 self.assertAllEqual(ind_out.indices, indices) 951 self.assertAllEqual(ind_out.dense_shape, dense_shape) 952 953 def testFeedIndexedSlicesWithoutDenseShape(self): 954 with session.Session() as s: 955 values = np.array([1.0, 2.0]).astype(np.float32) 956 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 957 dense_shape = None 958 ind = indexed_slices.IndexedSlices( 959 array_ops.placeholder(dtype=np.float32, shape=(2,)), 960 array_ops.placeholder(dtype=np.int64, shape=(2, 3)), None) 961 ind_values = array_ops.identity(ind.values) 962 ind_indices = array_ops.identity(ind.indices) 963 ind2 = indexed_slices.IndexedSlices(ind_values, ind_indices) 964 # Feed with tuple 965 values_out, indices_out = s.run([ind_values, ind_indices], { 966 ind: (values, indices) 967 }) 968 self.assertAllEqual(values_out, values) 969 self.assertAllEqual(indices_out, indices) 970 # Feed with IndexedSlicesValue 971 values_out, indices_out = s.run([ind_values, ind_indices], { 972 ind: indexed_slices.IndexedSlicesValue(values, indices, dense_shape) 973 }) 974 self.assertAllEqual(values_out, values) 975 self.assertAllEqual(indices_out, indices) 976 # Feed with IndexedSlicesValue, fetch IndexedSlicesValue 977 ind2_out = s.run(ind2, { 978 ind: indexed_slices.IndexedSlicesValue(values, indices, dense_shape) 979 }) 980 self.assertAllEqual(ind2_out.values, values) 981 self.assertAllEqual(ind2_out.indices, indices) 982 self.assertAllEqual(ind2_out.dense_shape, dense_shape) 983 984 def testExtendWithStatelessOperations(self): 985 with session.Session() as s: 986 a = constant_op.constant(1.0, shape=[1, 2]) 987 b = constant_op.constant(2.0, shape=[2, 3]) 988 c = math_ops.matmul(a, b) 989 c_val = s.run(c) 990 self.assertAllEqual([[4.0, 4.0, 4.0]], c_val) 991 d = constant_op.constant([1.0, 2.0, 3.0], shape=[3, 1]) 992 e = math_ops.matmul(c, d) 993 # Extend will happen here. 994 e_val = s.run(e) 995 self.assertAllEqual([[24.0]], e_val) 996 997 def testExtendWithStatefulOperations(self): 998 with session.Session() as s: 999 a = constant_op.constant(1.0, shape=[1, 2]) 1000 b = constant_op.constant(2.0, shape=[2, 3]) 1001 c = math_ops.matmul(a, b) 1002 v = variables.Variable(c, name='testExtendWithStatefulOperations_v') 1003 v.initializer.run() 1004 v_val = v.eval() 1005 self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) 1006 d = constant_op.constant(3.0, shape=[2, 3]) 1007 e = math_ops.matmul(a, d) 1008 assign_e_to_v = state_ops.assign(v, e) 1009 # Extend will happen here. 1010 e_val = e.eval() 1011 self.assertAllEqual([[6.0, 6.0, 6.0]], e_val) 1012 v_val = v.eval() 1013 self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) 1014 s.run(assign_e_to_v) 1015 v_val = v.eval() 1016 self.assertAllEqual([[6.0, 6.0, 6.0]], v_val) 1017 1018 def testExtendWithGroupBy(self): 1019 with session.Session() as s: 1020 a = constant_op.constant(1.0, shape=[1, 2]) 1021 p = variables.Variable(a, name='testExtendWithGroupBy_p') 1022 a_val = a.eval() # Force an Extend after this op. 1023 self.assertAllEqual([[1.0, 1.0]], a_val) 1024 1025 b = constant_op.constant(2.0, shape=[1, 2]) 1026 q = variables.Variable(b, name='testExtendWithGroupBy_q') 1027 # Extend will happen here. 1028 init = control_flow_ops.group(p.initializer, q.initializer) 1029 s.run(init) 1030 p_val, q_val = s.run([p, q]) 1031 1032 self.assertAllEqual([[1.0, 1.0]], p_val) 1033 self.assertAllEqual([[2.0, 2.0]], q_val) 1034 1035 def testTensorGetMethod(self): 1036 with session.Session(): 1037 a = constant_op.constant(1.0, shape=[1, 2]) 1038 b = constant_op.constant(2.0, shape=[2, 3]) 1039 c = math_ops.matmul(a, b) 1040 1041 c_val = c.eval() 1042 self.assertAllEqual([[4.0, 4.0, 4.0]], c_val) 1043 1044 fed_c_val = c.eval(feed_dict={a.name: [[4.0, 4.0]]}) 1045 self.assertAllEqual([[16.0, 16.0, 16.0]], fed_c_val) 1046 1047 @test_util.run_v1_only('b/120545219') 1048 def testOperationRunMethod(self): 1049 with session.Session(): 1050 a = constant_op.constant(1.0, shape=[1, 2]) 1051 b = constant_op.constant(2.0, shape=[1, 2], name='b') 1052 v = variables.VariableV1(a, a.dtype) 1053 assign_a_to_v = state_ops.assign(v, a) 1054 1055 assign_a_to_v.eval() 1056 1057 v_val = v.eval() 1058 self.assertAllEqual([[1.0, 1.0]], v_val) 1059 1060 assign_b_to_v = state_ops.assign(v, b) 1061 1062 assign_b_to_v.eval() 1063 v_val = v.eval() 1064 self.assertAllEqual([[2.0, 2.0]], v_val) 1065 1066 assign_b_to_v.eval(feed_dict={'b:0': [[3.0, 3.0]]}) 1067 v_val = v.eval() 1068 self.assertAllEqual([[3.0, 3.0]], v_val) 1069 1070 def testDefaultGraph(self): 1071 with session.Session() as s: 1072 self.assertEqual(ops.get_default_graph(), s.graph) 1073 a = constant_op.constant(1.0, shape=[1, 2]) 1074 b = constant_op.constant(2.0, shape=[2, 3]) 1075 self.assertEqual(ops.get_default_graph(), a.graph) 1076 self.assertEqual(ops.get_default_graph(), b.graph) 1077 c = math_ops.matmul(a, b) 1078 v = variables.Variable(c, name='testDefaultGraph_v') 1079 v.initializer.run() 1080 v_val = v.eval() 1081 self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) 1082 d = constant_op.constant(3.0, shape=[2, 3]) 1083 e = math_ops.matmul(a, d) 1084 assign_e_to_v = state_ops.assign(v, e) 1085 e_val = e.eval() 1086 self.assertAllEqual([[6.0, 6.0, 6.0]], e_val) 1087 v_val = v.eval() 1088 self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) 1089 s.run(assign_e_to_v) 1090 v_val = v.eval() 1091 self.assertAllEqual([[6.0, 6.0, 6.0]], v_val) 1092 self.assertEqual(ops.get_default_graph(), s.graph) 1093 1094 def _testDefaultGraphInThread(self, constructed_event, continue_event, i): 1095 with session.Session() as s: 1096 self.assertEqual(ops.get_default_graph(), s.graph) 1097 a = constant_op.constant(1.0, shape=[1, 2]) 1098 b = constant_op.constant(2.0, shape=[2, 3]) 1099 c = math_ops.matmul(a, b) 1100 v = variables.Variable(c, name='var_%d' % i) 1101 1102 # Block here until all threads have constructed their graph. 1103 constructed_event.set() 1104 continue_event.wait() 1105 1106 assign_c_to_v = state_ops.assign(v, c) 1107 v.initializer.run() 1108 assign_c_to_v.eval() 1109 v_val = v.eval() 1110 self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) 1111 d = constant_op.constant(3.0, shape=[2, 3]) 1112 e = math_ops.matmul(a, d) 1113 assign_e_to_v = state_ops.assign(v, e) 1114 e_val = e.eval() 1115 self.assertAllEqual([[6.0, 6.0, 6.0]], e_val) 1116 v_val = v.eval() 1117 self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) 1118 s.run(assign_e_to_v) 1119 v_val = v.eval() 1120 self.assertAllEqual([[6.0, 6.0, 6.0]], v_val) 1121 self.assertEqual(ops.get_default_graph(), s.graph) 1122 1123 def testDefaultGraphWithThreads(self): 1124 # Fork ten threads that use their thread-local default graph. 1125 threads = [] 1126 constructed_events = [threading.Event() for _ in range(10)] 1127 continue_event = threading.Event() 1128 for i, constructed_event in enumerate(constructed_events): 1129 t = self.checkedThread( 1130 target=self._testDefaultGraphInThread, 1131 args=(constructed_event, continue_event, i)) 1132 threads.append(t) 1133 for t in threads: 1134 t.start() 1135 for constructed_event in constructed_events: 1136 constructed_event.wait() 1137 continue_event.set() 1138 for t in threads: 1139 t.join() 1140 1141 def testParallelRun(self): 1142 with session.Session() as sess: 1143 c = constant_op.constant(5.0) 1144 ev = threading.Event() 1145 1146 def run_step(): 1147 ev.wait() 1148 val = c.eval(session=sess) 1149 self.assertEqual(val, 5.0) 1150 1151 threads = [self.checkedThread(target=run_step) for _ in range(100)] 1152 for t in threads: 1153 t.start() 1154 ev.set() 1155 for t in threads: 1156 t.join() 1157 1158 @staticmethod 1159 def _build_graph(): 1160 time.sleep(random.random() * 0.1) 1161 # Do some graph construction. Try to exercise non-trivial paths. 1162 graph = ops.get_default_graph() 1163 gdef = None 1164 for _ in range(10): 1165 x = array_ops.placeholder(dtype=dtypes.float32) 1166 with ops.colocate_with(x): 1167 y = array_ops.placeholder(dtype=dtypes.float32) 1168 with ops.device('/cpu:0'): 1169 z = control_flow_ops.while_loop( 1170 lambda x, y: x < 10, lambda x, y: (x + 1, x * y), [x, y]) 1171 with graph._attr_scope({'_a': attr_value_pb2.AttrValue(b=False)}): 1172 gradients_impl.gradients(z, [x, y]) 1173 if gdef is None: 1174 gdef = graph.as_graph_def() 1175 else: 1176 importer.import_graph_def(gdef, name='import') 1177 1178 @test_util.run_v1_only('b/120545219') 1179 def testParallelRunAndSingleBuild(self): 1180 with session.Session() as sess: 1181 c = constant_op.constant(5.0) 1182 stop = threading.Event() 1183 1184 def run_loop(): 1185 while not stop.is_set(): 1186 time.sleep(random.random() * 0.1) 1187 self.assertEqual(sess.run(c), 5.0) 1188 1189 threads = [self.checkedThread(target=run_loop) for _ in range(10)] 1190 for t in threads: 1191 t.start() 1192 1193 SessionTest._build_graph() 1194 1195 stop.set() 1196 for t in threads: 1197 t.join() 1198 1199 @test_util.run_v1_only('b/120545219') 1200 def testParallelRunAndParallelBuild(self): 1201 with session.Session() as sess: 1202 c = constant_op.constant(5.0) 1203 stop = threading.Event() 1204 1205 def run_loop(): 1206 while not stop.is_set(): 1207 time.sleep(random.random() * 0.1) 1208 self.assertEqual(sess.run(c), 5.0) 1209 1210 run_threads = [self.checkedThread(target=run_loop) for _ in range(10)] 1211 for t in run_threads: 1212 t.start() 1213 1214 build_threads = [self.checkedThread(target=SessionTest._build_graph) 1215 for _ in range(10)] 1216 for t in build_threads: 1217 t.start() 1218 for t in build_threads: 1219 t.join() 1220 1221 # Let the run_threads run until the build threads are finished. 1222 stop.set() 1223 for t in run_threads: 1224 t.join() 1225 1226 def testRunFeedDict(self): 1227 with session.Session() as s: 1228 x = array_ops.zeros([2]) 1229 1230 y = s.run(2 * x, feed_dict={x: np.ones(2).astype(np.float32)}) 1231 self.assertAllEqual(y, 2 * np.ones(2)) 1232 1233 y = s.run(2 * x, feed_dict={x.name: np.ones(2).astype(np.float32)}) 1234 self.assertAllEqual(y, 2 * np.ones(2)) 1235 1236 y = s.run(2 * x, feed_dict={x: [1, 1]}) 1237 assert (y == 2 * np.ones(2)).all() 1238 1239 # Test nested tuple keys 1240 z = (((array_ops.zeros([2]),),), array_ops.zeros([2]), 1241 (array_ops.zeros([2]),)) 1242 result = [z[0][0][0] * 2, z[1] * 2, z[2][0] * 2] 1243 values = (((np.array([1, 1]),),), np.array([2, 2]), (np.array([3, 3]),)) 1244 result_value = s.run(result, feed_dict={z: values}) 1245 self.assertAllEqual(result_value[0], 2 * np.ones(2)) 1246 self.assertAllEqual(result_value[1], 2 * np.array([2, 2])) 1247 self.assertAllEqual(result_value[2], 2 * np.array([3, 3])) 1248 1249 def testGraphDef(self): 1250 with session.Session() as sess: 1251 self.assertProtoEquals('versions { producer: %d min_consumer: %d }' % 1252 (versions.GRAPH_DEF_VERSION, 1253 versions.GRAPH_DEF_VERSION_MIN_CONSUMER), 1254 sess.graph_def) 1255 c = constant_op.constant(5.0, name='c') 1256 self.assertEqual(len(sess.graph_def.node), 1) 1257 d = constant_op.constant(6.0, name='d') 1258 self.assertEqual(len(sess.graph_def.node), 2) 1259 self.assertAllEqual(c, 5.0) 1260 self.assertAllEqual(d, 6.0) 1261 e = constant_op.constant(7.0, name='e') 1262 self.assertEqual(len(sess.graph_def.node), 3) 1263 self.assertAllEqual(e, 7.0) 1264 1265 def testUseAfterClose(self): 1266 with session.Session() as sess: 1267 c = constant_op.constant(5.0) 1268 self.assertAllEqual(sess.run(c), 5.0) 1269 with self.assertRaisesWithPredicateMatch( 1270 RuntimeError, lambda e: 'Attempted to use a closed Session.' in str(e)): 1271 sess.run(c) 1272 1273 def testUseAfterCloseConcurrent(self): 1274 with session.Session() as sess: 1275 c = constant_op.constant(5.0) 1276 self.assertAllEqual(sess.run(c), 5.0) 1277 1278 def update_thread(): 1279 with self.assertRaisesWithPredicateMatch( 1280 RuntimeError, 1281 lambda e: 'Attempted to use a closed Session.' in str(e)): 1282 while True: 1283 sess.run(c) 1284 1285 t = threading.Thread(target=update_thread) 1286 t.start() 1287 time.sleep(0.1) 1288 sess.close() 1289 t.join() 1290 1291 def testUseEmptyGraph(self): 1292 with session.Session() as sess: 1293 with self.assertRaisesRegex(RuntimeError, 'The Session graph is empty.'): 1294 sess.run([]) 1295 with self.assertRaisesRegex(RuntimeError, 'The Session graph is empty.'): 1296 sess.run(()) 1297 with self.assertRaisesRegex(RuntimeError, 'The Session graph is empty.'): 1298 sess.run({}) 1299 1300 @test_util.run_v1_only('b/120545219') 1301 def testNotEntered(self): 1302 # pylint: disable=protected-access 1303 self.assertIsNone(ops._default_session_stack.get_default()) 1304 # pylint: enable=protected-access 1305 with ops.device('/cpu:0'): 1306 sess = session.Session() 1307 c_1 = constant_op.constant(5.0) 1308 with sess.graph.as_default(): 1309 c_2 = constant_op.constant(5.0) 1310 self.assertEqual(c_1.graph, c_2.graph) 1311 self.assertEqual(sess.run(c_2), 5.0) 1312 with self.assertRaisesWithPredicateMatch( 1313 ValueError, lambda e: 'No default session is registered.' in str(e)): 1314 c_2.eval() 1315 1316 @test_util.run_v1_only('b/120545219') 1317 def testInteractive(self): 1318 with ops.device('/cpu:0'): 1319 sess = session.InteractiveSession() 1320 a = constant_op.constant(1.0, shape=[1, 2]) 1321 b = constant_op.constant(2.0, shape=[2, 3]) 1322 c = math_ops.matmul(a, b) 1323 self.assertAllEqual([[4.0, 4.0, 4.0]], c) 1324 d = constant_op.constant([1.0, 2.0, 3.0], shape=[3, 1]) 1325 e = math_ops.matmul(c, d) 1326 self.assertAllEqual([[24.0]], e) 1327 sess.close() 1328 1329 @test_util.run_v1_only('b/120545219') 1330 def testMultipleInteractiveSessionsWarning(self): 1331 # Reinitialize the global state to ensure that the expected warnings will 1332 # be emitted. 1333 session.InteractiveSession._active_session_count = 0 # pylint: disable=protected-access 1334 1335 sess = session.InteractiveSession() 1336 sess.run(constant_op.constant(4.0)) # Run so that the session is "opened". 1337 sess.close() 1338 # Opening and closing interactive sessions serially should not warn. 1339 with warnings.catch_warnings(record=True) as w: 1340 sess = session.InteractiveSession() 1341 sess.close() 1342 self.assertEqual(0, len(w)) 1343 1344 with warnings.catch_warnings(record=True) as w: 1345 sess = session.InteractiveSession() 1346 self.assertEqual(0, len(w)) 1347 with warnings.catch_warnings(record=True) as w: 1348 sess2 = session.InteractiveSession() 1349 self.assertEqual(1, len(w)) 1350 self.assertIn('An interactive session is already active. This can cause ' 1351 'out-of-memory errors in some cases. You must explicitly ' 1352 'call `InteractiveSession.close()` to release resources ' 1353 'held by the other session(s).', str(w[0].message)) 1354 sess2.close() 1355 sess.close() 1356 1357 @test_util.run_v1_only('b/120545219') 1358 def testInteractivePlacePrunedGraph(self): 1359 sess = session.InteractiveSession() 1360 1361 # Build a graph that has a bad op in it (no kernel). 1362 # 1363 # This test currently does not link in any GPU kernels, 1364 # which is why placing this is invalid. If at some point 1365 # GPU kernels are added to this test, some other different 1366 # op / device combo should be chosen. 1367 with ops.device('/device:GPU:0'): 1368 a = constant_op.constant(1.0, shape=[1, 2]) 1369 1370 b = constant_op.constant(1.0, shape=[1, 2]) 1371 1372 # Only run the valid op, this should work. 1373 b.eval() 1374 1375 with self.assertRaises(errors.InvalidArgumentError): 1376 a.eval() 1377 sess.close() 1378 1379 @test_util.run_v1_only('b/120545219') 1380 def testDefaultSessionPlacePrunedGraph(self): 1381 sess = session.Session() 1382 1383 # Build a graph that has a bad op in it (no kernel). 1384 # 1385 # This test currently does not link in any GPU kernels, 1386 # which is why placing this is invalid. If at some point 1387 # GPU kernels are added to this test, some other different 1388 # op / device combo should be chosen. 1389 with ops.device('/device:GPU:0'): 1390 _ = constant_op.constant(1.0, shape=[1, 2]) 1391 1392 b = constant_op.constant(1.0, shape=[1, 2]) 1393 1394 with self.assertRaises(errors.InvalidArgumentError): 1395 # Even though we don't run the bad op, we place the entire 1396 # graph, which should fail with a non-interactive session. 1397 sess.run(b) 1398 1399 sess.close() 1400 1401 def testSharedGraph(self): 1402 with ops.Graph().as_default() as g, ops.device('/cpu:0'): 1403 a = constant_op.constant(1.0, shape=[1, 2]) 1404 b = constant_op.constant(2.0, shape=[2, 3]) 1405 c = math_ops.matmul(a, b) 1406 1407 with session.Session(graph=g) as sess1: 1408 with session.Session(graph=g) as sess2: 1409 self.assertAllEqual(sess1.run(c), sess2.run(c)) 1410 1411 def testDuplicatedInputs(self): 1412 with session.Session() as sess: 1413 a = constant_op.constant(1.0, shape=[1, 2]) 1414 b = constant_op.constant(2.0, shape=[1, 3]) 1415 a_val, b_val, a2_val = sess.run([a, b, a]) 1416 self.assertAllEqual(a_val, [[1.0, 1.0]]) 1417 self.assertAllEqual(b_val, [[2.0, 2.0, 2.0]]) 1418 self.assertAllEqual(a2_val, [[1.0, 1.0]]) 1419 1420 def testFeedAndFetch(self): 1421 with session.Session() as sess: 1422 for dtype in [ 1423 dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32, 1424 dtypes.uint8, dtypes.int16, dtypes.int8, dtypes.int64, dtypes.bool, 1425 dtypes.complex64, dtypes.complex128 1426 ]: 1427 for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]: 1428 np_dtype = dtype.as_numpy_dtype 1429 1430 feed_t = array_ops.placeholder(dtype=dtype, shape=shape) 1431 out_t = array_ops.identity(feed_t) 1432 1433 np_array = np.random.randint(-10, 10, shape) 1434 1435 if dtype == dtypes.bool: 1436 np_array = np_array > 0 1437 elif dtype == dtypes.complex64: 1438 np_array = np.sqrt(np_array.astype(np_dtype)) 1439 elif dtype == dtypes.complex64: 1440 np_array = np.sqrt(np_array.astype(np_dtype)) 1441 else: 1442 np_array = np_array.astype(np_dtype) 1443 1444 self.assertAllEqual(np_array, 1445 sess.run(out_t, feed_dict={ 1446 feed_t: np_array 1447 })) 1448 # Check that we can also get the feed back. 1449 self.assertAllEqual(np_array, 1450 sess.run(feed_t, feed_dict={ 1451 feed_t: np_array 1452 })) 1453 # Also check that we can get both back. 1454 out_v, feed_v = sess.run( 1455 [out_t, feed_t], feed_dict={ 1456 feed_t: np_array 1457 }) 1458 self.assertAllEqual(np_array, out_v) 1459 self.assertAllEqual(np_array, feed_v) 1460 1461 feed_fetch_runner = sess.make_callable([out_t, feed_t], [feed_t]) 1462 out_v, feed_v = feed_fetch_runner(np_array) 1463 self.assertAllEqual(np_array, out_v) 1464 self.assertAllEqual(np_array, feed_v) 1465 1466 def testMakeCallableOnTensorWithRunOptions(self): 1467 with session.Session() as sess: 1468 a = constant_op.constant(42.0) 1469 tensor_runner = sess.make_callable(a, accept_options=True) 1470 run_options = config_pb2.RunOptions( 1471 trace_level=config_pb2.RunOptions.FULL_TRACE) 1472 run_metadata = config_pb2.RunMetadata() 1473 self.assertEqual(0, len(run_metadata.step_stats.dev_stats)) 1474 res = tensor_runner(options=run_options, run_metadata=run_metadata) 1475 self.assertEqual(42.0, res) 1476 self.assertGreater(len(run_metadata.step_stats.dev_stats), 0) 1477 1478 def testMakeCallableOnOperationWithRunOptions(self): 1479 with session.Session() as sess: 1480 a = variables.Variable(42.0) 1481 b = state_ops.assign_add(a, 1.0) 1482 sess.run(a.initializer) 1483 tensor_runner = sess.make_callable(b.op, accept_options=True) 1484 run_options = config_pb2.RunOptions( 1485 trace_level=config_pb2.RunOptions.FULL_TRACE) 1486 run_metadata = config_pb2.RunMetadata() 1487 self.assertEqual(0, len(run_metadata.step_stats.dev_stats)) 1488 tensor_runner(options=run_options, run_metadata=run_metadata) 1489 self.assertEqual(43.0, sess.run(a)) 1490 self.assertGreater(len(run_metadata.step_stats.dev_stats), 0) 1491 1492 def testMakeCallableWithFeedListAndRunOptions(self): 1493 with session.Session() as sess: 1494 ph = array_ops.placeholder(dtypes.float32) 1495 a = math_ops.add(ph, 1.0) 1496 tensor_runner = sess.make_callable( 1497 a, feed_list=[ph.name], accept_options=True) 1498 run_options = config_pb2.RunOptions( 1499 trace_level=config_pb2.RunOptions.FULL_TRACE) 1500 run_metadata = config_pb2.RunMetadata() 1501 self.assertEqual(0, len(run_metadata.step_stats.dev_stats)) 1502 self.assertAllClose(42.0, 1503 tensor_runner( 1504 41.0, 1505 options=run_options, 1506 run_metadata=run_metadata)) 1507 self.assertGreater(len(run_metadata.step_stats.dev_stats), 0) 1508 1509 def testOptimizedMakeCallable(self): 1510 with session.Session() as sess: 1511 ph = array_ops.placeholder(dtypes.float32) 1512 a = math_ops.add(ph, 1.0) 1513 callable_opts = config_pb2.CallableOptions() 1514 callable_opts.feed.append(ph.name) 1515 callable_opts.fetch.append(a.name) 1516 for _ in range(3): 1517 callable_fn = sess._make_callable_from_options(callable_opts) 1518 for _ in range(5): 1519 self.assertEqual([2.0], callable_fn(np.array(1.0, dtype=np.float32))) 1520 1521 def testOptimizedMakeCallableWithRunMetadata(self): 1522 with session.Session() as sess: 1523 ph = array_ops.placeholder(dtypes.float32) 1524 a = math_ops.add(ph, 1.0) 1525 callable_opts = config_pb2.CallableOptions() 1526 callable_opts.feed.append(ph.name) 1527 callable_opts.fetch.append(a.name) 1528 callable_opts.run_options.trace_level = config_pb2.RunOptions.FULL_TRACE 1529 callable_fn = sess._make_callable_from_options(callable_opts) 1530 run_metadata = config_pb2.RunMetadata() 1531 self.assertEqual([2.0], callable_fn(np.array(1.0, dtype=np.float32), 1532 run_metadata=run_metadata)) 1533 self.assertGreater(len(run_metadata.step_stats.dev_stats), 0) 1534 1535 def testFeedError(self): 1536 with session.Session() as sess: 1537 feed_t = array_ops.placeholder(dtype=dtypes.float32) 1538 out_t = array_ops.identity(feed_t) 1539 feed_val = constant_op.constant(5.0) 1540 with self.assertRaisesRegex(TypeError, 'cannot be a tf.Tensor object'): 1541 sess.run(out_t, feed_dict={feed_t: feed_val}) 1542 with self.assertRaisesRegex(TypeError, 'cannot be a tf.Tensor object'): 1543 out_t.eval(feed_dict={feed_t: feed_val}) 1544 with self.assertRaisesRegex(TypeError, 'cannot be a tf.Tensor object'): 1545 out_t.op.run(feed_dict={feed_t: feed_val}) 1546 1547 def testFeedPrecisionLossError(self): 1548 with session.Session() as sess: 1549 largest_int64 = np.iinfo(np.int64).max 1550 1551 feed_int_implicit_int32 = constant_op.constant(1) 1552 feed_int_explicit_int32 = constant_op.constant(1, dtype=dtypes.int32) 1553 1554 out_t = constant_op.constant(1.0) 1555 1556 with self.assertRaisesRegex(TypeError, 1557 'is not compatible with Tensor type'): 1558 sess.run(out_t, feed_dict={feed_int_implicit_int32: largest_int64}) 1559 with self.assertRaisesRegex(TypeError, 1560 'is not compatible with Tensor type'): 1561 sess.run(out_t, feed_dict={feed_int_explicit_int32: largest_int64}) 1562 1563 def testStringFetch(self): 1564 with session.Session(): 1565 for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]: 1566 size = 1 1567 for s in shape: 1568 size *= s 1569 c_list = np.array([compat.as_bytes(str(i)) for i in range(size)], 1570 dtype=np.object_).reshape(shape) if size > 0 else [] 1571 c = constant_op.constant(c_list) 1572 self.assertAllEqual(c, c_list) 1573 1574 def testStringFeed(self): 1575 with session.Session() as sess: 1576 for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]: 1577 size = 1 1578 for s in shape: 1579 size *= s 1580 c_list = np.array([compat.as_bytes(str(i)) for i in range(size)], 1581 dtype=np.object_).reshape(shape) 1582 feed_t = array_ops.placeholder(dtype=dtypes.string, shape=shape) 1583 c = array_ops.identity(feed_t) 1584 self.assertAllEqual(sess.run(c, feed_dict={feed_t: c_list}), c_list) 1585 self.assertAllEqual( 1586 sess.run(feed_t, feed_dict={ 1587 feed_t: c_list 1588 }), c_list) 1589 c_v, feed_v = sess.run([c, feed_t], feed_dict={feed_t: c_list}) 1590 self.assertAllEqual(c_v, c_list) 1591 self.assertAllEqual(feed_v, c_list) 1592 1593 def testStringFeedWithNullCharacters(self): 1594 with session.Session(): 1595 c_list = [b'\n\x01\x00', b'\n\x00\x01'] 1596 feed_t = array_ops.placeholder(dtype=dtypes.string, shape=[2]) 1597 c = array_ops.identity(feed_t) 1598 out = c.eval(feed_dict={feed_t: c_list}) 1599 self.assertEqual(c_list[0], out[0]) 1600 self.assertEqual(c_list[1], out[1]) 1601 1602 def testStringFeedWithUnicode(self): 1603 with session.Session(): 1604 c_list = [ 1605 u'\n\x01\x00', u'\n\x00\x01', u'\u26a3 unicode', 1606 u'\U0001f60e deal with it' 1607 ] 1608 feed_t = array_ops.placeholder(dtype=dtypes.string, shape=[len(c_list)]) 1609 c = array_ops.identity(feed_t) 1610 1611 out = c.eval(feed_dict={feed_t: c_list}) 1612 for i in range(len(c_list)): 1613 self.assertEqual(c_list[i], out[i].decode('utf-8')) 1614 1615 out = c.eval(feed_dict={feed_t: np.array(c_list, dtype=np.object_)}) 1616 for i in range(len(c_list)): 1617 self.assertEqual(c_list[i], out[i].decode('utf-8')) 1618 1619 def testInvalidTargetFails(self): 1620 with self.assertRaisesRegex( 1621 errors.NotFoundError, 1622 'No session factory registered for the given session options'): 1623 session.Session('INVALID_TARGET') 1624 1625 def testFetchByNameDifferentStringTypes(self): 1626 with session.Session() as sess: 1627 c = constant_op.constant(42.0, name='c') 1628 d = constant_op.constant(43.0, name=u'd') 1629 e = constant_op.constant(44.0, name=b'e') 1630 f = constant_op.constant(45.0, name=r'f') 1631 1632 self.assertIsInstance(c.name, six.text_type) 1633 self.assertIsInstance(d.name, six.text_type) 1634 self.assertIsInstance(e.name, six.text_type) 1635 self.assertIsInstance(f.name, six.text_type) 1636 1637 self.assertEqual(42.0, sess.run('c:0')) 1638 self.assertEqual(42.0, sess.run(u'c:0')) 1639 self.assertEqual(42.0, sess.run(b'c:0')) 1640 self.assertEqual(42.0, sess.run(r'c:0')) 1641 1642 self.assertEqual(43.0, sess.run('d:0')) 1643 self.assertEqual(43.0, sess.run(u'd:0')) 1644 self.assertEqual(43.0, sess.run(b'd:0')) 1645 self.assertEqual(43.0, sess.run(r'd:0')) 1646 1647 self.assertEqual(44.0, sess.run('e:0')) 1648 self.assertEqual(44.0, sess.run(u'e:0')) 1649 self.assertEqual(44.0, sess.run(b'e:0')) 1650 self.assertEqual(44.0, sess.run(r'e:0')) 1651 1652 self.assertEqual(45.0, sess.run('f:0')) 1653 self.assertEqual(45.0, sess.run(u'f:0')) 1654 self.assertEqual(45.0, sess.run(b'f:0')) 1655 self.assertEqual(45.0, sess.run(r'f:0')) 1656 1657 def testIncorrectGraph(self): 1658 with ops.Graph().as_default() as g_1: 1659 c_1 = constant_op.constant(1.0, name='c') 1660 1661 with ops.Graph().as_default() as g_2: 1662 c_2 = constant_op.constant(2.0, name='c') 1663 1664 self.assertEqual('c', c_1.op.name) 1665 self.assertEqual('c', c_2.op.name) 1666 1667 with session.Session(graph=g_1) as sess_1: 1668 self.assertEqual(1.0, sess_1.run(c_1)) 1669 with self.assertRaises(ValueError): 1670 sess_1.run(c_2) 1671 with self.assertRaises(ValueError): 1672 sess_1.run(c_2.op) 1673 1674 with session.Session(graph=g_2) as sess_2: 1675 with self.assertRaises(ValueError): 1676 sess_2.run(c_1) 1677 with self.assertRaises(ValueError): 1678 sess_2.run(c_1.op) 1679 self.assertEqual(2.0, sess_2.run(c_2)) 1680 1681 def testFeedDictKeyException(self): 1682 with session.Session() as sess: 1683 a = constant_op.constant(1.0, dtypes.float32, name='a') 1684 with self.assertRaisesRegex(TypeError, 'Cannot interpret feed_dict'): 1685 sess.run(a, feed_dict={'a': [2.0]}) 1686 1687 def testPerStepTrace(self): 1688 run_options = config_pb2.RunOptions( 1689 trace_level=config_pb2.RunOptions.SOFTWARE_TRACE) 1690 run_metadata = config_pb2.RunMetadata() 1691 1692 with ops.device('/cpu:0'): 1693 with session.Session() as sess: 1694 sess.run(constant_op.constant(1.0)) 1695 self.assertFalse(run_metadata.HasField('step_stats')) 1696 1697 sess.run(constant_op.constant(1.0), run_metadata=run_metadata) 1698 self.assertFalse(run_metadata.HasField('step_stats')) 1699 1700 sess.run( 1701 constant_op.constant(1.0), 1702 options=run_options, 1703 run_metadata=run_metadata) 1704 1705 self.assertTrue(run_metadata.HasField('step_stats')) 1706 self.assertEqual(len(run_metadata.step_stats.dev_stats), 1) 1707 1708 def testRunOptionsRunMetadata(self): 1709 run_options = config_pb2.RunOptions( 1710 trace_level=config_pb2.RunOptions.SOFTWARE_TRACE) 1711 run_metadata = config_pb2.RunMetadata() 1712 1713 with ops.device('/cpu:0'): 1714 with session.Session() as sess: 1715 # all combinations are valid 1716 sess.run(constant_op.constant(1.0), options=None, run_metadata=None) 1717 sess.run( 1718 constant_op.constant(1.0), options=None, run_metadata=run_metadata) 1719 self.assertFalse(run_metadata.HasField('step_stats')) 1720 1721 sess.run( 1722 constant_op.constant(1.0), options=run_options, run_metadata=None) 1723 self.assertFalse(run_metadata.HasField('step_stats')) 1724 1725 sess.run( 1726 constant_op.constant(1.0), 1727 options=run_options, 1728 run_metadata=run_metadata) 1729 1730 self.assertTrue(run_metadata.HasField('step_stats')) 1731 self.assertEqual(len(run_metadata.step_stats.dev_stats), 1) 1732 1733 def testFeedShapeCompatibility(self): 1734 with session.Session() as sess: 1735 some_tensor = constant_op.constant([2.0, 2.0, 2.0, 2.0]) 1736 new_shape = constant_op.constant([2, 2]) 1737 reshaped_tensor = array_ops.reshape(some_tensor, new_shape) 1738 1739 with self.assertRaisesRegex(ValueError, 'Cannot feed value of shape'): 1740 sess.run(reshaped_tensor, feed_dict={some_tensor: [1.0, 2.0, 3.0]}) 1741 1742 with self.assertRaisesRegex( 1743 errors.InvalidArgumentError, 1744 'Input to reshape is a tensor with 4 values, ' 1745 'but the requested shape has 21'): 1746 sess.run(reshaped_tensor, feed_dict={new_shape: [3, 7]}) 1747 1748 def testInferShapesFalse(self): 1749 with ops.Graph().as_default(), ops.device('/cpu:0'): 1750 a = constant_op.constant([[1, 2]]) 1751 sess = session.Session() 1752 self.assertNotIn('_output_shapes', sess.graph_def.node[0].attr) 1753 # Avoid lint error regarding 'unused' var a. 1754 self.assertEqual(a, a) 1755 1756 def testInferShapesTrue(self): 1757 config_pb = config_pb2.ConfigProto( 1758 graph_options=config_pb2.GraphOptions(infer_shapes=True)) 1759 with ops.Graph().as_default(), ops.device('/cpu:0'): 1760 a = constant_op.constant([[1, 2]]) 1761 sess = session.Session(config=config_pb) 1762 self.assertIn('_output_shapes', sess.graph_def.node[0].attr) 1763 # Avoid lint error regarding 'unused' var a. 1764 self.assertEqual(a, a) 1765 1766 def testBuildCostModel(self): 1767 run_options = config_pb2.RunOptions() 1768 config_pb = config_pb2.ConfigProto( 1769 allow_soft_placement=True, 1770 graph_options=config_pb2.GraphOptions(build_cost_model=100)) 1771 with session.Session(config=config_pb) as sess: 1772 with ops.device('/device:GPU:0'): 1773 a = array_ops.placeholder(dtypes.float32, shape=[]) 1774 b = math_ops.add(a, a) 1775 c = array_ops.identity(b) 1776 d = math_ops.multiply(c, c) 1777 for step in range(120): 1778 run_metadata = config_pb2.RunMetadata() 1779 sess.run( 1780 d, 1781 feed_dict={a: 1.0}, 1782 options=run_options, 1783 run_metadata=run_metadata) 1784 if step == 99: 1785 self.assertTrue(run_metadata.HasField('cost_graph')) 1786 else: 1787 self.assertFalse(run_metadata.HasField('cost_graph')) 1788 1789 def runTestOutputPartitionGraphs(self, sess): 1790 run_options = config_pb2.RunOptions(output_partition_graphs=True) 1791 a = constant_op.constant(1) 1792 run_metadata = config_pb2.RunMetadata() 1793 sess.run(a, options=run_options, run_metadata=run_metadata) 1794 self.assertGreater(len(run_metadata.partition_graphs), 0) 1795 sess.run(a, run_metadata=run_metadata) 1796 self.assertEqual(len(run_metadata.partition_graphs), 0) 1797 1798 @test_util.run_v1_only('b/120545219') 1799 def testOutputPartitionGraphsDirect(self): 1800 self.runTestOutputPartitionGraphs(session.Session()) 1801 1802 @test_util.run_v1_only('b/120545219') 1803 def testOutputPartitionGraphsDistributed(self): 1804 server = server_lib.Server.create_local_server() 1805 self.runTestOutputPartitionGraphs(session.Session(server.target)) 1806 1807 def testNonInteractiveSessionNesting(self): 1808 sess1 = session.Session() 1809 sess1_controller = sess1.as_default() 1810 sess1_controller.__enter__() 1811 1812 sess2 = session.Session() 1813 sess2_controller = sess2.as_default() 1814 sess2_controller.__enter__() 1815 1816 with self.assertRaisesRegex(AssertionError, 'Nesting violated'): 1817 sess1_controller.__exit__(None, None, None) 1818 1819 ops._default_session_stack.reset() 1820 1821 def testInteractiveSessionNesting(self): 1822 sess1 = session.InteractiveSession() 1823 sess2 = session.InteractiveSession() 1824 del sess1 1825 del sess2 1826 1827 @test_util.run_v1_only('b/120545219') 1828 def testAsDefault(self): 1829 c = constant_op.constant(37) 1830 sess = session.Session() 1831 with sess.as_default(): 1832 self.assertEqual(37, c.eval()) 1833 1834 # Ensure that the session remains valid even when it is not captured. 1835 with session.Session().as_default(): 1836 self.assertEqual(37, c.eval()) 1837 1838 def testReentry(self): 1839 sess = session.Session() 1840 with self.assertRaisesRegex(RuntimeError, 'not re-entrant'): 1841 with sess: 1842 with sess: 1843 pass 1844 1845 def testInvalidArgument(self): 1846 with self.assertRaisesRegex(TypeError, 1847 'Argument `target` must be a string'): 1848 session.Session(37) 1849 with self.assertRaisesRegex(TypeError, 1850 'Argument `config` must be a tf.ConfigProto'): 1851 session.Session(config=37) 1852 with self.assertRaisesRegex(TypeError, 1853 'Argument `graph` must be a tf.Graph'): 1854 session.Session(graph=37) 1855 1856 @test_util.run_v1_only('b/120545219') 1857 def testTimeoutWithShortOperations(self): 1858 num_epochs = 5 1859 q = data_flow_ops.FIFOQueue(capacity=50, dtypes=[dtypes.int32], shapes=[()]) 1860 enqueue_op = q.enqueue_many(constant_op.constant([1, 2])) 1861 1862 # Use a 10-second timeout, which should be longer than any 1863 # non-blocking enqueue_many op. 1864 config_pb = config_pb2.ConfigProto(operation_timeout_in_ms=10000) 1865 with session.Session(config=config_pb) as sess: 1866 for _ in range(num_epochs): 1867 sess.run(enqueue_op) 1868 self.assertEqual(sess.run(q.size()), num_epochs * 2) 1869 1870 @test_util.run_v1_only('b/120545219') 1871 def testRegisterFetchAndFeedConversionFunctions(self): 1872 1873 class SquaredTensor(object): 1874 1875 def __init__(self, tensor): 1876 self.sq = math_ops.square(tensor) 1877 1878 fetch_fn = lambda squared_tensor: ([squared_tensor.sq], lambda val: val[0]) 1879 feed_fn1 = lambda feed, feed_val: [(feed.sq, feed_val)] 1880 feed_fn2 = lambda feed: [feed.sq] 1881 1882 session.register_session_run_conversion_functions(SquaredTensor, fetch_fn, 1883 feed_fn1, feed_fn2) 1884 with self.assertRaises(ValueError): 1885 session.register_session_run_conversion_functions(SquaredTensor, fetch_fn, 1886 feed_fn1, feed_fn2) 1887 with self.cached_session() as sess: 1888 np1 = np.array([1.0, 1.5, 2.0, 2.5]) 1889 np2 = np.array([3.0, 3.5, 4.0, 4.5]) 1890 squared_tensor = SquaredTensor(np2) 1891 squared_eval = sess.run(squared_tensor) 1892 self.assertAllClose(np2 * np2, squared_eval) 1893 squared_eval = sess.run( 1894 squared_tensor, feed_dict={ 1895 squared_tensor: np1 * np1 1896 }) 1897 self.assertAllClose(np1 * np1, squared_eval) 1898 partial_run = sess.partial_run_setup([squared_tensor], []) 1899 squared_eval = sess.partial_run(partial_run, squared_tensor) 1900 self.assertAllClose(np2 * np2, squared_eval) 1901 1902 def testDefaultLogDevicePlacement(self): 1903 1904 class CaptureStderr(str): 1905 """Class to capture stderr from C++ shared library.""" 1906 1907 def __enter__(self): 1908 self._esc = compat.as_str('\b') 1909 self._output = compat.as_str('') 1910 self._stderr = sys.stderr 1911 self._fd = self._stderr.fileno() 1912 self._out_pipe, in_pipe = os.pipe() 1913 # Save the original io stream. 1914 self._dup_fd = os.dup(self._fd) 1915 # Replace the original io stream with in pipe. 1916 os.dup2(in_pipe, self._fd) 1917 return self 1918 1919 def __exit__(self, *args): 1920 self._stderr.write(self._esc) 1921 self._stderr.flush() 1922 self.read() 1923 os.close(self._out_pipe) 1924 # Restore the original io stream. 1925 os.dup2(self._dup_fd, self._fd) 1926 1927 def read(self): 1928 while True: 1929 data = os.read(self._out_pipe, 1) 1930 if not data or compat.as_str(data) == self._esc: 1931 break 1932 self._output += compat.as_str(data) 1933 1934 def __str__(self): 1935 return self._output 1936 1937 context.set_log_device_placement(True) 1938 if context.executing_eagerly(): 1939 with CaptureStderr() as log: 1940 a = constant_op.constant(1) 1941 b = constant_op.constant(2) 1942 c = a + b 1943 # Ensure if the same kernel with the same arguments is executed then its 1944 # execution is logged. 1945 d = a + b 1946 else: 1947 # Passing the config to the server, but not the session should still 1948 # result in logging device placement. 1949 config_pb = config_pb2.ConfigProto(log_device_placement=True) 1950 server = server_lib.Server.create_local_server(config=config_pb) 1951 a = constant_op.constant(1) 1952 b = constant_op.constant(2) 1953 c = a + b 1954 d = a + b 1955 with session.Session(server.target) as sess: 1956 with CaptureStderr() as log: 1957 c, d = sess.run([c, d]) 1958 1959 self.assertEqual(c, 3) 1960 self.assertEqual(d, 3) 1961 1962 # Ensure that we did log device placement. 1963 # We have three modes of execution at the moment: 1964 # (1) TF1 Graph (2) TF2 eager (3) TF2 eager with function wrapping. 1965 # The codepaths taken by each are slightly different in all resulting in 1966 # slightly different logging messages. 1967 log_msg = ('Executing op AddV2' 1968 if ops.executing_eagerly_outside_functions() else 'AddV2') 1969 add_executions = [l for l in str(log).splitlines() if log_msg in l] 1970 self.assertEqual(len(add_executions), 2) 1971 1972 @def_function.function 1973 def fn(a, b): 1974 c = a + b 1975 # These two AddV2 cannot use the same argument in tf.function since an 1976 # optimization pass will remove duplicate ops and only run it once. 1977 d = a + c 1978 return c, d 1979 1980 with CaptureStderr() as log: 1981 c, d = self.evaluate(fn(constant_op.constant(1), constant_op.constant(2))) 1982 self.assertEqual(c, 3) 1983 self.assertEqual(d, 4) 1984 # Ensure that we did log device placement. 1985 add_executions = [l for l in str(log).splitlines() if 'AddV2' in l] 1986 self.assertEqual(len(add_executions), 2) 1987 1988 @test_util.run_v1_only('b/120545219') 1989 def testLocalMasterSessionTimeout(self): 1990 # Test that the timeout passed in a config to the session works correctly. 1991 config_pb = config_pb2.ConfigProto(operation_timeout_in_ms=1000) 1992 server = server_lib.Server.create_local_server() 1993 q = data_flow_ops.FIFOQueue(1, dtypes.float32) 1994 dequeued_t = q.dequeue() 1995 1996 with session.Session(server.target, config=config_pb) as sess: 1997 # Intentionally do not run any enqueue_ops so that dequeue will block 1998 # until operation_timeout_in_ms. 1999 with self.assertRaises(errors.DeadlineExceededError): 2000 sess.run(dequeued_t) 2001 2002 @test_util.run_v1_only('b/120545219') 2003 def testDefaultServerTimeout(self): 2004 # Test that the default server config timeout gets used when no Session 2005 # config is provided. 2006 config_pb = config_pb2.ConfigProto(operation_timeout_in_ms=1000) 2007 server = server_lib.Server.create_local_server(config=config_pb) 2008 q = data_flow_ops.FIFOQueue(1, dtypes.float32) 2009 dequeued_t = q.dequeue() 2010 2011 with session.Session(server.target) as sess: 2012 # Intentionally do not run any enqueue_ops so that dequeue will block 2013 # until operation_timeout_in_ms. 2014 with self.assertRaises(errors.DeadlineExceededError): 2015 sess.run(dequeued_t) 2016 2017 def runTestBuildGraphError(self, sess): 2018 # Ensure that errors from building the graph get propagated. 2019 data = array_ops.placeholder(dtypes.float32, shape=[]) 2020 # pylint: disable=protected-access 2021 enter_1 = gen_control_flow_ops.enter(data, 'foo_1', False) 2022 enter_2 = gen_control_flow_ops.enter(data, 'foo_2', False) 2023 # pylint: enable=protected-access 2024 res = math_ops.add(enter_1, enter_2) 2025 with self.assertRaisesOpError('has inputs from different frames'): 2026 sess.run(res, feed_dict={data: 1.0}) 2027 2028 @test_util.run_v1_only('b/120545219') 2029 def testBuildGraphErrorDirect(self): 2030 self.runTestBuildGraphError(session.Session()) 2031 2032 @test_util.run_v1_only('b/120545219') 2033 def testBuildGraphErrorDist(self): 2034 server = server_lib.Server.create_local_server() 2035 self.runTestBuildGraphError(session.Session(server.target)) 2036 2037 def testDeviceAttributes(self): 2038 attrs = session._DeviceAttributes( 2039 '/job:worker/replica:0/task:3/device:CPU:2', 'TYPE', 1337, 1000000) 2040 self.assertEqual(1337, attrs.memory_limit_bytes) 2041 self.assertEqual('/job:worker/replica:0/task:3/device:CPU:2', attrs.name) 2042 self.assertEqual('TYPE', attrs.device_type) 2043 self.assertEqual(1000000, attrs.incarnation) 2044 str_repr = '%s' % attrs 2045 self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr) 2046 2047 def testDeviceAttributesCanonicalization(self): 2048 attrs = session._DeviceAttributes('/job:worker/replica:0/task:3/cpu:1', 2049 'TYPE', 1337, 1000000) 2050 self.assertEqual(1337, attrs.memory_limit_bytes) 2051 self.assertEqual('/job:worker/replica:0/task:3/device:CPU:1', attrs.name) 2052 self.assertEqual('TYPE', attrs.device_type) 2053 self.assertEqual(1000000, attrs.incarnation) 2054 str_repr = '%s' % attrs 2055 self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr) 2056 2057 def runTestAddFunctionToSession(self, target=''): 2058 """Add a function to a session after the graph has already been run.""" 2059 2060 @function.Defun(dtypes.float32) 2061 def foo(x): 2062 return x + 1 2063 2064 x = constant_op.constant(1.0) 2065 with session.Session(target=target) as sess: 2066 sess.run(x) 2067 f = foo(x) 2068 result = sess.run(f) 2069 self.assertEqual(result, 2.0) 2070 2071 @test_util.run_v1_only('b/120545219') 2072 def testAddFunctionToSession(self): 2073 self.runTestAddFunctionToSession() 2074 2075 @test_util.run_v1_only('b/120545219') 2076 def testAddFunctionToGrpcSession(self): 2077 server = server_lib.Server.create_local_server() 2078 self.runTestAddFunctionToSession(server.target) 2079 2080 def testOpenAndCloseGrpcSession(self): 2081 server = server_lib.Server.create_local_server() 2082 with session.Session(server.target): 2083 pass 2084 2085 def testOpenAndCloseSession(self): 2086 with session.Session(): 2087 pass 2088 2089 @test_util.run_v1_only('b/120545219') 2090 def testAutoConvertAndCheckData(self): 2091 with self.cached_session() as sess: 2092 a = array_ops.placeholder(dtype=dtypes.string) 2093 with self.assertRaisesRegex( 2094 TypeError, r'Type of feed value 1 with type <(\w+) \'int\'> is not'): 2095 sess.run(a, feed_dict={a: 1}) 2096 2097 @test_util.run_v1_only('b/120545219') 2098 def testOptimizerOptions(self): 2099 config.set_optimizer_experimental_options({'min_graph_nodes': -1}) 2100 2101 with ops.Graph().as_default(): 2102 sess = session.Session() 2103 self.assertEqual( 2104 sess._config.graph_options.rewrite_options.min_graph_nodes, -1) 2105 2106 2107if __name__ == '__main__': 2108 googletest.main() 2109