1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for `make_slices_selector_example_selector` custom op.""" 15 16import tensorflow as tf 17 18from fcp.protos import plan_pb2 19from fcp.tensorflow import make_slices_selector_example_selector 20 21 22class MakeSlicesSelectorExampleSelectorTest(tf.test.TestCase): 23 24 def test_returns_serialized_proto(self): 25 served_at_id = 'test_served_at_id' 26 keys = [1, 3, 5, 20] 27 serialized_proto_tensor = make_slices_selector_example_selector.make_slices_selector_example_selector( 28 served_at_id, keys) 29 self.assertIsInstance(serialized_proto_tensor, tf.Tensor) 30 self.assertEqual(serialized_proto_tensor.dtype, tf.string) 31 serialized_proto = serialized_proto_tensor.numpy() 32 example_selector = plan_pb2.ExampleSelector.FromString(serialized_proto) 33 self.assertEqual(example_selector.collection_uri, 34 'internal:/federated_select') 35 slices_selector = plan_pb2.SlicesSelector() 36 self.assertTrue(example_selector.criteria.Unpack(slices_selector)) 37 self.assertEqual(slices_selector.served_at_id, served_at_id) 38 self.assertEqual(slices_selector.keys, keys) 39 40 41if __name__ == '__main__': 42 tf.test.main() 43