xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/dictionary_ops_test.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1from absl.testing import parameterized
2import tensorflow as tf
3from google.protobuf import text_format
4from fcp.dictionary import dictionary_pb2
5from fcp.tensorflow import dictionary_ops
6
7
8class DictionaryOpsTest(tf.test.TestCase, parameterized.TestCase):
9
10  def test_direct_tf_use_literal_dictionary(self):
11    dictionary = dictionary_pb2.DictionaryDescription()
12    text_format.Merge(
13        'special_ids: < unk: 0 > '
14        'vocabulary: < '
15        '  index: < token: "a" token: "b" token: "c" token: "d" >'
16        '>',
17        dictionary)
18
19    lookup = dictionary_ops.dictionary_lookup(
20        tf.constant(['a', 'b', 'a', 'a', 'd', 'X']),
21        dictionary_description_proto=dictionary.SerializeToString())
22    with tf.compat.v1.Session() as sess:
23      tokenized = sess.run(lookup)
24    self.assertEqual([1, 2, 1, 1, 4, 0], tokenized.tolist())
25
26  @parameterized.named_parameters(
27      ('token_index', dictionary_ops.VocabularyType.TOKEN_INDEX))
28  def test_build_dictionary_with_output_blocklist(self, vocabulary_type):
29    # Build a dictionary, explicitly blocklisting the first token and
30    # implicitly blocklisting the last token via output_size.
31    dictionary = dictionary_ops.Dictionary.from_tokens(
32        ['01', '02', '10', '11'],
33        unk_id=0,
34        output_blocklist_tokens=['01'],
35        output_size=4,
36        vocabulary_type=vocabulary_type)
37
38    if vocabulary_type in (
39        dictionary_ops.VocabularyType.TOKEN_INDEX,
40    ):
41      result = dictionary_ops.dictionary_lookup(
42          [['01', '02', '10', '11', '12']],
43          dictionary_description_proto=dictionary.dictionary_description_proto)
44
45    with tf.compat.v1.Session() as sess:
46      tokenized = sess.run(result)
47    self.assertEqual([[1, 2, 3, 4, 0]], tokenized.tolist())
48    self.assertEqual(
49        [1, 4], list(dictionary.dictionary_description.output_blocklist_ids.id))
50
51  @parameterized.named_parameters(
52      ('token_index', dictionary_ops.VocabularyType.TOKEN_INDEX))
53  def test_build_dictionary(self, vocabulary_type):
54    dictionary = dictionary_ops.Dictionary.from_tokens(
55        ['A', 'a', 'B', 'c'],
56        unk_id=0,
57        vocabulary_type=vocabulary_type)
58
59    result = dictionary_ops.dictionary_lookup(
60        [['A', 'a', 'B', 'b', 'C', 'c', 'D', 'd']],
61        dictionary_description_proto=dictionary.dictionary_description_proto)
62    expected = [[1, 2, 3, 0, 0, 4, 0, 0]]
63    with tf.compat.v1.Session() as sess:
64      tokenized = sess.run(result)
65    self.assertEqual(expected, tokenized.tolist())
66
67  @parameterized.named_parameters(
68      ('token_index', dictionary_ops.VocabularyType.TOKEN_INDEX))
69  def test_dictionary_should_raise_with_duplicate_tokens(self, vocabulary_type):
70    with self.assertRaisesRegex(ValueError, 'Duplicate tokens'):
71      dictionary_ops.Dictionary.from_tokens(['01', '02', '11', '10', '11'],
72                                            vocabulary_type=vocabulary_type)
73
74  @parameterized.named_parameters(
75      ('token_index', dictionary_ops.VocabularyType.TOKEN_INDEX))
76  def test_lookup_in_python(self, vocabulary_type):
77    dictionary = dictionary_ops.Dictionary.from_tokens(
78        ['01', '02', '10', '11'], unk_id=0, vocabulary_type=vocabulary_type)
79    self.assertLen(dictionary, 5)
80    self.assertListEqual([1, 2, 3, 4, 0],
81                         dictionary.lookup(['01', '02', '10', '11', '12']))
82
83  @parameterized.named_parameters(
84      ('token_index', dictionary_ops.VocabularyType.TOKEN_INDEX))
85  def test_reverse_lookup_in_python(self, vocabulary_type):
86    dictionary = dictionary_ops.Dictionary.from_tokens(
87        ['01', '02', '10', '11'], unk_id=0, vocabulary_type=vocabulary_type)
88    self.assertLen(dictionary, 5)
89    rlookup = [
90        t.decode('utf-8') for t in dictionary.reverse_lookup([3, 2, 1, 4, 0])
91    ]
92    self.assertListEqual(['10', '02', '01', '11', ''], rlookup)
93
94  def test_literal_dictionary_in_python(self):
95    dictionary_description = dictionary_pb2.DictionaryDescription()
96    text_format.Merge(
97        'special_ids: < unk: 0 > '
98        'vocabulary: < '
99        '  index: < token: "a" token: "b" token: "c" token: "d" >'
100        '>',
101        dictionary_description)
102    dictionary = dictionary_ops.Dictionary.from_dictionary_description(
103        dictionary_description)
104    self.assertListEqual([b'a', b'b', b'c', b'd'], dictionary.tokens)
105
106
107if __name__ == '__main__':
108  # Required since the test still relies on v1 Session.run behavior.
109  tf.compat.v1.disable_v2_behavior()
110  tf.test.main()
111