1from absl.testing import parameterized 2import tensorflow as tf 3from google.protobuf import text_format 4from fcp.dictionary import dictionary_pb2 5from fcp.tensorflow import dictionary_ops 6 7 8class DictionaryOpsTest(tf.test.TestCase, parameterized.TestCase): 9 10 def test_direct_tf_use_literal_dictionary(self): 11 dictionary = dictionary_pb2.DictionaryDescription() 12 text_format.Merge( 13 'special_ids: < unk: 0 > ' 14 'vocabulary: < ' 15 ' index: < token: "a" token: "b" token: "c" token: "d" >' 16 '>', 17 dictionary) 18 19 lookup = dictionary_ops.dictionary_lookup( 20 tf.constant(['a', 'b', 'a', 'a', 'd', 'X']), 21 dictionary_description_proto=dictionary.SerializeToString()) 22 with tf.compat.v1.Session() as sess: 23 tokenized = sess.run(lookup) 24 self.assertEqual([1, 2, 1, 1, 4, 0], tokenized.tolist()) 25 26 @parameterized.named_parameters( 27 ('token_index', dictionary_ops.VocabularyType.TOKEN_INDEX)) 28 def test_build_dictionary_with_output_blocklist(self, vocabulary_type): 29 # Build a dictionary, explicitly blocklisting the first token and 30 # implicitly blocklisting the last token via output_size. 31 dictionary = dictionary_ops.Dictionary.from_tokens( 32 ['01', '02', '10', '11'], 33 unk_id=0, 34 output_blocklist_tokens=['01'], 35 output_size=4, 36 vocabulary_type=vocabulary_type) 37 38 if vocabulary_type in ( 39 dictionary_ops.VocabularyType.TOKEN_INDEX, 40 ): 41 result = dictionary_ops.dictionary_lookup( 42 [['01', '02', '10', '11', '12']], 43 dictionary_description_proto=dictionary.dictionary_description_proto) 44 45 with tf.compat.v1.Session() as sess: 46 tokenized = sess.run(result) 47 self.assertEqual([[1, 2, 3, 4, 0]], tokenized.tolist()) 48 self.assertEqual( 49 [1, 4], list(dictionary.dictionary_description.output_blocklist_ids.id)) 50 51 @parameterized.named_parameters( 52 ('token_index', dictionary_ops.VocabularyType.TOKEN_INDEX)) 53 def test_build_dictionary(self, vocabulary_type): 54 dictionary = dictionary_ops.Dictionary.from_tokens( 55 ['A', 'a', 'B', 'c'], 56 unk_id=0, 57 vocabulary_type=vocabulary_type) 58 59 result = dictionary_ops.dictionary_lookup( 60 [['A', 'a', 'B', 'b', 'C', 'c', 'D', 'd']], 61 dictionary_description_proto=dictionary.dictionary_description_proto) 62 expected = [[1, 2, 3, 0, 0, 4, 0, 0]] 63 with tf.compat.v1.Session() as sess: 64 tokenized = sess.run(result) 65 self.assertEqual(expected, tokenized.tolist()) 66 67 @parameterized.named_parameters( 68 ('token_index', dictionary_ops.VocabularyType.TOKEN_INDEX)) 69 def test_dictionary_should_raise_with_duplicate_tokens(self, vocabulary_type): 70 with self.assertRaisesRegex(ValueError, 'Duplicate tokens'): 71 dictionary_ops.Dictionary.from_tokens(['01', '02', '11', '10', '11'], 72 vocabulary_type=vocabulary_type) 73 74 @parameterized.named_parameters( 75 ('token_index', dictionary_ops.VocabularyType.TOKEN_INDEX)) 76 def test_lookup_in_python(self, vocabulary_type): 77 dictionary = dictionary_ops.Dictionary.from_tokens( 78 ['01', '02', '10', '11'], unk_id=0, vocabulary_type=vocabulary_type) 79 self.assertLen(dictionary, 5) 80 self.assertListEqual([1, 2, 3, 4, 0], 81 dictionary.lookup(['01', '02', '10', '11', '12'])) 82 83 @parameterized.named_parameters( 84 ('token_index', dictionary_ops.VocabularyType.TOKEN_INDEX)) 85 def test_reverse_lookup_in_python(self, vocabulary_type): 86 dictionary = dictionary_ops.Dictionary.from_tokens( 87 ['01', '02', '10', '11'], unk_id=0, vocabulary_type=vocabulary_type) 88 self.assertLen(dictionary, 5) 89 rlookup = [ 90 t.decode('utf-8') for t in dictionary.reverse_lookup([3, 2, 1, 4, 0]) 91 ] 92 self.assertListEqual(['10', '02', '01', '11', ''], rlookup) 93 94 def test_literal_dictionary_in_python(self): 95 dictionary_description = dictionary_pb2.DictionaryDescription() 96 text_format.Merge( 97 'special_ids: < unk: 0 > ' 98 'vocabulary: < ' 99 ' index: < token: "a" token: "b" token: "c" token: "d" >' 100 '>', 101 dictionary_description) 102 dictionary = dictionary_ops.Dictionary.from_dictionary_description( 103 dictionary_description) 104 self.assertListEqual([b'a', b'b', b'c', b'd'], dictionary.tokens) 105 106 107if __name__ == '__main__': 108 # Required since the test still relies on v1 Session.run behavior. 109 tf.compat.v1.disable_v2_behavior() 110 tf.test.main() 111