xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/tensor_name_test.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for `tensor_name` custom op."""
15
16import tensorflow as tf
17
18from fcp.tensorflow import tensor_name
19
20
21class TensorNameTest(tf.test.TestCase):
22
23  def test_returns_simple_name(self):
24    test_name = b'placeholder_test_name'
25    with tf.Graph().as_default() as graph:
26      placeholder = tf.compat.v1.placeholder_with_default(
27          input='default_value', shape=(), name=test_name)
28      tensor_name_out = tensor_name.tensor_name(placeholder)
29    with tf.compat.v1.Session(graph=graph) as sess:
30      result = sess.run(tensor_name_out)
31    self.assertEqual(test_name, result)
32
33  def test_returns_modified_name_after_reimport(self):
34    test_name = b'placeholder_test_name'
35    with tf.Graph().as_default() as inner_graph:
36      placeholder = tf.compat.v1.placeholder_with_default(
37          input='default_value', shape=(), name=test_name)
38      inner_tensor_name_out = tensor_name.tensor_name(placeholder)
39    import_prefix = b'import_prefix_'
40    with tf.Graph().as_default() as outer_graph:
41      tensor_name_out = tf.graph_util.import_graph_def(
42          graph_def=inner_graph.as_graph_def(),
43          input_map={},
44          return_elements=[inner_tensor_name_out.name],
45          name=import_prefix)[0]
46    with tf.compat.v1.Session(graph=outer_graph) as sess:
47      result = sess.run(tensor_name_out)
48    self.assertEqual(b'/'.join([import_prefix, test_name]), result)
49
50
51if __name__ == '__main__':
52  tf.test.main()
53