1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expresus or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for checkpoint_tensor_reference.""" 15 16import unittest 17 18from absl.testing import absltest 19import numpy 20import tensorflow as tf 21import tensorflow_federated as tff 22 23from fcp.demo import checkpoint_tensor_reference as ctr 24from fcp.demo import test_utils 25 26TENSOR_NAME = 'test' 27DTYPE = tf.int32 28SHAPE = (2, 3) 29TEST_VALUE = tf.zeros(SHAPE, DTYPE).numpy() 30 31 32async def get_test_checkpoint(): 33 return test_utils.create_checkpoint({TENSOR_NAME: TEST_VALUE}) 34 35 36class CheckpointTensorReferenceTest(absltest.TestCase, 37 unittest.IsolatedAsyncioTestCase): 38 39 def test_type_signature(self): 40 ref = ctr.CheckpointTensorReference( 41 TENSOR_NAME, DTYPE, SHAPE, 42 tff.async_utils.SharedAwaitable(get_test_checkpoint())) 43 self.assertEqual(ref.type_signature, tff.TensorType(DTYPE, SHAPE)) 44 45 async def test_get_value(self): 46 47 async def get_checkpoint(): 48 return test_utils.create_checkpoint({TENSOR_NAME: TEST_VALUE}) 49 50 ref = ctr.CheckpointTensorReference( 51 TENSOR_NAME, DTYPE, SHAPE, 52 tff.async_utils.SharedAwaitable(get_checkpoint())) 53 self.assertTrue(numpy.array_equiv(await ref.get_value(), TEST_VALUE)) 54 55 async def test_get_value_in_graph_mode(self): 56 with tf.compat.v1.Graph().as_default(): 57 ref = ctr.CheckpointTensorReference( 58 TENSOR_NAME, DTYPE, SHAPE, 59 tff.async_utils.SharedAwaitable(get_test_checkpoint())) 60 with self.assertRaisesRegex(ValueError, 61 'get_value is only supported in eager mode'): 62 await ref.get_value() 63 64 async def test_get_value_not_found(self): 65 66 async def get_not_found_checkpoint(): 67 return test_utils.create_checkpoint({'other': TEST_VALUE}) 68 69 ref = ctr.CheckpointTensorReference( 70 TENSOR_NAME, DTYPE, SHAPE, 71 tff.async_utils.SharedAwaitable(get_not_found_checkpoint())) 72 with self.assertRaises(tf.errors.NotFoundError): 73 await ref.get_value() 74 75 async def test_get_value_with_invalid_checkpoint(self): 76 77 async def get_invalid_checkpoint(): 78 return b'invalid' 79 80 ref = ctr.CheckpointTensorReference( 81 TENSOR_NAME, DTYPE, SHAPE, 82 tff.async_utils.SharedAwaitable(get_invalid_checkpoint())) 83 with self.assertRaises(tf.errors.DataLossError): 84 await ref.get_value() 85 86 87if __name__ == '__main__': 88 absltest.main() 89