xref: /aosp_15_r20/external/tensorflow/tensorflow/python/client/session_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.python.client.session.Session."""
16import collections
17import os
18import random
19import sys
20import threading
21import time
22import warnings
23
24import numpy as np
25import six
26
27from tensorflow.core.framework import attr_value_pb2
28from tensorflow.core.lib.core import error_codes_pb2
29from tensorflow.core.protobuf import config_pb2
30from tensorflow.python.client import session
31from tensorflow.python.eager import context
32from tensorflow.python.eager import def_function
33from tensorflow.python.framework import config
34from tensorflow.python.framework import constant_op
35from tensorflow.python.framework import device as framework_device_lib
36from tensorflow.python.framework import dtypes
37from tensorflow.python.framework import errors
38from tensorflow.python.framework import function
39from tensorflow.python.framework import importer
40from tensorflow.python.framework import indexed_slices
41from tensorflow.python.framework import ops
42from tensorflow.python.framework import sparse_tensor
43from tensorflow.python.framework import tensor_util
44from tensorflow.python.framework import test_util
45from tensorflow.python.framework import versions
46from tensorflow.python.ops import array_ops
47from tensorflow.python.ops import control_flow_ops
48from tensorflow.python.ops import data_flow_ops
49from tensorflow.python.ops import gen_control_flow_ops
50# Import gradients to resolve circular imports
51from tensorflow.python.ops import gradients  # pylint: disable=unused-import
52from tensorflow.python.ops import gradients_impl
53from tensorflow.python.ops import math_ops
54# Import resource_variable_ops for the variables-to-tensor implicit conversion.
55from tensorflow.python.ops import resource_variable_ops  # pylint: disable=unused-import
56from tensorflow.python.ops import state_ops
57from tensorflow.python.ops import variables
58from tensorflow.python.platform import googletest
59from tensorflow.python.training import server_lib
60from tensorflow.python.util import compat
61
62try:
63  import attr  # pylint:disable=g-import-not-at-top
64except ImportError:
65  attr = None
66
67try:
68  from frozendict import frozendict  # pylint:disable=g-import-not-at-top
69except ImportError:
70  frozendict = dict  # pylint:disable=invalid-name
71
72defaultdict = collections.defaultdict  # pylint:disable=invalid-name
73
74
75@test_util.with_eager_op_as_function
76class SessionTest(test_util.TensorFlowTestCase):
77
78  def setUp(self):
79    super(SessionTest, self).setUp()
80    warnings.simplefilter('always')
81
82  def testUseExistingGraph(self):
83    with ops.Graph().as_default() as g, ops.device('/cpu:0'):
84      a = constant_op.constant(6.0, shape=[1, 1])
85      b = constant_op.constant(7.0, shape=[1, 1])
86      c = math_ops.matmul(a, b, name='matmul')
87    with session.Session(graph=g):
88      result = c.eval()
89      self.assertAllEqual(result, [[42.0]])
90
91  def testUseDefaultGraph(self):
92    with ops.Graph().as_default(), ops.device('/cpu:0'):
93      a = constant_op.constant(6.0, shape=[1, 1])
94      b = constant_op.constant(7.0, shape=[1, 1])
95      c = math_ops.matmul(a, b, name='matmul')
96      with session.Session():
97        result = c.eval()
98        self.assertAllEqual(result, [[42.0]])
99
100  def testCreate(self):
101    with session.Session():
102      inp = constant_op.constant(10.0, shape=[2, 3], name='W1')
103      copy = array_ops.identity(inp)
104      # Test with feed.
105      # TODO(mrry): Investigate why order='F' didn't work.
106      arr = np.asarray([[0, 1, 2], [3, 4, 5]], dtype=np.float32, order='C')
107      copy_val = copy.eval({'W1:0': arr})
108      self.assertAllEqual(arr, copy_val)
109      # Test without feed.
110      copy_val = copy.eval()
111      self.assertAllEqual(
112          np.asarray(
113              [[10.0, 10.0, 10.0], [10.0, 10.0, 10.0]], dtype=np.float32),
114          copy_val)
115
116  def testManyCPUs(self):
117    with session.Session(
118        config=config_pb2.ConfigProto(device_count={
119            'CPU': 2, 'GPU': 0
120        })) as sess:
121      inp = constant_op.constant(10.0, name='W1')
122      self.assertAllEqual(inp, 10.0)
123
124      num_cpu_devices = 0
125      num_gpu_devices = 0
126      for device in sess.list_devices():
127        device_type = framework_device_lib.DeviceSpec.from_string(
128            device.name).device_type
129        if device_type == 'CPU':
130          num_cpu_devices += 1
131        elif device_type == 'GPU':
132          num_gpu_devices += 1
133      self.assertEqual(2, num_cpu_devices)
134      self.assertEqual(0, num_gpu_devices)
135
136  def testPerSessionThreads(self):
137    with session.Session(
138        config=config_pb2.ConfigProto(use_per_session_threads=True)):
139      inp = constant_op.constant(10.0, name='W1')
140      self.assertAllEqual(inp, 10.0)
141
142  def testSessionInterOpThreadPool(self):
143    config_pb = config_pb2.ConfigProto()
144    pool = config_pb.session_inter_op_thread_pool.add()
145    with session.Session(config=config_pb) as s:
146      inp = constant_op.constant(10.0, name='W1')
147      results = s.run([inp])
148      self.assertAllEqual([10.0], results)
149
150    pool = config_pb.session_inter_op_thread_pool.add()
151    pool.num_threads = 1
152    with session.Session(config=config_pb) as s:
153      inp = constant_op.constant(20.0, name='W2')
154      results = s.run([inp])
155      self.assertAllEqual([20.0], results)
156
157    pool = config_pb.session_inter_op_thread_pool.add()
158    pool.num_threads = 1
159    pool.global_name = 't1'
160    run_options = config_pb2.RunOptions()
161    run_options.inter_op_thread_pool = (
162        len(config_pb.session_inter_op_thread_pool) - 1)
163    with session.Session(config=config_pb) as s:
164      inp = constant_op.constant(30.0, name='W2')
165      results = s.run([inp], options=run_options)
166      self.assertAllEqual([30.0], results)
167
168  def testErrorsReported(self):
169    with session.Session() as s:
170      constant_op.constant(10.0, name='W1')
171      with self.assertRaises(ValueError):
172        s.run('foo:0')
173
174  def testErrorPayload(self):
175    with session.Session():
176      a = array_ops.placeholder(dtypes.float32)
177      with self.assertRaisesOpError(lambda e: e.op == a.op):
178        a.eval()
179
180  def testErrorCodeWithNoNodeDef(self):
181    with session.Session() as s:
182      a = array_ops.placeholder(dtypes.float32, shape=[])
183      b = array_ops.placeholder(dtypes.float32, shape=[])
184      r1 = math_ops.add(a, b)
185
186      def exc_predicate(e):
187        return (e.op is None and e.node_def is None and
188                e.error_code == error_codes_pb2.INVALID_ARGUMENT)
189
190      with self.assertRaisesOpError(exc_predicate):
191        # Run with a bogus handle.
192        s.partial_run('foo', r1, feed_dict={a: 1, b: 2})
193
194  def testErrorBasedOn(self):
195    with session.Session() as sess:
196      a = constant_op.constant(0.0, shape=[2, 3])
197      # NOTE(mrry): The original_op is nonsense, but used here to test that the
198      #   errors are reported correctly.
199      with sess.graph._original_op(a.op):
200        b = array_ops.identity(a, name='id')
201      with sess.graph._original_op(b.op):
202        c = array_ops.placeholder(dtypes.float32)
203
204      def exc_predicate(e):
205        return (e.op == c.op and e.op._original_op == b.op and
206                e.op._original_op._original_op == a.op)
207
208      with self.assertRaisesOpError(exc_predicate):
209        c.eval()
210
211  def testFetchNone(self):
212    with session.Session() as s:
213      a = constant_op.constant(1.0)
214      with self.assertRaises(TypeError):
215        s.run(None)
216      with self.assertRaises(TypeError):
217        s.run([None])
218      with self.assertRaises(TypeError):
219        s.run({'b': None})
220      with self.assertRaises(TypeError):
221        s.run({'a': a, 'b': None})
222
223  def testFetchSingleton(self):
224    with session.Session() as sess:
225      a = constant_op.constant(42.0)
226      res = sess.run(a)
227      self.assertEqual(42.0, res)
228      res = sess.run(a.op)  # An op, not a tensor.
229      self.assertIsNone(res)
230      tensor_runner = sess.make_callable(a)
231      res = tensor_runner()
232      self.assertEqual(42.0, res)
233      op_runner = sess.make_callable(a.op)
234      res = op_runner()
235      self.assertIsNone(res)
236
237  def testFetchSingletonByName(self):
238    with session.Session() as sess:
239      a = constant_op.constant(42.0)
240      res = sess.run(a.name)
241      self.assertEqual(42.0, res)
242      res = sess.run(a.op)  # An op, not a tensor.
243      self.assertIsNone(res)
244
245  def testFetchList(self):
246    with session.Session() as sess:
247      a = constant_op.constant(42.0)
248      b = control_flow_ops.no_op()  # An op, not a tensor.
249      c = constant_op.constant(44.0)
250      v = variables.Variable([54.0])
251      assign = v.assign([63.0])
252      res = sess.run([a, b, c, a.name, assign.op])
253      self.assertIsInstance(res, list)
254      self.assertEqual([42.0, None, 44.0, 42.0, None], res)
255      list_runner = sess.make_callable([a, b, c, a.name, assign.op])
256      res = list_runner()
257      self.assertIsInstance(res, list)
258      self.assertEqual([42.0, None, 44.0, 42.0, None], res)
259
260  def testFetchTuple(self):
261    with session.Session() as sess:
262      a = constant_op.constant(42.0)
263      b = control_flow_ops.no_op()  # An op, not a tensor.
264      c = constant_op.constant(44.0)
265      res = sess.run((a, b, c, a.name))
266      self.assertIsInstance(res, tuple)
267      self.assertEqual((42.0, None, 44.0, 42.0), res)
268      tuple_runner = sess.make_callable((a, b, c, a.name))
269      res = tuple_runner()
270      self.assertIsInstance(res, tuple)
271      self.assertEqual((42.0, None, 44.0, 42.0), res)
272
273  def testFetchNamedTuple(self):
274    # pylint: disable=invalid-name
275    ABC = collections.namedtuple('ABC', ['a', 'b', 'c'])
276    # pylint: enable=invalid-name
277    with session.Session() as sess:
278      a = constant_op.constant(42.0)
279      b = control_flow_ops.no_op()  # An op, not a tensor.
280      c = constant_op.constant(44.0)
281      res = sess.run(ABC(a, b, c))
282      self.assertIsInstance(res, ABC)
283      self.assertEqual(42.0, res.a)
284      self.assertIsNone(res.b)
285      self.assertEqual(44.0, res.c)
286      namedtuple_runner = sess.make_callable(ABC(a, b, c))
287      res = namedtuple_runner()
288      self.assertIsInstance(res, ABC)
289      self.assertEqual(42.0, res.a)
290      self.assertIsNone(res.b)
291      self.assertEqual(44.0, res.c)
292
293  def testFetchDict(self):
294    with session.Session() as sess:
295      a = constant_op.constant(42.0)
296      b = control_flow_ops.no_op()  # An op, not a tensor.
297      c = constant_op.constant(44.0)
298      res = sess.run({'a': a, 'b': b, 'c': c})
299      self.assertIsInstance(res, dict)
300      self.assertEqual(42.0, res['a'])
301      self.assertIsNone(res['b'])
302      self.assertEqual(44.0, res['c'])
303
304  def testFetchOrderedDict(self):
305    with session.Session() as sess:
306      a = constant_op.constant(42.0)
307      b = control_flow_ops.no_op()  # An op, not a tensor.
308      c = constant_op.constant(44.0)
309      res = sess.run(collections.OrderedDict([(3, a), (2, b), (1, c)]))
310      self.assertIsInstance(res, collections.OrderedDict)
311      self.assertEqual([3, 2, 1], list(res.keys()))
312      self.assertEqual(42.0, res[3])
313      self.assertIsNone(res[2])
314      self.assertEqual(44.0, res[1])
315
316  @test_util.run_v1_only('b/120545219')
317  def testFetchAttrs(self):
318    if attr is None:
319      self.skipTest('attr module is unavailable.')
320
321    @attr.s
322    class SampleAttr(object):
323      field1 = attr.ib()
324      field2 = attr.ib()
325
326    val1 = np.array([1.2, 3.4, 5.6])
327    val2 = np.array([[1, 2], [4, 3]])
328    val3 = np.array([10, 20, 30])
329
330    t1 = constant_op.constant(val1)
331    t2 = constant_op.constant(val2)
332
333    sample = SampleAttr(t1, t2)
334    with session.Session() as sess:
335      result = sess.run(sample)
336      self.assertIsInstance(result, SampleAttr)
337      self.assertAllEqual(val1, result.field1)
338      self.assertAllEqual(val2, result.field2)
339
340      result = sess.run(sample, feed_dict={sample.field1: val3})
341      self.assertIsInstance(result, SampleAttr)
342      self.assertAllEqual(val3, result.field1)
343      self.assertAllEqual(val2, result.field2)
344
345  @test_util.run_v1_only('b/120545219')
346  def testFetchNestedAttrs(self):
347    if attr is None:
348      self.skipTest('attr module is unavailable.')
349
350    @attr.s
351    class SampleAttr(object):
352      field0 = attr.ib()
353      field1 = attr.ib()
354
355    v1 = 10
356    v2 = 20
357    v3 = np.float32(1.2)
358    v4 = np.float32(3.4)
359    v5 = np.float64(100.001)
360    v6 = np.float64(-23.451)
361    arr1 = np.array([1.2, 6.7, 3.4])
362    arr2 = np.array([7, 11, 3])
363    sample = SampleAttr(
364        SampleAttr(
365            SampleAttr(constant_op.constant(v1), constant_op.constant(v2)),
366            SampleAttr(constant_op.constant(arr1), constant_op.constant(arr2))),
367        {'A': SampleAttr(constant_op.constant(v3), constant_op.constant(v4)),
368         'B': [SampleAttr(constant_op.constant(v5), constant_op.constant(v6))]})
369
370    with session.Session() as sess:
371      result = sess.run(sample)
372      self.assertIsInstance(result, SampleAttr)
373      self.assertIsInstance(result.field0, SampleAttr)
374      self.assertIsInstance(result.field0.field0, SampleAttr)
375      self.assertIsInstance(result.field0.field1, SampleAttr)
376      self.assertIsInstance(result.field0.field1.field0, np.ndarray)
377      self.assertAllEqual(arr1, result.field0.field1.field0)
378      self.assertIsInstance(result.field0.field1.field1, np.ndarray)
379      self.assertAllEqual(arr2, result.field0.field1.field1)
380      self.assertIsInstance(result.field1, dict)
381      self.assertIn('A', result.field1)
382      self.assertIn('B', result.field1)
383      self.assertIsInstance(result.field1['A'], SampleAttr)
384      self.assertAllEqual(
385          [v3, v4],
386          [result.field1['A'].field0, result.field1['A'].field1])
387      self.assertIsInstance(result.field1['B'], list)
388      self.assertEqual(1, len(result.field1['B']))
389      self.assertIsInstance(result.field1['B'][0], SampleAttr)
390      self.assertAllEqual(
391          [v5, v6],
392          [result.field1['B'][0].field0, result.field1['B'][0].field1])
393
394  def testFetchNestingEmptyOneLevel(self):
395    with session.Session() as sess:
396      a_val = 11.0
397      a = constant_op.constant(a_val)
398
399      res = sess.run([[], tuple(), {}])
400      self.assertIsInstance(res, list)
401      self.assertEqual(3, len(res))
402      self.assertIsInstance(res[0], list)
403      self.assertEqual(0, len(res[0]))
404      self.assertIsInstance(res[1], tuple)
405      self.assertEqual(0, len(res[1]))
406      self.assertIsInstance(res[2], dict)
407      self.assertEqual(0, len(res[2]))
408
409      res = sess.run([[], tuple(), {}, a])
410      self.assertIsInstance(res, list)
411      self.assertEqual(4, len(res))
412      self.assertIsInstance(res[0], list)
413      self.assertEqual(0, len(res[0]))
414      self.assertIsInstance(res[1], tuple)
415      self.assertEqual(0, len(res[1]))
416      self.assertIsInstance(res[2], dict)
417      self.assertEqual(0, len(res[2]))
418      self.assertEqual(a_val, res[3])
419
420  def testFetchNestingOneLevel(self):
421    with session.Session() as sess:
422      # pylint: disable=invalid-name
423      ABC = collections.namedtuple('ABC', ['a', 'b', 'c'])
424      DEFGHI = collections.namedtuple('DEFGHI', ['d', 'e', 'f', 'g', 'h', 'i'])
425      # pylint: enable=invalid-name
426      a_val = 42.0
427      b_val = None
428      c_val = 44.0
429      a = constant_op.constant(a_val)
430      b = control_flow_ops.no_op()  # An op, not a tensor.
431      c = constant_op.constant(c_val)
432      test_dct = {'a': a.name, 'c': c, 'b': b}
433      test_dct_types = [dict, frozendict, defaultdict]
434      # List of lists, tuples, namedtuple, dict, frozendict, and defaultdict
435      res = sess.run([
436          [a, b, c],
437          (a, b, c),
438          ABC(a=a, b=b, c=c),
439          dict(test_dct),
440          frozendict(test_dct),
441          defaultdict(str, test_dct),
442      ])
443      self.assertIsInstance(res, list)
444      self.assertEqual(6, len(res))
445      self.assertIsInstance(res[0], list)
446      self.assertEqual(3, len(res[0]))
447      self.assertEqual(a_val, res[0][0])
448      self.assertEqual(b_val, res[0][1])
449      self.assertEqual(c_val, res[0][2])
450      self.assertIsInstance(res[1], tuple)
451      self.assertEqual(3, len(res[1]))
452      self.assertEqual(a_val, res[1][0])
453      self.assertEqual(b_val, res[1][1])
454      self.assertEqual(c_val, res[1][2])
455      self.assertIsInstance(res[2], ABC)
456      self.assertEqual(a_val, res[2].a)
457      self.assertEqual(b_val, res[2].b)
458      self.assertEqual(c_val, res[2].c)
459      for expected_type, r in zip(test_dct_types, res[3:]):
460        self.assertIsInstance(r, expected_type)
461        self.assertEqual(3, len(r))
462        self.assertEqual(a_val, r['a'])
463        self.assertEqual(b_val, r['b'])
464        self.assertEqual(c_val, r['c'])
465      self.assertEqual(res[5].default_factory, str)
466      # Tuple of lists, tuples, namedtuple, dict, frozendict, and defaultdict
467      res = sess.run(([a, b, c], (a.name, b, c), ABC(a=a, b=b,
468                                                     c=c), dict(test_dct),
469                      frozendict(test_dct), defaultdict(str, test_dct)))
470      self.assertIsInstance(res, tuple)
471      self.assertEqual(6, len(res))
472      self.assertIsInstance(res[0], list)
473      self.assertEqual(3, len(res[0]))
474      self.assertEqual(a_val, res[0][0])
475      self.assertEqual(b_val, res[0][1])
476      self.assertEqual(c_val, res[0][2])
477      self.assertIsInstance(res[1], tuple)
478      self.assertEqual(3, len(res[1]))
479      self.assertEqual(a_val, res[1][0])
480      self.assertEqual(b_val, res[1][1])
481      self.assertEqual(c_val, res[1][2])
482      self.assertIsInstance(res[2], ABC)
483      self.assertEqual(a_val, res[2].a)
484      self.assertEqual(b_val, res[2].b)
485      self.assertEqual(c_val, res[2].c)
486      for expected_type, r in zip(test_dct_types, res[3:]):
487        self.assertIsInstance(r, expected_type)
488        self.assertEqual(3, len(r))
489        self.assertEqual(a_val, r['a'])
490        self.assertEqual(b_val, r['b'])
491        self.assertEqual(c_val, r['c'])
492      self.assertEqual(res[5].default_factory, str)
493
494      # Namedtuple of lists, tuples, namedtuples, dict, frozendict, defaultdict
495      res = sess.run(
496          DEFGHI(
497              d=[a, b, c],
498              e=(a, b, c),
499              f=ABC(a=a.name, b=b, c=c),
500              g=dict(test_dct),
501              h=frozendict(test_dct),
502              i=defaultdict(str, test_dct)))
503      self.assertIsInstance(res, DEFGHI)
504      self.assertIsInstance(res.d, list)
505      self.assertEqual(3, len(res.d))
506      self.assertEqual(a_val, res.d[0])
507      self.assertEqual(b_val, res.d[1])
508      self.assertEqual(c_val, res.d[2])
509      self.assertIsInstance(res.e, tuple)
510      self.assertEqual(3, len(res.e))
511      self.assertEqual(a_val, res.e[0])
512      self.assertEqual(b_val, res.e[1])
513      self.assertEqual(c_val, res.e[2])
514      self.assertIsInstance(res.f, ABC)
515      self.assertEqual(a_val, res.f.a)
516      self.assertEqual(b_val, res.f.b)
517      self.assertEqual(c_val, res.f.c)
518      self.assertIsInstance(res.g, dict)
519      self.assertEqual(3, len(res.g))
520      self.assertEqual(a_val, res.g['a'])
521      self.assertEqual(b_val, res.g['b'])
522      self.assertEqual(c_val, res.g['c'])
523      self.assertIsInstance(res.h, frozendict)
524      self.assertEqual(3, len(res.h))
525      self.assertEqual(a_val, res.h['a'])
526      self.assertEqual(b_val, res.h['b'])
527      self.assertEqual(c_val, res.h['c'])
528      self.assertIsInstance(res.i, defaultdict)
529      self.assertEqual(3, len(res.i))
530      self.assertEqual(a_val, res.i['a'])
531      self.assertEqual(b_val, res.i['b'])
532      self.assertEqual(c_val, res.i['c'])
533      self.assertEqual(res.i.default_factory, str)
534      # Dict of lists, tuples, namedtuples, dict, frozendict, defaultdict
535      res = sess.run({
536          'd': [a, b, c],
537          'e': (a, b, c),
538          'f': ABC(a=a, b=b, c=c),
539          'g': dict(test_dct),
540          'h': frozendict(test_dct),
541          'i': defaultdict(str, test_dct),
542      })
543      self.assertIsInstance(res, dict)
544      self.assertEqual(6, len(res))
545      self.assertIsInstance(res['d'], list)
546      self.assertEqual(3, len(res['d']))
547      self.assertEqual(a_val, res['d'][0])
548      self.assertEqual(b_val, res['d'][1])
549      self.assertEqual(c_val, res['d'][2])
550      self.assertIsInstance(res['e'], tuple)
551      self.assertEqual(3, len(res['e']))
552      self.assertEqual(a_val, res['e'][0])
553      self.assertEqual(b_val, res['e'][1])
554      self.assertEqual(c_val, res['e'][2])
555      self.assertIsInstance(res['f'], ABC)
556      self.assertEqual(a_val, res['f'].a)
557      self.assertEqual(b_val, res['f'].b)
558      self.assertEqual(c_val, res['f'].c)
559      for expected_type, r_key in zip(test_dct_types, ('g', 'h', 'i')):
560        r = res[r_key]
561        self.assertIsInstance(r, expected_type)
562        self.assertEqual(3, len(r))
563        self.assertEqual(a_val, r['a'])
564        self.assertEqual(b_val, r['b'])
565        self.assertEqual(c_val, r['c'])
566      self.assertEqual(res['i'].default_factory, str)
567
568  def testFetchTensorObject(self):
569    with session.Session() as s:
570      a = constant_op.constant(1.0, shape=[1, 2])
571      b = constant_op.constant(2.0, shape=[2, 3])
572      c = math_ops.matmul(a, b)
573      results_with_list = s.run([c])
574      self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_list[0])
575      results_with_single = s.run(c)
576      self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_single)
577      results_with_get = c.eval()
578      self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_get)
579      a_val, b_val = s.run([a, b])  # Test multiple fetches.
580      self.assertAllEqual([[1.0, 1.0]], a_val)
581      self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], b_val)
582      results_with_dict = s.run({'a': [a], 'b': b, 'z': [a, b]})
583      self.assertAllEqual([[1.0, 1.0]], results_with_dict['a'][0])
584      self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]],
585                          results_with_dict['b'])
586      self.assertAllEqual(results_with_dict['a'][0], results_with_dict['z'][0])
587      self.assertAllEqual(results_with_dict['b'], results_with_dict['z'][1])
588
589      # Test nested structures
590      results_with_nested_list = s.run([[[a, b], b], a, [a, b]])
591      self.assertAllEqual([[1.0, 1.0]], results_with_nested_list[0][0][0])
592      self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]],
593                          results_with_nested_list[0][0][1])
594      self.assertAllEqual(results_with_nested_list[0][0][0],
595                          results_with_nested_list[1])
596      self.assertAllEqual(results_with_nested_list[1],
597                          results_with_nested_list[2][0])
598      self.assertAllEqual(results_with_nested_list[0][0][1],
599                          results_with_nested_list[0][1])
600      self.assertAllEqual(results_with_nested_list[0][1],
601                          results_with_nested_list[2][1])
602
603  def testFetchScalar(self):
604    with session.Session() as s:
605      for scalar in np.int32, np.int64, np.float16, np.float32, np.float64:
606        x = scalar(7)
607        y = scalar(8)
608        tf_x = constant_op.constant(x, shape=[])
609        tf_y = constant_op.constant(y)
610        tf_xy = math_ops.add(tf_x, tf_y)
611        # Single fetch
612        xy = s.run(tf_xy)
613        self.assertEqual(scalar, type(xy))
614        self.assertEqual(x + y, xy)
615        # List fetch
616        xy, = s.run([tf_xy])
617        self.assertEqual(scalar, type(xy))
618        self.assertEqual(x + y, xy)
619        # Dict fetch
620        xy = s.run({'xy': tf_xy})['xy']
621        self.assertEqual(scalar, type(xy))
622        self.assertEqual(x + y, xy)
623        # Nested list fetch
624        xy = s.run([[[tf_xy]], tf_xy, [tf_xy]])
625        self.assertAllEqual(xy, [[[x + y]], x + y, [x + y]])
626        self.assertEqual(scalar, type(xy[0][0][0]))
627        self.assertEqual(scalar, type(xy[1]))
628        self.assertEqual(scalar, type(xy[2][0]))
629
630  def testFetchOperationObject(self):
631    with session.Session() as s:
632      a = constant_op.constant(1.0, shape=[1, 2])
633      v = variables.Variable(a, name='testFetchOperationObject_v')
634      s.run(v.initializer)
635      v_val = s.run(v)
636      self.assertAllEqual([[1.0, 1.0]], v_val)
637
638  def testFetchSparseTensor(self):
639    with session.Session() as s:
640      indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
641      values = np.array([1.0, 2.0]).astype(np.float32)
642      shape = np.array([7, 9, 2]).astype(np.int64)
643      sp = sparse_tensor.SparseTensor(
644          constant_op.constant(indices), constant_op.constant(values),
645          constant_op.constant(shape))
646      # Single fetch, use as tuple
647      sp_out = s.run(sp)
648      indices_out, values_out, shape_out = sp_out
649      self.assertAllEqual(indices_out, indices)
650      self.assertAllEqual(values_out, values)
651      self.assertAllEqual(shape_out, shape)
652      # Single fetch, use as SparseTensorValue
653      sp_out = s.run(sp)
654      self.assertAllEqual(sp_out.indices, indices)
655      self.assertAllEqual(sp_out.values, values)
656      self.assertAllEqual(sp_out.dense_shape, shape)
657      # Tuple fetch, use as tuple
658      indices_out, values_out, shape_out = s.run(sp)
659      self.assertAllEqual(indices_out, indices)
660      self.assertAllEqual(values_out, values)
661      self.assertAllEqual(shape_out, shape)
662      # List fetch, use as tuple
663      (indices_out, values_out, shape_out), = s.run([sp])
664      self.assertAllEqual(indices_out, indices)
665      self.assertAllEqual(values_out, values)
666      self.assertAllEqual(shape_out, shape)
667      # List fetch, use as SparseTensorValue
668      sp_out, = s.run([sp])
669      self.assertAllEqual(sp_out.indices, indices)
670      self.assertAllEqual(sp_out.values, values)
671      self.assertAllEqual(sp_out.dense_shape, shape)
672      # Dict fetch (single value), use as tuple
673      indices_out, values_out, shape_out = s.run({'sp': sp})['sp']
674      self.assertAllEqual(indices_out, indices)
675      self.assertAllEqual(values_out, values)
676      self.assertAllEqual(shape_out, shape)
677      # Dict fetch (list value), use as tuple
678      (indices_out, values_out, shape_out), = s.run({'sp': [sp]})['sp']
679      self.assertAllEqual(indices_out, indices)
680      self.assertAllEqual(values_out, values)
681      self.assertAllEqual(shape_out, shape)
682      # Dict fetch, use as SparseTensorValue
683      sp_out = s.run({'sp': sp})['sp']
684      self.assertAllEqual(sp_out.indices, indices)
685      self.assertAllEqual(sp_out.values, values)
686      self.assertAllEqual(sp_out.dense_shape, shape)
687      # Nested list fetch use as tuple
688      sp_out = s.run([[[sp]], sp])
689      indices_out, values_out, shape_out = sp_out[0][0][0]
690      self.assertAllEqual(indices_out, indices)
691      self.assertAllEqual(values_out, values)
692      self.assertAllEqual(shape_out, shape)
693      indices_out, values_out, shape_out = sp_out[1]
694      self.assertAllEqual(indices_out, indices)
695      self.assertAllEqual(values_out, values)
696      self.assertAllEqual(shape_out, shape)
697      # Nested list fetch, use as SparseTensorValue
698      sp_out = s.run([[[sp]], sp])
699      self.assertAllEqual(sp_out[0][0][0].indices, indices)
700      self.assertAllEqual(sp_out[0][0][0].values, values)
701      self.assertAllEqual(sp_out[0][0][0].dense_shape, shape)
702      self.assertAllEqual(sp_out[1].indices, indices)
703      self.assertAllEqual(sp_out[1].values, values)
704      self.assertAllEqual(sp_out[1].dense_shape, shape)
705
706  def testFeedSparseTensor(self):
707    with session.Session() as s:
708      indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
709      values = np.array([1.0, 2.0]).astype(np.float32)
710      shape = np.array([7, 9, 2]).astype(np.int64)
711      sp = sparse_tensor.SparseTensor(
712          array_ops.placeholder(dtype=np.int64, shape=(2, 3)),
713          array_ops.placeholder(dtype=np.float32, shape=(2,)),
714          array_ops.placeholder(dtype=np.int64, shape=(3,)),
715      )
716      sp_indices = array_ops.identity(sp.indices)
717      sp_values = array_ops.identity(sp.values)
718      sp_shape = array_ops.identity(sp.dense_shape)
719      sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape)
720      # Feed with tuple
721      indices_out, values_out, shape_out = s.run(
722          [sp_indices, sp_values, sp_shape], {
723              sp: (indices, values, shape)
724          })
725      self.assertAllEqual(indices_out, indices)
726      self.assertAllEqual(values_out, values)
727      self.assertAllEqual(shape_out, shape)
728      # Feed with tuple, fetch sp directly
729      sp_out = s.run(sp, {sp: (indices, values, shape)})
730      self.assertAllEqual(sp_out.indices, indices)
731      self.assertAllEqual(sp_out.values, values)
732      self.assertAllEqual(sp_out.dense_shape, shape)
733      # Feed with SparseTensorValue
734      indices_out, values_out, shape_out = s.run(
735          [sp_indices, sp_values, sp_shape], {
736              sp: sparse_tensor.SparseTensorValue(indices, values, shape)
737          })
738      self.assertAllEqual(indices_out, indices)
739      self.assertAllEqual(values_out, values)
740      self.assertAllEqual(shape_out, shape)
741      # Feed with SparseTensorValue, fetch SparseTensorValue
742      sp2_out = s.run(sp2, {
743          sp: sparse_tensor.SparseTensorValue(indices, values, shape)
744      })
745      self.assertAllEqual(sp2_out.indices, indices)
746      self.assertAllEqual(sp2_out.values, values)
747      self.assertAllEqual(sp2_out.dense_shape, shape)
748      # Feed SparseTensorValue and fetch sp directly.
749      sp_out = s.run(sp, {
750          sp: sparse_tensor.SparseTensorValue(indices, values, shape)
751      })
752      self.assertAllEqual(sp_out.indices, indices)
753      self.assertAllEqual(sp_out.values, values)
754      self.assertAllEqual(sp_out.dense_shape, shape)
755
756  def testFeedSparsePlaceholder(self):
757    with session.Session() as s:
758      indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
759      values = np.array([1.0, 2.0]).astype(np.float32)
760      shape = np.array([7, 9, 2]).astype(np.int64)
761      sp = array_ops.sparse_placeholder(dtype=np.float32, name='placeholder1')
762      sp_indices = array_ops.identity(sp.indices)
763      sp_values = array_ops.identity(sp.values)
764      sp_shape = array_ops.identity(sp.dense_shape)
765      sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape)
766      # Feed with tuple
767      indices_out, values_out, shape_out = s.run(
768          [sp_indices, sp_values, sp_shape], {
769              sp: (indices, values, shape)
770          })
771      self.assertAllEqual(indices_out, indices)
772      self.assertAllEqual(values_out, values)
773      self.assertAllEqual(shape_out, shape)
774      # Feed with SparseTensorValue
775      indices_out, values_out, shape_out = s.run(
776          [sp_indices, sp_values, sp_shape], {
777              sp: sparse_tensor.SparseTensorValue(indices, values, shape)
778          })
779      self.assertAllEqual(indices_out, indices)
780      self.assertAllEqual(values_out, values)
781      self.assertAllEqual(shape_out, shape)
782      # Feed with SparseTensorValue, fetch SparseTensorValue
783      sp2_out = s.run(sp2, {
784          sp: sparse_tensor.SparseTensorValue(indices, values, shape)
785      })
786      self.assertAllEqual(sp2_out.indices, indices)
787      self.assertAllEqual(sp2_out.values, values)
788      self.assertAllEqual(sp2_out.dense_shape, shape)
789
790  def testFeedSparsePlaceholderPartialShape(self):
791    with session.Session() as s:
792      indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
793      values = np.array([1.0, 2.0]).astype(np.float32)
794      shape = np.array([7, 9, 2]).astype(np.int64)
795      sp = array_ops.sparse_placeholder(
796          shape=[None, 9, 2], dtype=np.float32, name='placeholder1')
797      sp_indices = array_ops.identity(sp.indices)
798      sp_values = array_ops.identity(sp.values)
799      sp_shape = array_ops.identity(sp.dense_shape)
800      sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape)
801      # Feed with tuple
802      indices_out, values_out, shape_out = s.run(
803          [sp_indices, sp_values, sp_shape], {
804              sp: (indices, values, shape)
805          })
806      self.assertAllEqual(indices_out, indices)
807      self.assertAllEqual(values_out, values)
808      self.assertAllEqual(shape_out, shape)
809      # Feed with SparseTensorValue
810      indices_out, values_out, shape_out = s.run(
811          [sp_indices, sp_values, sp_shape], {
812              sp: sparse_tensor.SparseTensorValue(indices, values, shape)
813          })
814      self.assertAllEqual(indices_out, indices)
815      self.assertAllEqual(values_out, values)
816      self.assertAllEqual(shape_out, shape)
817      # Feed with SparseTensorValue, fetch SparseTensorValue
818      sp2_out = s.run(sp2, {
819          sp: sparse_tensor.SparseTensorValue(indices, values, shape)
820      })
821      self.assertAllEqual(sp2_out.indices, indices)
822      self.assertAllEqual(sp2_out.values, values)
823      self.assertAllEqual(sp2_out.dense_shape, shape)
824
825  def testFeedSparsePlaceholderConstantShape(self):
826    with session.Session() as s:
827      indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
828      values = np.array([1.0, 2.0]).astype(np.float32)
829      shape = np.array([7, 9, 2]).astype(np.int64)
830      sp = array_ops.sparse_placeholder(
831          dtype=np.float32, shape=shape, name='placeholder1')
832      self.assertAllEqual(sp.dense_shape.eval(session=s), shape)
833      self.assertAllEqual(tensor_util.constant_value(sp.shape), shape)
834      sp_indices = array_ops.identity(sp.indices)
835      sp_values = array_ops.identity(sp.values)
836      sp_shape = array_ops.identity(sp.dense_shape)
837      # Feed with tuple
838      indices_out, values_out, shape_out = s.run(
839          [sp_indices, sp_values, sp_shape], {
840              sp: (indices, values)
841          })
842      self.assertAllEqual(indices_out, indices)
843      self.assertAllEqual(values_out, values)
844      self.assertAllEqual(shape_out, shape)
845
846  def testFetchIndexedSlices(self):
847    with session.Session() as s:
848      indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
849      values = np.array([1.0, 2.0]).astype(np.float32)
850      dense_shape = np.array([7, 9, 2]).astype(np.int64)
851      ind = indexed_slices.IndexedSlices(
852          constant_op.constant(values), constant_op.constant(indices),
853          constant_op.constant(dense_shape))
854      # Single fetch, use as tuple
855      ind_out = s.run(ind)
856      values_out, indices_out, dense_shape_out = ind_out
857      self.assertAllEqual(values_out, values)
858      self.assertAllEqual(indices_out, indices)
859      self.assertAllEqual(dense_shape_out, dense_shape)
860      # Single fetch, use as IndexedSlicesValue
861      ind_out = s.run(ind)
862      self.assertAllEqual(ind_out.values, values)
863      self.assertAllEqual(ind_out.indices, indices)
864      self.assertAllEqual(ind_out.dense_shape, dense_shape)
865      # Tuple fetch, use as tuple
866      values_out, indices_out, dense_shape_out = s.run(ind)
867      self.assertAllEqual(values_out, values)
868      self.assertAllEqual(indices_out, indices)
869      self.assertAllEqual(dense_shape_out, dense_shape)
870      # List fetch, use as tuple
871      (values_out, indices_out, dense_shape_out), = s.run([ind])
872      self.assertAllEqual(values_out, values)
873      self.assertAllEqual(indices_out, indices)
874      self.assertAllEqual(dense_shape_out, dense_shape)
875      # List fetch, use as IndexedSlicesValue
876      ind_out, = s.run([ind])
877      self.assertAllEqual(ind_out.values, values)
878      self.assertAllEqual(ind_out.indices, indices)
879      self.assertAllEqual(ind_out.dense_shape, dense_shape)
880
881  def testFeedIndexedSlices(self):
882    with session.Session() as s:
883      values = np.array([1.0, 2.0]).astype(np.float32)
884      indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
885      dense_shape = np.array([7, 9, 2]).astype(np.int64)
886      ind = indexed_slices.IndexedSlices(
887          array_ops.placeholder(dtype=np.float32, shape=(2,)),
888          array_ops.placeholder(dtype=np.int64, shape=(2, 3)),
889          array_ops.placeholder(dtype=np.int64, shape=(3,)),
890      )
891      ind_values = array_ops.identity(ind.values)
892      ind_indices = array_ops.identity(ind.indices)
893      ind_dense_shape = array_ops.identity(ind.dense_shape)
894      ind2 = indexed_slices.IndexedSlices(ind_values, ind_indices,
895                                          ind_dense_shape)
896      # Feed with tuple
897      values_out, indices_out, dense_shape_out = s.run(
898          [ind_values, ind_indices, ind_dense_shape], {
899              ind: (values, indices, dense_shape)
900          })
901      self.assertAllEqual(values_out, values)
902      self.assertAllEqual(indices_out, indices)
903      self.assertAllEqual(dense_shape_out, dense_shape)
904      # Feed with IndexedSlicesValue
905      values_out, indices_out, dense_shape_out = s.run([
906          ind_values, ind_indices, ind_dense_shape
907      ], {ind: indexed_slices.IndexedSlicesValue(values, indices, dense_shape)})
908      self.assertAllEqual(values_out, values)
909      self.assertAllEqual(indices_out, indices)
910      self.assertAllEqual(dense_shape_out, dense_shape)
911      # Feed with IndexedSlicesValue, fetch IndexedSlicesValue
912      ind2_out = s.run(ind2, {
913          ind: indexed_slices.IndexedSlicesValue(values, indices, dense_shape)
914      })
915      self.assertAllEqual(ind2_out.values, values)
916      self.assertAllEqual(ind2_out.indices, indices)
917      self.assertAllEqual(ind2_out.dense_shape, dense_shape)
918
919  def testFetchIndexedSlicesWithoutDenseShape(self):
920    with session.Session() as s:
921      indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
922      values = np.array([1.0, 2.0]).astype(np.float32)
923      dense_shape = None
924      ind = indexed_slices.IndexedSlices(
925          constant_op.constant(values), constant_op.constant(indices), None)
926      # Single fetch, use as tuple
927      ind_out = s.run(ind)
928      values_out, indices_out, dense_shape_out = ind_out
929      self.assertAllEqual(values_out, values)
930      self.assertAllEqual(indices_out, indices)
931      self.assertAllEqual(dense_shape_out, dense_shape)
932      # Single fetch, use as IndexedSlicesValue
933      ind_out = s.run(ind)
934      self.assertAllEqual(ind_out.values, values)
935      self.assertAllEqual(ind_out.indices, indices)
936      self.assertAllEqual(ind_out.dense_shape, dense_shape)
937      # Tuple fetch, use as tuple
938      values_out, indices_out, dense_shape_out = s.run(ind)
939      self.assertAllEqual(values_out, values)
940      self.assertAllEqual(indices_out, indices)
941      self.assertAllEqual(dense_shape_out, dense_shape)
942      # List fetch, use as tuple
943      (values_out, indices_out, dense_shape_out), = s.run([ind])
944      self.assertAllEqual(values_out, values)
945      self.assertAllEqual(indices_out, indices)
946      self.assertAllEqual(dense_shape_out, dense_shape)
947      # List fetch, use as IndexedSlicesValue
948      ind_out, = s.run([ind])
949      self.assertAllEqual(ind_out.values, values)
950      self.assertAllEqual(ind_out.indices, indices)
951      self.assertAllEqual(ind_out.dense_shape, dense_shape)
952
953  def testFeedIndexedSlicesWithoutDenseShape(self):
954    with session.Session() as s:
955      values = np.array([1.0, 2.0]).astype(np.float32)
956      indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
957      dense_shape = None
958      ind = indexed_slices.IndexedSlices(
959          array_ops.placeholder(dtype=np.float32, shape=(2,)),
960          array_ops.placeholder(dtype=np.int64, shape=(2, 3)), None)
961      ind_values = array_ops.identity(ind.values)
962      ind_indices = array_ops.identity(ind.indices)
963      ind2 = indexed_slices.IndexedSlices(ind_values, ind_indices)
964      # Feed with tuple
965      values_out, indices_out = s.run([ind_values, ind_indices], {
966          ind: (values, indices)
967      })
968      self.assertAllEqual(values_out, values)
969      self.assertAllEqual(indices_out, indices)
970      # Feed with IndexedSlicesValue
971      values_out, indices_out = s.run([ind_values, ind_indices], {
972          ind: indexed_slices.IndexedSlicesValue(values, indices, dense_shape)
973      })
974      self.assertAllEqual(values_out, values)
975      self.assertAllEqual(indices_out, indices)
976      # Feed with IndexedSlicesValue, fetch IndexedSlicesValue
977      ind2_out = s.run(ind2, {
978          ind: indexed_slices.IndexedSlicesValue(values, indices, dense_shape)
979      })
980      self.assertAllEqual(ind2_out.values, values)
981      self.assertAllEqual(ind2_out.indices, indices)
982      self.assertAllEqual(ind2_out.dense_shape, dense_shape)
983
984  def testExtendWithStatelessOperations(self):
985    with session.Session() as s:
986      a = constant_op.constant(1.0, shape=[1, 2])
987      b = constant_op.constant(2.0, shape=[2, 3])
988      c = math_ops.matmul(a, b)
989      c_val = s.run(c)
990      self.assertAllEqual([[4.0, 4.0, 4.0]], c_val)
991      d = constant_op.constant([1.0, 2.0, 3.0], shape=[3, 1])
992      e = math_ops.matmul(c, d)
993      # Extend will happen here.
994      e_val = s.run(e)
995      self.assertAllEqual([[24.0]], e_val)
996
997  def testExtendWithStatefulOperations(self):
998    with session.Session() as s:
999      a = constant_op.constant(1.0, shape=[1, 2])
1000      b = constant_op.constant(2.0, shape=[2, 3])
1001      c = math_ops.matmul(a, b)
1002      v = variables.Variable(c, name='testExtendWithStatefulOperations_v')
1003      v.initializer.run()
1004      v_val = v.eval()
1005      self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
1006      d = constant_op.constant(3.0, shape=[2, 3])
1007      e = math_ops.matmul(a, d)
1008      assign_e_to_v = state_ops.assign(v, e)
1009      # Extend will happen here.
1010      e_val = e.eval()
1011      self.assertAllEqual([[6.0, 6.0, 6.0]], e_val)
1012      v_val = v.eval()
1013      self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
1014      s.run(assign_e_to_v)
1015      v_val = v.eval()
1016      self.assertAllEqual([[6.0, 6.0, 6.0]], v_val)
1017
1018  def testExtendWithGroupBy(self):
1019    with session.Session() as s:
1020      a = constant_op.constant(1.0, shape=[1, 2])
1021      p = variables.Variable(a, name='testExtendWithGroupBy_p')
1022      a_val = a.eval()  # Force an Extend after this op.
1023      self.assertAllEqual([[1.0, 1.0]], a_val)
1024
1025      b = constant_op.constant(2.0, shape=[1, 2])
1026      q = variables.Variable(b, name='testExtendWithGroupBy_q')
1027      # Extend will happen here.
1028      init = control_flow_ops.group(p.initializer, q.initializer)
1029      s.run(init)
1030      p_val, q_val = s.run([p, q])
1031
1032      self.assertAllEqual([[1.0, 1.0]], p_val)
1033      self.assertAllEqual([[2.0, 2.0]], q_val)
1034
1035  def testTensorGetMethod(self):
1036    with session.Session():
1037      a = constant_op.constant(1.0, shape=[1, 2])
1038      b = constant_op.constant(2.0, shape=[2, 3])
1039      c = math_ops.matmul(a, b)
1040
1041      c_val = c.eval()
1042      self.assertAllEqual([[4.0, 4.0, 4.0]], c_val)
1043
1044      fed_c_val = c.eval(feed_dict={a.name: [[4.0, 4.0]]})
1045      self.assertAllEqual([[16.0, 16.0, 16.0]], fed_c_val)
1046
1047  @test_util.run_v1_only('b/120545219')
1048  def testOperationRunMethod(self):
1049    with session.Session():
1050      a = constant_op.constant(1.0, shape=[1, 2])
1051      b = constant_op.constant(2.0, shape=[1, 2], name='b')
1052      v = variables.VariableV1(a, a.dtype)
1053      assign_a_to_v = state_ops.assign(v, a)
1054
1055      assign_a_to_v.eval()
1056
1057      v_val = v.eval()
1058      self.assertAllEqual([[1.0, 1.0]], v_val)
1059
1060      assign_b_to_v = state_ops.assign(v, b)
1061
1062      assign_b_to_v.eval()
1063      v_val = v.eval()
1064      self.assertAllEqual([[2.0, 2.0]], v_val)
1065
1066      assign_b_to_v.eval(feed_dict={'b:0': [[3.0, 3.0]]})
1067      v_val = v.eval()
1068      self.assertAllEqual([[3.0, 3.0]], v_val)
1069
1070  def testDefaultGraph(self):
1071    with session.Session() as s:
1072      self.assertEqual(ops.get_default_graph(), s.graph)
1073      a = constant_op.constant(1.0, shape=[1, 2])
1074      b = constant_op.constant(2.0, shape=[2, 3])
1075      self.assertEqual(ops.get_default_graph(), a.graph)
1076      self.assertEqual(ops.get_default_graph(), b.graph)
1077      c = math_ops.matmul(a, b)
1078      v = variables.Variable(c, name='testDefaultGraph_v')
1079      v.initializer.run()
1080      v_val = v.eval()
1081      self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
1082      d = constant_op.constant(3.0, shape=[2, 3])
1083      e = math_ops.matmul(a, d)
1084      assign_e_to_v = state_ops.assign(v, e)
1085      e_val = e.eval()
1086      self.assertAllEqual([[6.0, 6.0, 6.0]], e_val)
1087      v_val = v.eval()
1088      self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
1089      s.run(assign_e_to_v)
1090      v_val = v.eval()
1091      self.assertAllEqual([[6.0, 6.0, 6.0]], v_val)
1092      self.assertEqual(ops.get_default_graph(), s.graph)
1093
1094  def _testDefaultGraphInThread(self, constructed_event, continue_event, i):
1095    with session.Session() as s:
1096      self.assertEqual(ops.get_default_graph(), s.graph)
1097      a = constant_op.constant(1.0, shape=[1, 2])
1098      b = constant_op.constant(2.0, shape=[2, 3])
1099      c = math_ops.matmul(a, b)
1100      v = variables.Variable(c, name='var_%d' % i)
1101
1102      # Block here until all threads have constructed their graph.
1103      constructed_event.set()
1104      continue_event.wait()
1105
1106      assign_c_to_v = state_ops.assign(v, c)
1107      v.initializer.run()
1108      assign_c_to_v.eval()
1109      v_val = v.eval()
1110      self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
1111      d = constant_op.constant(3.0, shape=[2, 3])
1112      e = math_ops.matmul(a, d)
1113      assign_e_to_v = state_ops.assign(v, e)
1114      e_val = e.eval()
1115      self.assertAllEqual([[6.0, 6.0, 6.0]], e_val)
1116      v_val = v.eval()
1117      self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
1118      s.run(assign_e_to_v)
1119      v_val = v.eval()
1120      self.assertAllEqual([[6.0, 6.0, 6.0]], v_val)
1121      self.assertEqual(ops.get_default_graph(), s.graph)
1122
1123  def testDefaultGraphWithThreads(self):
1124    # Fork ten threads that use their thread-local default graph.
1125    threads = []
1126    constructed_events = [threading.Event() for _ in range(10)]
1127    continue_event = threading.Event()
1128    for i, constructed_event in enumerate(constructed_events):
1129      t = self.checkedThread(
1130          target=self._testDefaultGraphInThread,
1131          args=(constructed_event, continue_event, i))
1132      threads.append(t)
1133    for t in threads:
1134      t.start()
1135    for constructed_event in constructed_events:
1136      constructed_event.wait()
1137    continue_event.set()
1138    for t in threads:
1139      t.join()
1140
1141  def testParallelRun(self):
1142    with session.Session() as sess:
1143      c = constant_op.constant(5.0)
1144      ev = threading.Event()
1145
1146      def run_step():
1147        ev.wait()
1148        val = c.eval(session=sess)
1149        self.assertEqual(val, 5.0)
1150
1151      threads = [self.checkedThread(target=run_step) for _ in range(100)]
1152      for t in threads:
1153        t.start()
1154      ev.set()
1155      for t in threads:
1156        t.join()
1157
1158  @staticmethod
1159  def _build_graph():
1160    time.sleep(random.random() * 0.1)
1161    # Do some graph construction. Try to exercise non-trivial paths.
1162    graph = ops.get_default_graph()
1163    gdef = None
1164    for _ in range(10):
1165      x = array_ops.placeholder(dtype=dtypes.float32)
1166      with ops.colocate_with(x):
1167        y = array_ops.placeholder(dtype=dtypes.float32)
1168      with ops.device('/cpu:0'):
1169        z = control_flow_ops.while_loop(
1170            lambda x, y: x < 10, lambda x, y: (x + 1, x * y), [x, y])
1171      with graph._attr_scope({'_a': attr_value_pb2.AttrValue(b=False)}):
1172        gradients_impl.gradients(z, [x, y])
1173        if gdef is None:
1174          gdef = graph.as_graph_def()
1175        else:
1176          importer.import_graph_def(gdef, name='import')
1177
1178  @test_util.run_v1_only('b/120545219')
1179  def testParallelRunAndSingleBuild(self):
1180    with session.Session() as sess:
1181      c = constant_op.constant(5.0)
1182      stop = threading.Event()
1183
1184      def run_loop():
1185        while not stop.is_set():
1186          time.sleep(random.random() * 0.1)
1187          self.assertEqual(sess.run(c), 5.0)
1188
1189      threads = [self.checkedThread(target=run_loop) for _ in range(10)]
1190      for t in threads:
1191        t.start()
1192
1193      SessionTest._build_graph()
1194
1195      stop.set()
1196      for t in threads:
1197        t.join()
1198
1199  @test_util.run_v1_only('b/120545219')
1200  def testParallelRunAndParallelBuild(self):
1201    with session.Session() as sess:
1202      c = constant_op.constant(5.0)
1203      stop = threading.Event()
1204
1205      def run_loop():
1206        while not stop.is_set():
1207          time.sleep(random.random() * 0.1)
1208          self.assertEqual(sess.run(c), 5.0)
1209
1210      run_threads = [self.checkedThread(target=run_loop) for _ in range(10)]
1211      for t in run_threads:
1212        t.start()
1213
1214      build_threads = [self.checkedThread(target=SessionTest._build_graph)
1215                       for _ in range(10)]
1216      for t in build_threads:
1217        t.start()
1218      for t in build_threads:
1219        t.join()
1220
1221      # Let the run_threads run until the build threads are finished.
1222      stop.set()
1223      for t in run_threads:
1224        t.join()
1225
1226  def testRunFeedDict(self):
1227    with session.Session() as s:
1228      x = array_ops.zeros([2])
1229
1230      y = s.run(2 * x, feed_dict={x: np.ones(2).astype(np.float32)})
1231      self.assertAllEqual(y, 2 * np.ones(2))
1232
1233      y = s.run(2 * x, feed_dict={x.name: np.ones(2).astype(np.float32)})
1234      self.assertAllEqual(y, 2 * np.ones(2))
1235
1236      y = s.run(2 * x, feed_dict={x: [1, 1]})
1237      assert (y == 2 * np.ones(2)).all()
1238
1239      # Test nested tuple keys
1240      z = (((array_ops.zeros([2]),),), array_ops.zeros([2]),
1241           (array_ops.zeros([2]),))
1242      result = [z[0][0][0] * 2, z[1] * 2, z[2][0] * 2]
1243      values = (((np.array([1, 1]),),), np.array([2, 2]), (np.array([3, 3]),))
1244      result_value = s.run(result, feed_dict={z: values})
1245      self.assertAllEqual(result_value[0], 2 * np.ones(2))
1246      self.assertAllEqual(result_value[1], 2 * np.array([2, 2]))
1247      self.assertAllEqual(result_value[2], 2 * np.array([3, 3]))
1248
1249  def testGraphDef(self):
1250    with session.Session() as sess:
1251      self.assertProtoEquals('versions { producer: %d min_consumer: %d }' %
1252                             (versions.GRAPH_DEF_VERSION,
1253                              versions.GRAPH_DEF_VERSION_MIN_CONSUMER),
1254                             sess.graph_def)
1255      c = constant_op.constant(5.0, name='c')
1256      self.assertEqual(len(sess.graph_def.node), 1)
1257      d = constant_op.constant(6.0, name='d')
1258      self.assertEqual(len(sess.graph_def.node), 2)
1259      self.assertAllEqual(c, 5.0)
1260      self.assertAllEqual(d, 6.0)
1261      e = constant_op.constant(7.0, name='e')
1262      self.assertEqual(len(sess.graph_def.node), 3)
1263      self.assertAllEqual(e, 7.0)
1264
1265  def testUseAfterClose(self):
1266    with session.Session() as sess:
1267      c = constant_op.constant(5.0)
1268      self.assertAllEqual(sess.run(c), 5.0)
1269    with self.assertRaisesWithPredicateMatch(
1270        RuntimeError, lambda e: 'Attempted to use a closed Session.' in str(e)):
1271      sess.run(c)
1272
1273  def testUseAfterCloseConcurrent(self):
1274    with session.Session() as sess:
1275      c = constant_op.constant(5.0)
1276      self.assertAllEqual(sess.run(c), 5.0)
1277
1278      def update_thread():
1279        with self.assertRaisesWithPredicateMatch(
1280            RuntimeError,
1281            lambda e: 'Attempted to use a closed Session.' in str(e)):
1282          while True:
1283            sess.run(c)
1284
1285      t = threading.Thread(target=update_thread)
1286      t.start()
1287      time.sleep(0.1)
1288      sess.close()
1289      t.join()
1290
1291  def testUseEmptyGraph(self):
1292    with session.Session() as sess:
1293      with self.assertRaisesRegex(RuntimeError, 'The Session graph is empty.'):
1294        sess.run([])
1295      with self.assertRaisesRegex(RuntimeError, 'The Session graph is empty.'):
1296        sess.run(())
1297      with self.assertRaisesRegex(RuntimeError, 'The Session graph is empty.'):
1298        sess.run({})
1299
1300  @test_util.run_v1_only('b/120545219')
1301  def testNotEntered(self):
1302    # pylint: disable=protected-access
1303    self.assertIsNone(ops._default_session_stack.get_default())
1304    # pylint: enable=protected-access
1305    with ops.device('/cpu:0'):
1306      sess = session.Session()
1307      c_1 = constant_op.constant(5.0)
1308      with sess.graph.as_default():
1309        c_2 = constant_op.constant(5.0)
1310      self.assertEqual(c_1.graph, c_2.graph)
1311      self.assertEqual(sess.run(c_2), 5.0)
1312      with self.assertRaisesWithPredicateMatch(
1313          ValueError, lambda e: 'No default session is registered.' in str(e)):
1314        c_2.eval()
1315
1316  @test_util.run_v1_only('b/120545219')
1317  def testInteractive(self):
1318    with ops.device('/cpu:0'):
1319      sess = session.InteractiveSession()
1320      a = constant_op.constant(1.0, shape=[1, 2])
1321      b = constant_op.constant(2.0, shape=[2, 3])
1322      c = math_ops.matmul(a, b)
1323      self.assertAllEqual([[4.0, 4.0, 4.0]], c)
1324      d = constant_op.constant([1.0, 2.0, 3.0], shape=[3, 1])
1325      e = math_ops.matmul(c, d)
1326      self.assertAllEqual([[24.0]], e)
1327      sess.close()
1328
1329  @test_util.run_v1_only('b/120545219')
1330  def testMultipleInteractiveSessionsWarning(self):
1331    # Reinitialize the global state to ensure that the expected warnings will
1332    # be emitted.
1333    session.InteractiveSession._active_session_count = 0  # pylint: disable=protected-access
1334
1335    sess = session.InteractiveSession()
1336    sess.run(constant_op.constant(4.0))  # Run so that the session is "opened".
1337    sess.close()
1338    # Opening and closing interactive sessions serially should not warn.
1339    with warnings.catch_warnings(record=True) as w:
1340      sess = session.InteractiveSession()
1341      sess.close()
1342    self.assertEqual(0, len(w))
1343
1344    with warnings.catch_warnings(record=True) as w:
1345      sess = session.InteractiveSession()
1346    self.assertEqual(0, len(w))
1347    with warnings.catch_warnings(record=True) as w:
1348      sess2 = session.InteractiveSession()
1349    self.assertEqual(1, len(w))
1350    self.assertIn('An interactive session is already active. This can cause '
1351                  'out-of-memory errors in some cases. You must explicitly '
1352                  'call `InteractiveSession.close()` to release resources '
1353                  'held by the other session(s).', str(w[0].message))
1354    sess2.close()
1355    sess.close()
1356
1357  @test_util.run_v1_only('b/120545219')
1358  def testInteractivePlacePrunedGraph(self):
1359    sess = session.InteractiveSession()
1360
1361    # Build a graph that has a bad op in it (no kernel).
1362    #
1363    # This test currently does not link in any GPU kernels,
1364    # which is why placing this is invalid.  If at some point
1365    # GPU kernels are added to this test, some other different
1366    # op / device combo should be chosen.
1367    with ops.device('/device:GPU:0'):
1368      a = constant_op.constant(1.0, shape=[1, 2])
1369
1370    b = constant_op.constant(1.0, shape=[1, 2])
1371
1372    # Only run the valid op, this should work.
1373    b.eval()
1374
1375    with self.assertRaises(errors.InvalidArgumentError):
1376      a.eval()
1377    sess.close()
1378
1379  @test_util.run_v1_only('b/120545219')
1380  def testDefaultSessionPlacePrunedGraph(self):
1381    sess = session.Session()
1382
1383    # Build a graph that has a bad op in it (no kernel).
1384    #
1385    # This test currently does not link in any GPU kernels,
1386    # which is why placing this is invalid.  If at some point
1387    # GPU kernels are added to this test, some other different
1388    # op / device combo should be chosen.
1389    with ops.device('/device:GPU:0'):
1390      _ = constant_op.constant(1.0, shape=[1, 2])
1391
1392    b = constant_op.constant(1.0, shape=[1, 2])
1393
1394    with self.assertRaises(errors.InvalidArgumentError):
1395      # Even though we don't run the bad op, we place the entire
1396      # graph, which should fail with a non-interactive session.
1397      sess.run(b)
1398
1399    sess.close()
1400
1401  def testSharedGraph(self):
1402    with ops.Graph().as_default() as g, ops.device('/cpu:0'):
1403      a = constant_op.constant(1.0, shape=[1, 2])
1404      b = constant_op.constant(2.0, shape=[2, 3])
1405      c = math_ops.matmul(a, b)
1406
1407    with session.Session(graph=g) as sess1:
1408      with session.Session(graph=g) as sess2:
1409        self.assertAllEqual(sess1.run(c), sess2.run(c))
1410
1411  def testDuplicatedInputs(self):
1412    with session.Session() as sess:
1413      a = constant_op.constant(1.0, shape=[1, 2])
1414      b = constant_op.constant(2.0, shape=[1, 3])
1415      a_val, b_val, a2_val = sess.run([a, b, a])
1416      self.assertAllEqual(a_val, [[1.0, 1.0]])
1417      self.assertAllEqual(b_val, [[2.0, 2.0, 2.0]])
1418      self.assertAllEqual(a2_val, [[1.0, 1.0]])
1419
1420  def testFeedAndFetch(self):
1421    with session.Session() as sess:
1422      for dtype in [
1423          dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32,
1424          dtypes.uint8, dtypes.int16, dtypes.int8, dtypes.int64, dtypes.bool,
1425          dtypes.complex64, dtypes.complex128
1426      ]:
1427        for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]:
1428          np_dtype = dtype.as_numpy_dtype
1429
1430          feed_t = array_ops.placeholder(dtype=dtype, shape=shape)
1431          out_t = array_ops.identity(feed_t)
1432
1433          np_array = np.random.randint(-10, 10, shape)
1434
1435          if dtype == dtypes.bool:
1436            np_array = np_array > 0
1437          elif dtype == dtypes.complex64:
1438            np_array = np.sqrt(np_array.astype(np_dtype))
1439          elif dtype == dtypes.complex64:
1440            np_array = np.sqrt(np_array.astype(np_dtype))
1441          else:
1442            np_array = np_array.astype(np_dtype)
1443
1444          self.assertAllEqual(np_array,
1445                              sess.run(out_t, feed_dict={
1446                                  feed_t: np_array
1447                              }))
1448          # Check that we can also get the feed back.
1449          self.assertAllEqual(np_array,
1450                              sess.run(feed_t, feed_dict={
1451                                  feed_t: np_array
1452                              }))
1453          # Also check that we can get both back.
1454          out_v, feed_v = sess.run(
1455              [out_t, feed_t], feed_dict={
1456                  feed_t: np_array
1457              })
1458          self.assertAllEqual(np_array, out_v)
1459          self.assertAllEqual(np_array, feed_v)
1460
1461          feed_fetch_runner = sess.make_callable([out_t, feed_t], [feed_t])
1462          out_v, feed_v = feed_fetch_runner(np_array)
1463          self.assertAllEqual(np_array, out_v)
1464          self.assertAllEqual(np_array, feed_v)
1465
1466  def testMakeCallableOnTensorWithRunOptions(self):
1467    with session.Session() as sess:
1468      a = constant_op.constant(42.0)
1469      tensor_runner = sess.make_callable(a, accept_options=True)
1470      run_options = config_pb2.RunOptions(
1471          trace_level=config_pb2.RunOptions.FULL_TRACE)
1472      run_metadata = config_pb2.RunMetadata()
1473      self.assertEqual(0, len(run_metadata.step_stats.dev_stats))
1474      res = tensor_runner(options=run_options, run_metadata=run_metadata)
1475      self.assertEqual(42.0, res)
1476      self.assertGreater(len(run_metadata.step_stats.dev_stats), 0)
1477
1478  def testMakeCallableOnOperationWithRunOptions(self):
1479    with session.Session() as sess:
1480      a = variables.Variable(42.0)
1481      b = state_ops.assign_add(a, 1.0)
1482      sess.run(a.initializer)
1483      tensor_runner = sess.make_callable(b.op, accept_options=True)
1484      run_options = config_pb2.RunOptions(
1485          trace_level=config_pb2.RunOptions.FULL_TRACE)
1486      run_metadata = config_pb2.RunMetadata()
1487      self.assertEqual(0, len(run_metadata.step_stats.dev_stats))
1488      tensor_runner(options=run_options, run_metadata=run_metadata)
1489      self.assertEqual(43.0, sess.run(a))
1490      self.assertGreater(len(run_metadata.step_stats.dev_stats), 0)
1491
1492  def testMakeCallableWithFeedListAndRunOptions(self):
1493    with session.Session() as sess:
1494      ph = array_ops.placeholder(dtypes.float32)
1495      a = math_ops.add(ph, 1.0)
1496      tensor_runner = sess.make_callable(
1497          a, feed_list=[ph.name], accept_options=True)
1498      run_options = config_pb2.RunOptions(
1499          trace_level=config_pb2.RunOptions.FULL_TRACE)
1500      run_metadata = config_pb2.RunMetadata()
1501      self.assertEqual(0, len(run_metadata.step_stats.dev_stats))
1502      self.assertAllClose(42.0,
1503                          tensor_runner(
1504                              41.0,
1505                              options=run_options,
1506                              run_metadata=run_metadata))
1507      self.assertGreater(len(run_metadata.step_stats.dev_stats), 0)
1508
1509  def testOptimizedMakeCallable(self):
1510    with session.Session() as sess:
1511      ph = array_ops.placeholder(dtypes.float32)
1512      a = math_ops.add(ph, 1.0)
1513      callable_opts = config_pb2.CallableOptions()
1514      callable_opts.feed.append(ph.name)
1515      callable_opts.fetch.append(a.name)
1516      for _ in range(3):
1517        callable_fn = sess._make_callable_from_options(callable_opts)
1518        for _ in range(5):
1519          self.assertEqual([2.0], callable_fn(np.array(1.0, dtype=np.float32)))
1520
1521  def testOptimizedMakeCallableWithRunMetadata(self):
1522    with session.Session() as sess:
1523      ph = array_ops.placeholder(dtypes.float32)
1524      a = math_ops.add(ph, 1.0)
1525      callable_opts = config_pb2.CallableOptions()
1526      callable_opts.feed.append(ph.name)
1527      callable_opts.fetch.append(a.name)
1528      callable_opts.run_options.trace_level = config_pb2.RunOptions.FULL_TRACE
1529      callable_fn = sess._make_callable_from_options(callable_opts)
1530      run_metadata = config_pb2.RunMetadata()
1531      self.assertEqual([2.0], callable_fn(np.array(1.0, dtype=np.float32),
1532                                          run_metadata=run_metadata))
1533      self.assertGreater(len(run_metadata.step_stats.dev_stats), 0)
1534
1535  def testFeedError(self):
1536    with session.Session() as sess:
1537      feed_t = array_ops.placeholder(dtype=dtypes.float32)
1538      out_t = array_ops.identity(feed_t)
1539      feed_val = constant_op.constant(5.0)
1540      with self.assertRaisesRegex(TypeError, 'cannot be a tf.Tensor object'):
1541        sess.run(out_t, feed_dict={feed_t: feed_val})
1542      with self.assertRaisesRegex(TypeError, 'cannot be a tf.Tensor object'):
1543        out_t.eval(feed_dict={feed_t: feed_val})
1544      with self.assertRaisesRegex(TypeError, 'cannot be a tf.Tensor object'):
1545        out_t.op.run(feed_dict={feed_t: feed_val})
1546
1547  def testFeedPrecisionLossError(self):
1548    with session.Session() as sess:
1549      largest_int64 = np.iinfo(np.int64).max
1550
1551      feed_int_implicit_int32 = constant_op.constant(1)
1552      feed_int_explicit_int32 = constant_op.constant(1, dtype=dtypes.int32)
1553
1554      out_t = constant_op.constant(1.0)
1555
1556      with self.assertRaisesRegex(TypeError,
1557                                  'is not compatible with Tensor type'):
1558        sess.run(out_t, feed_dict={feed_int_implicit_int32: largest_int64})
1559      with self.assertRaisesRegex(TypeError,
1560                                  'is not compatible with Tensor type'):
1561        sess.run(out_t, feed_dict={feed_int_explicit_int32: largest_int64})
1562
1563  def testStringFetch(self):
1564    with session.Session():
1565      for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]:
1566        size = 1
1567        for s in shape:
1568          size *= s
1569        c_list = np.array([compat.as_bytes(str(i)) for i in range(size)],
1570                          dtype=np.object_).reshape(shape) if size > 0 else []
1571        c = constant_op.constant(c_list)
1572        self.assertAllEqual(c, c_list)
1573
1574  def testStringFeed(self):
1575    with session.Session() as sess:
1576      for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]:
1577        size = 1
1578        for s in shape:
1579          size *= s
1580        c_list = np.array([compat.as_bytes(str(i)) for i in range(size)],
1581                          dtype=np.object_).reshape(shape)
1582        feed_t = array_ops.placeholder(dtype=dtypes.string, shape=shape)
1583        c = array_ops.identity(feed_t)
1584        self.assertAllEqual(sess.run(c, feed_dict={feed_t: c_list}), c_list)
1585        self.assertAllEqual(
1586            sess.run(feed_t, feed_dict={
1587                feed_t: c_list
1588            }), c_list)
1589        c_v, feed_v = sess.run([c, feed_t], feed_dict={feed_t: c_list})
1590        self.assertAllEqual(c_v, c_list)
1591        self.assertAllEqual(feed_v, c_list)
1592
1593  def testStringFeedWithNullCharacters(self):
1594    with session.Session():
1595      c_list = [b'\n\x01\x00', b'\n\x00\x01']
1596      feed_t = array_ops.placeholder(dtype=dtypes.string, shape=[2])
1597      c = array_ops.identity(feed_t)
1598      out = c.eval(feed_dict={feed_t: c_list})
1599      self.assertEqual(c_list[0], out[0])
1600      self.assertEqual(c_list[1], out[1])
1601
1602  def testStringFeedWithUnicode(self):
1603    with session.Session():
1604      c_list = [
1605          u'\n\x01\x00', u'\n\x00\x01', u'\u26a3 unicode',
1606          u'\U0001f60e deal with it'
1607      ]
1608      feed_t = array_ops.placeholder(dtype=dtypes.string, shape=[len(c_list)])
1609      c = array_ops.identity(feed_t)
1610
1611      out = c.eval(feed_dict={feed_t: c_list})
1612      for i in range(len(c_list)):
1613        self.assertEqual(c_list[i], out[i].decode('utf-8'))
1614
1615      out = c.eval(feed_dict={feed_t: np.array(c_list, dtype=np.object_)})
1616      for i in range(len(c_list)):
1617        self.assertEqual(c_list[i], out[i].decode('utf-8'))
1618
1619  def testInvalidTargetFails(self):
1620    with self.assertRaisesRegex(
1621        errors.NotFoundError,
1622        'No session factory registered for the given session options'):
1623      session.Session('INVALID_TARGET')
1624
1625  def testFetchByNameDifferentStringTypes(self):
1626    with session.Session() as sess:
1627      c = constant_op.constant(42.0, name='c')
1628      d = constant_op.constant(43.0, name=u'd')
1629      e = constant_op.constant(44.0, name=b'e')
1630      f = constant_op.constant(45.0, name=r'f')
1631
1632      self.assertIsInstance(c.name, six.text_type)
1633      self.assertIsInstance(d.name, six.text_type)
1634      self.assertIsInstance(e.name, six.text_type)
1635      self.assertIsInstance(f.name, six.text_type)
1636
1637      self.assertEqual(42.0, sess.run('c:0'))
1638      self.assertEqual(42.0, sess.run(u'c:0'))
1639      self.assertEqual(42.0, sess.run(b'c:0'))
1640      self.assertEqual(42.0, sess.run(r'c:0'))
1641
1642      self.assertEqual(43.0, sess.run('d:0'))
1643      self.assertEqual(43.0, sess.run(u'd:0'))
1644      self.assertEqual(43.0, sess.run(b'd:0'))
1645      self.assertEqual(43.0, sess.run(r'd:0'))
1646
1647      self.assertEqual(44.0, sess.run('e:0'))
1648      self.assertEqual(44.0, sess.run(u'e:0'))
1649      self.assertEqual(44.0, sess.run(b'e:0'))
1650      self.assertEqual(44.0, sess.run(r'e:0'))
1651
1652      self.assertEqual(45.0, sess.run('f:0'))
1653      self.assertEqual(45.0, sess.run(u'f:0'))
1654      self.assertEqual(45.0, sess.run(b'f:0'))
1655      self.assertEqual(45.0, sess.run(r'f:0'))
1656
1657  def testIncorrectGraph(self):
1658    with ops.Graph().as_default() as g_1:
1659      c_1 = constant_op.constant(1.0, name='c')
1660
1661    with ops.Graph().as_default() as g_2:
1662      c_2 = constant_op.constant(2.0, name='c')
1663
1664    self.assertEqual('c', c_1.op.name)
1665    self.assertEqual('c', c_2.op.name)
1666
1667    with session.Session(graph=g_1) as sess_1:
1668      self.assertEqual(1.0, sess_1.run(c_1))
1669      with self.assertRaises(ValueError):
1670        sess_1.run(c_2)
1671      with self.assertRaises(ValueError):
1672        sess_1.run(c_2.op)
1673
1674    with session.Session(graph=g_2) as sess_2:
1675      with self.assertRaises(ValueError):
1676        sess_2.run(c_1)
1677      with self.assertRaises(ValueError):
1678        sess_2.run(c_1.op)
1679      self.assertEqual(2.0, sess_2.run(c_2))
1680
1681  def testFeedDictKeyException(self):
1682    with session.Session() as sess:
1683      a = constant_op.constant(1.0, dtypes.float32, name='a')
1684      with self.assertRaisesRegex(TypeError, 'Cannot interpret feed_dict'):
1685        sess.run(a, feed_dict={'a': [2.0]})
1686
1687  def testPerStepTrace(self):
1688    run_options = config_pb2.RunOptions(
1689        trace_level=config_pb2.RunOptions.SOFTWARE_TRACE)
1690    run_metadata = config_pb2.RunMetadata()
1691
1692    with ops.device('/cpu:0'):
1693      with session.Session() as sess:
1694        sess.run(constant_op.constant(1.0))
1695        self.assertFalse(run_metadata.HasField('step_stats'))
1696
1697        sess.run(constant_op.constant(1.0), run_metadata=run_metadata)
1698        self.assertFalse(run_metadata.HasField('step_stats'))
1699
1700        sess.run(
1701            constant_op.constant(1.0),
1702            options=run_options,
1703            run_metadata=run_metadata)
1704
1705        self.assertTrue(run_metadata.HasField('step_stats'))
1706        self.assertEqual(len(run_metadata.step_stats.dev_stats), 1)
1707
1708  def testRunOptionsRunMetadata(self):
1709    run_options = config_pb2.RunOptions(
1710        trace_level=config_pb2.RunOptions.SOFTWARE_TRACE)
1711    run_metadata = config_pb2.RunMetadata()
1712
1713    with ops.device('/cpu:0'):
1714      with session.Session() as sess:
1715        # all combinations are valid
1716        sess.run(constant_op.constant(1.0), options=None, run_metadata=None)
1717        sess.run(
1718            constant_op.constant(1.0), options=None, run_metadata=run_metadata)
1719        self.assertFalse(run_metadata.HasField('step_stats'))
1720
1721        sess.run(
1722            constant_op.constant(1.0), options=run_options, run_metadata=None)
1723        self.assertFalse(run_metadata.HasField('step_stats'))
1724
1725        sess.run(
1726            constant_op.constant(1.0),
1727            options=run_options,
1728            run_metadata=run_metadata)
1729
1730        self.assertTrue(run_metadata.HasField('step_stats'))
1731        self.assertEqual(len(run_metadata.step_stats.dev_stats), 1)
1732
1733  def testFeedShapeCompatibility(self):
1734    with session.Session() as sess:
1735      some_tensor = constant_op.constant([2.0, 2.0, 2.0, 2.0])
1736      new_shape = constant_op.constant([2, 2])
1737      reshaped_tensor = array_ops.reshape(some_tensor, new_shape)
1738
1739      with self.assertRaisesRegex(ValueError, 'Cannot feed value of shape'):
1740        sess.run(reshaped_tensor, feed_dict={some_tensor: [1.0, 2.0, 3.0]})
1741
1742      with self.assertRaisesRegex(
1743          errors.InvalidArgumentError,
1744          'Input to reshape is a tensor with 4 values, '
1745          'but the requested shape has 21'):
1746        sess.run(reshaped_tensor, feed_dict={new_shape: [3, 7]})
1747
1748  def testInferShapesFalse(self):
1749    with ops.Graph().as_default(), ops.device('/cpu:0'):
1750      a = constant_op.constant([[1, 2]])
1751      sess = session.Session()
1752      self.assertNotIn('_output_shapes', sess.graph_def.node[0].attr)
1753      # Avoid lint error regarding 'unused' var a.
1754      self.assertEqual(a, a)
1755
1756  def testInferShapesTrue(self):
1757    config_pb = config_pb2.ConfigProto(
1758        graph_options=config_pb2.GraphOptions(infer_shapes=True))
1759    with ops.Graph().as_default(), ops.device('/cpu:0'):
1760      a = constant_op.constant([[1, 2]])
1761      sess = session.Session(config=config_pb)
1762      self.assertIn('_output_shapes', sess.graph_def.node[0].attr)
1763      # Avoid lint error regarding 'unused' var a.
1764      self.assertEqual(a, a)
1765
1766  def testBuildCostModel(self):
1767    run_options = config_pb2.RunOptions()
1768    config_pb = config_pb2.ConfigProto(
1769        allow_soft_placement=True,
1770        graph_options=config_pb2.GraphOptions(build_cost_model=100))
1771    with session.Session(config=config_pb) as sess:
1772      with ops.device('/device:GPU:0'):
1773        a = array_ops.placeholder(dtypes.float32, shape=[])
1774        b = math_ops.add(a, a)
1775        c = array_ops.identity(b)
1776        d = math_ops.multiply(c, c)
1777      for step in range(120):
1778        run_metadata = config_pb2.RunMetadata()
1779        sess.run(
1780            d,
1781            feed_dict={a: 1.0},
1782            options=run_options,
1783            run_metadata=run_metadata)
1784        if step == 99:
1785          self.assertTrue(run_metadata.HasField('cost_graph'))
1786        else:
1787          self.assertFalse(run_metadata.HasField('cost_graph'))
1788
1789  def runTestOutputPartitionGraphs(self, sess):
1790    run_options = config_pb2.RunOptions(output_partition_graphs=True)
1791    a = constant_op.constant(1)
1792    run_metadata = config_pb2.RunMetadata()
1793    sess.run(a, options=run_options, run_metadata=run_metadata)
1794    self.assertGreater(len(run_metadata.partition_graphs), 0)
1795    sess.run(a, run_metadata=run_metadata)
1796    self.assertEqual(len(run_metadata.partition_graphs), 0)
1797
1798  @test_util.run_v1_only('b/120545219')
1799  def testOutputPartitionGraphsDirect(self):
1800    self.runTestOutputPartitionGraphs(session.Session())
1801
1802  @test_util.run_v1_only('b/120545219')
1803  def testOutputPartitionGraphsDistributed(self):
1804    server = server_lib.Server.create_local_server()
1805    self.runTestOutputPartitionGraphs(session.Session(server.target))
1806
1807  def testNonInteractiveSessionNesting(self):
1808    sess1 = session.Session()
1809    sess1_controller = sess1.as_default()
1810    sess1_controller.__enter__()
1811
1812    sess2 = session.Session()
1813    sess2_controller = sess2.as_default()
1814    sess2_controller.__enter__()
1815
1816    with self.assertRaisesRegex(AssertionError, 'Nesting violated'):
1817      sess1_controller.__exit__(None, None, None)
1818
1819    ops._default_session_stack.reset()
1820
1821  def testInteractiveSessionNesting(self):
1822    sess1 = session.InteractiveSession()
1823    sess2 = session.InteractiveSession()
1824    del sess1
1825    del sess2
1826
1827  @test_util.run_v1_only('b/120545219')
1828  def testAsDefault(self):
1829    c = constant_op.constant(37)
1830    sess = session.Session()
1831    with sess.as_default():
1832      self.assertEqual(37, c.eval())
1833
1834    # Ensure that the session remains valid even when it is not captured.
1835    with session.Session().as_default():
1836      self.assertEqual(37, c.eval())
1837
1838  def testReentry(self):
1839    sess = session.Session()
1840    with self.assertRaisesRegex(RuntimeError, 'not re-entrant'):
1841      with sess:
1842        with sess:
1843          pass
1844
1845  def testInvalidArgument(self):
1846    with self.assertRaisesRegex(TypeError,
1847                                'Argument `target` must be a string'):
1848      session.Session(37)
1849    with self.assertRaisesRegex(TypeError,
1850                                'Argument `config` must be a tf.ConfigProto'):
1851      session.Session(config=37)
1852    with self.assertRaisesRegex(TypeError,
1853                                'Argument `graph` must be a tf.Graph'):
1854      session.Session(graph=37)
1855
1856  @test_util.run_v1_only('b/120545219')
1857  def testTimeoutWithShortOperations(self):
1858    num_epochs = 5
1859    q = data_flow_ops.FIFOQueue(capacity=50, dtypes=[dtypes.int32], shapes=[()])
1860    enqueue_op = q.enqueue_many(constant_op.constant([1, 2]))
1861
1862    # Use a 10-second timeout, which should be longer than any
1863    # non-blocking enqueue_many op.
1864    config_pb = config_pb2.ConfigProto(operation_timeout_in_ms=10000)
1865    with session.Session(config=config_pb) as sess:
1866      for _ in range(num_epochs):
1867        sess.run(enqueue_op)
1868      self.assertEqual(sess.run(q.size()), num_epochs * 2)
1869
1870  @test_util.run_v1_only('b/120545219')
1871  def testRegisterFetchAndFeedConversionFunctions(self):
1872
1873    class SquaredTensor(object):
1874
1875      def __init__(self, tensor):
1876        self.sq = math_ops.square(tensor)
1877
1878    fetch_fn = lambda squared_tensor: ([squared_tensor.sq], lambda val: val[0])
1879    feed_fn1 = lambda feed, feed_val: [(feed.sq, feed_val)]
1880    feed_fn2 = lambda feed: [feed.sq]
1881
1882    session.register_session_run_conversion_functions(SquaredTensor, fetch_fn,
1883                                                      feed_fn1, feed_fn2)
1884    with self.assertRaises(ValueError):
1885      session.register_session_run_conversion_functions(SquaredTensor, fetch_fn,
1886                                                        feed_fn1, feed_fn2)
1887    with self.cached_session() as sess:
1888      np1 = np.array([1.0, 1.5, 2.0, 2.5])
1889      np2 = np.array([3.0, 3.5, 4.0, 4.5])
1890      squared_tensor = SquaredTensor(np2)
1891      squared_eval = sess.run(squared_tensor)
1892      self.assertAllClose(np2 * np2, squared_eval)
1893      squared_eval = sess.run(
1894          squared_tensor, feed_dict={
1895              squared_tensor: np1 * np1
1896          })
1897      self.assertAllClose(np1 * np1, squared_eval)
1898      partial_run = sess.partial_run_setup([squared_tensor], [])
1899      squared_eval = sess.partial_run(partial_run, squared_tensor)
1900      self.assertAllClose(np2 * np2, squared_eval)
1901
1902  def testDefaultLogDevicePlacement(self):
1903
1904    class CaptureStderr(str):
1905      """Class to capture stderr from C++ shared library."""
1906
1907      def __enter__(self):
1908        self._esc = compat.as_str('\b')
1909        self._output = compat.as_str('')
1910        self._stderr = sys.stderr
1911        self._fd = self._stderr.fileno()
1912        self._out_pipe, in_pipe = os.pipe()
1913        # Save the original io stream.
1914        self._dup_fd = os.dup(self._fd)
1915        # Replace the original io stream with in pipe.
1916        os.dup2(in_pipe, self._fd)
1917        return self
1918
1919      def __exit__(self, *args):
1920        self._stderr.write(self._esc)
1921        self._stderr.flush()
1922        self.read()
1923        os.close(self._out_pipe)
1924        # Restore the original io stream.
1925        os.dup2(self._dup_fd, self._fd)
1926
1927      def read(self):
1928        while True:
1929          data = os.read(self._out_pipe, 1)
1930          if not data or compat.as_str(data) == self._esc:
1931            break
1932          self._output += compat.as_str(data)
1933
1934      def __str__(self):
1935        return self._output
1936
1937    context.set_log_device_placement(True)
1938    if context.executing_eagerly():
1939      with CaptureStderr() as log:
1940        a = constant_op.constant(1)
1941        b = constant_op.constant(2)
1942        c = a + b
1943        # Ensure if the same kernel with the same arguments is executed then its
1944        # execution is logged.
1945        d = a + b
1946    else:
1947      # Passing the config to the server, but not the session should still
1948      # result in logging device placement.
1949      config_pb = config_pb2.ConfigProto(log_device_placement=True)
1950      server = server_lib.Server.create_local_server(config=config_pb)
1951      a = constant_op.constant(1)
1952      b = constant_op.constant(2)
1953      c = a + b
1954      d = a + b
1955      with session.Session(server.target) as sess:
1956        with CaptureStderr() as log:
1957          c, d = sess.run([c, d])
1958
1959    self.assertEqual(c, 3)
1960    self.assertEqual(d, 3)
1961
1962    # Ensure that we did log device placement.
1963    # We have three modes of execution at the moment:
1964    # (1) TF1 Graph (2) TF2 eager (3) TF2 eager with function wrapping.
1965    # The codepaths taken by each are slightly different in all resulting in
1966    # slightly different logging messages.
1967    log_msg = ('Executing op AddV2'
1968               if ops.executing_eagerly_outside_functions() else 'AddV2')
1969    add_executions = [l for l in str(log).splitlines() if log_msg in l]
1970    self.assertEqual(len(add_executions), 2)
1971
1972    @def_function.function
1973    def fn(a, b):
1974      c = a + b
1975      # These two AddV2 cannot use the same argument in tf.function since an
1976      # optimization pass will remove duplicate ops and only run it once.
1977      d = a + c
1978      return c, d
1979
1980    with CaptureStderr() as log:
1981      c, d = self.evaluate(fn(constant_op.constant(1), constant_op.constant(2)))
1982    self.assertEqual(c, 3)
1983    self.assertEqual(d, 4)
1984    # Ensure that we did log device placement.
1985    add_executions = [l for l in str(log).splitlines() if 'AddV2' in l]
1986    self.assertEqual(len(add_executions), 2)
1987
1988  @test_util.run_v1_only('b/120545219')
1989  def testLocalMasterSessionTimeout(self):
1990    # Test that the timeout passed in a config to the session works correctly.
1991    config_pb = config_pb2.ConfigProto(operation_timeout_in_ms=1000)
1992    server = server_lib.Server.create_local_server()
1993    q = data_flow_ops.FIFOQueue(1, dtypes.float32)
1994    dequeued_t = q.dequeue()
1995
1996    with session.Session(server.target, config=config_pb) as sess:
1997      # Intentionally do not run any enqueue_ops so that dequeue will block
1998      # until operation_timeout_in_ms.
1999      with self.assertRaises(errors.DeadlineExceededError):
2000        sess.run(dequeued_t)
2001
2002  @test_util.run_v1_only('b/120545219')
2003  def testDefaultServerTimeout(self):
2004    # Test that the default server config timeout gets used when no Session
2005    # config is provided.
2006    config_pb = config_pb2.ConfigProto(operation_timeout_in_ms=1000)
2007    server = server_lib.Server.create_local_server(config=config_pb)
2008    q = data_flow_ops.FIFOQueue(1, dtypes.float32)
2009    dequeued_t = q.dequeue()
2010
2011    with session.Session(server.target) as sess:
2012      # Intentionally do not run any enqueue_ops so that dequeue will block
2013      # until operation_timeout_in_ms.
2014      with self.assertRaises(errors.DeadlineExceededError):
2015        sess.run(dequeued_t)
2016
2017  def runTestBuildGraphError(self, sess):
2018    # Ensure that errors from building the graph get propagated.
2019    data = array_ops.placeholder(dtypes.float32, shape=[])
2020    # pylint: disable=protected-access
2021    enter_1 = gen_control_flow_ops.enter(data, 'foo_1', False)
2022    enter_2 = gen_control_flow_ops.enter(data, 'foo_2', False)
2023    # pylint: enable=protected-access
2024    res = math_ops.add(enter_1, enter_2)
2025    with self.assertRaisesOpError('has inputs from different frames'):
2026      sess.run(res, feed_dict={data: 1.0})
2027
2028  @test_util.run_v1_only('b/120545219')
2029  def testBuildGraphErrorDirect(self):
2030    self.runTestBuildGraphError(session.Session())
2031
2032  @test_util.run_v1_only('b/120545219')
2033  def testBuildGraphErrorDist(self):
2034    server = server_lib.Server.create_local_server()
2035    self.runTestBuildGraphError(session.Session(server.target))
2036
2037  def testDeviceAttributes(self):
2038    attrs = session._DeviceAttributes(
2039        '/job:worker/replica:0/task:3/device:CPU:2', 'TYPE', 1337, 1000000)
2040    self.assertEqual(1337, attrs.memory_limit_bytes)
2041    self.assertEqual('/job:worker/replica:0/task:3/device:CPU:2', attrs.name)
2042    self.assertEqual('TYPE', attrs.device_type)
2043    self.assertEqual(1000000, attrs.incarnation)
2044    str_repr = '%s' % attrs
2045    self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr)
2046
2047  def testDeviceAttributesCanonicalization(self):
2048    attrs = session._DeviceAttributes('/job:worker/replica:0/task:3/cpu:1',
2049                                      'TYPE', 1337, 1000000)
2050    self.assertEqual(1337, attrs.memory_limit_bytes)
2051    self.assertEqual('/job:worker/replica:0/task:3/device:CPU:1', attrs.name)
2052    self.assertEqual('TYPE', attrs.device_type)
2053    self.assertEqual(1000000, attrs.incarnation)
2054    str_repr = '%s' % attrs
2055    self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr)
2056
2057  def runTestAddFunctionToSession(self, target=''):
2058    """Add a function to a session after the graph has already been run."""
2059
2060    @function.Defun(dtypes.float32)
2061    def foo(x):
2062      return x + 1
2063
2064    x = constant_op.constant(1.0)
2065    with session.Session(target=target) as sess:
2066      sess.run(x)
2067      f = foo(x)
2068      result = sess.run(f)
2069      self.assertEqual(result, 2.0)
2070
2071  @test_util.run_v1_only('b/120545219')
2072  def testAddFunctionToSession(self):
2073    self.runTestAddFunctionToSession()
2074
2075  @test_util.run_v1_only('b/120545219')
2076  def testAddFunctionToGrpcSession(self):
2077    server = server_lib.Server.create_local_server()
2078    self.runTestAddFunctionToSession(server.target)
2079
2080  def testOpenAndCloseGrpcSession(self):
2081    server = server_lib.Server.create_local_server()
2082    with session.Session(server.target):
2083      pass
2084
2085  def testOpenAndCloseSession(self):
2086    with session.Session():
2087      pass
2088
2089  @test_util.run_v1_only('b/120545219')
2090  def testAutoConvertAndCheckData(self):
2091    with self.cached_session() as sess:
2092      a = array_ops.placeholder(dtype=dtypes.string)
2093      with self.assertRaisesRegex(
2094          TypeError, r'Type of feed value 1 with type <(\w+) \'int\'> is not'):
2095        sess.run(a, feed_dict={a: 1})
2096
2097  @test_util.run_v1_only('b/120545219')
2098  def testOptimizerOptions(self):
2099    config.set_optimizer_experimental_options({'min_graph_nodes': -1})
2100
2101    with ops.Graph().as_default():
2102      sess = session.Session()
2103      self.assertEqual(
2104          sess._config.graph_options.rewrite_options.min_graph_nodes, -1)
2105
2106
2107if __name__ == '__main__':
2108  googletest.main()
2109