# Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for util.traverse_ir.""" import collections import unittest from compiler.util import ir_data from compiler.util import ir_data_utils from compiler.util import traverse_ir _EXAMPLE_IR = ir_data_utils.IrDataSerializer.from_json(ir_data.EmbossIr, """{ "module": [ { "type": [ { "structure": { "field": [ { "location": { "start": { "constant": { "value": "0" } }, "size": { "constant": { "value": "8" } } }, "type": { "atomic_type": { "reference": { "canonical_name": { "module_file": "", "object_path": ["UInt"] } } } }, "name": { "name": { "text": "field1" } } }, { "location": { "start": { "constant": { "value": "8" } }, "size": { "constant": { "value": "16" } } }, "type": { "array_type": { "base_type": { "atomic_type": { "reference": { "canonical_name": { "module_file": "", "object_path": ["UInt"] } } } }, "element_count": { "constant": { "value": "8" } } } }, "name": { "name": { "text": "field2" } } } ] }, "name": { "name": { "text": "Foo" } }, "subtype": [ { "structure": { "field": [ { "location": { "start": { "constant": { "value": "24" } }, "size": { "constant": { "value": "32" } } }, "type": { "atomic_type": { "reference": { "canonical_name": { "module_file": "", "object_path": ["UInt"] } } } }, "name": { "name": { "text": "bar_field1" } } }, { "location": { "start": { "constant": { "value": "32" } }, "size": { "constant": { "value": "320" } } }, "type": { "array_type": { "base_type": { "array_type": { "base_type": { "atomic_type": { "reference": { "canonical_name": { "module_file": "", "object_path": ["UInt"] } } } }, "element_count": { "constant": { "value": "16" } } } }, "automatic": { } } }, "name": { "name": { "text": "bar_field2" } } } ] }, "name": { "name": { "text": "Bar" } } } ] }, { "enumeration": { "value": [ { "name": { "name": { "text": "ONE" } }, "value": { "constant": { "value": "1" } } }, { "name": { "name": { "text": "TWO" } }, "value": { "function": { "function": "ADDITION", "args": [ { "constant": { "value": "1" } }, { "constant": { "value": "1" } } ], "function_name": { "text": "+" } } } } ] }, "name": { "name": { "text": "Bar" } } } ], "source_file_name": "t.emb" }, { "type": [ { "external": { }, "name": { "name": { "text": "UInt" }, "canonical_name": { "module_file": "", "object_path": ["UInt"] } }, "attribute": [ { "name": { "text": "statically_sized" }, "value": { "expression": { "boolean_constant": { "value": true } } } }, { "name": { "text": "size_in_bits" }, "value": { "expression": { "constant": { "value": "64" } } } } ] } ], "source_file_name": "" } ] }""") def _count_entries(sequence): counts = collections.Counter() for entry in sequence: counts[entry] += 1 return counts def _record_constant(constant, constant_list): constant_list.append(int(constant.value)) def _record_field_name_and_constant(constant, constant_list, field): constant_list.append((field.name.name.text, int(constant.value))) def _record_file_name_and_constant(constant, constant_list, source_file_name): constant_list.append((source_file_name, int(constant.value))) def _record_location_parameter_and_constant(constant, constant_list, location=None): constant_list.append((location, int(constant.value))) def _record_kind_and_constant(constant, constant_list, type_definition): if type_definition.HasField("enumeration"): constant_list.append(("enumeration", int(constant.value))) elif type_definition.HasField("structure"): constant_list.append(("structure", int(constant.value))) elif type_definition.HasField("external"): constant_list.append(("external", int(constant.value))) else: assert False, "Shouldn't be here." class TraverseIrTest(unittest.TestCase): def test_filter_on_type(self): constants = [] traverse_ir.fast_traverse_ir_top_down( _EXAMPLE_IR, [ir_data.NumericConstant], _record_constant, parameters={"constant_list": constants}) self.assertEqual( _count_entries([0, 8, 8, 8, 16, 24, 32, 16, 32, 320, 1, 1, 1, 64]), _count_entries(constants)) def test_filter_on_type_in_type(self): constants = [] traverse_ir.fast_traverse_ir_top_down( _EXAMPLE_IR, [ir_data.Function, ir_data.Expression, ir_data.NumericConstant], _record_constant, parameters={"constant_list": constants}) self.assertEqual([1, 1], constants) def test_filter_on_type_star_type(self): struct_constants = [] traverse_ir.fast_traverse_ir_top_down( _EXAMPLE_IR, [ir_data.Structure, ir_data.NumericConstant], _record_constant, parameters={"constant_list": struct_constants}) self.assertEqual(_count_entries([0, 8, 8, 8, 16, 24, 32, 16, 32, 320]), _count_entries(struct_constants)) enum_constants = [] traverse_ir.fast_traverse_ir_top_down( _EXAMPLE_IR, [ir_data.Enum, ir_data.NumericConstant], _record_constant, parameters={"constant_list": enum_constants}) self.assertEqual(_count_entries([1, 1, 1]), _count_entries(enum_constants)) def test_filter_on_not_type(self): notstruct_constants = [] traverse_ir.fast_traverse_ir_top_down( _EXAMPLE_IR, [ir_data.NumericConstant], _record_constant, skip_descendants_of=(ir_data.Structure,), parameters={"constant_list": notstruct_constants}) self.assertEqual(_count_entries([1, 1, 1, 64]), _count_entries(notstruct_constants)) def test_field_is_populated(self): constants = [] traverse_ir.fast_traverse_ir_top_down( _EXAMPLE_IR, [ir_data.Field, ir_data.NumericConstant], _record_field_name_and_constant, parameters={"constant_list": constants}) self.assertEqual(_count_entries([ ("field1", 0), ("field1", 8), ("field2", 8), ("field2", 8), ("field2", 16), ("bar_field1", 24), ("bar_field1", 32), ("bar_field2", 16), ("bar_field2", 32), ("bar_field2", 320) ]), _count_entries(constants)) def test_file_name_is_populated(self): constants = [] traverse_ir.fast_traverse_ir_top_down( _EXAMPLE_IR, [ir_data.NumericConstant], _record_file_name_and_constant, parameters={"constant_list": constants}) self.assertEqual(_count_entries([ ("t.emb", 0), ("t.emb", 8), ("t.emb", 8), ("t.emb", 8), ("t.emb", 16), ("t.emb", 24), ("t.emb", 32), ("t.emb", 16), ("t.emb", 32), ("t.emb", 320), ("t.emb", 1), ("t.emb", 1), ("t.emb", 1), ("", 64) ]), _count_entries(constants)) def test_type_definition_is_populated(self): constants = [] traverse_ir.fast_traverse_ir_top_down( _EXAMPLE_IR, [ir_data.NumericConstant], _record_kind_and_constant, parameters={"constant_list": constants}) self.assertEqual(_count_entries([ ("structure", 0), ("structure", 8), ("structure", 8), ("structure", 8), ("structure", 16), ("structure", 24), ("structure", 32), ("structure", 16), ("structure", 32), ("structure", 320), ("enumeration", 1), ("enumeration", 1), ("enumeration", 1), ("external", 64) ]), _count_entries(constants)) def test_keyword_args_dict_in_action(self): call_counts = {"populated": 0, "not": 0} def check_field_is_populated(node, **kwargs): del node # Unused. self.assertTrue(kwargs["field"]) call_counts["populated"] += 1 def check_field_is_not_populated(node, **kwargs): del node # Unused. self.assertFalse("field" in kwargs) call_counts["not"] += 1 traverse_ir.fast_traverse_ir_top_down( _EXAMPLE_IR, [ir_data.Field, ir_data.Type], check_field_is_populated) self.assertEqual(7, call_counts["populated"]) traverse_ir.fast_traverse_ir_top_down( _EXAMPLE_IR, [ir_data.Enum, ir_data.EnumValue], check_field_is_not_populated) self.assertEqual(2, call_counts["not"]) def test_pass_only_to_sub_nodes(self): constants = [] def pass_location_down(field): return { "location": (int(field.location.start.constant.value), int(field.location.size.constant.value)) } traverse_ir.fast_traverse_ir_top_down( _EXAMPLE_IR, [ir_data.NumericConstant], _record_location_parameter_and_constant, incidental_actions={ir_data.Field: pass_location_down}, parameters={"constant_list": constants, "location": None}) self.assertEqual(_count_entries([ ((0, 8), 0), ((0, 8), 8), ((8, 16), 8), ((8, 16), 8), ((8, 16), 16), ((24, 32), 24), ((24, 32), 32), ((32, 320), 16), ((32, 320), 32), ((32, 320), 320), (None, 1), (None, 1), (None, 1), (None, 64) ]), _count_entries(constants)) if __name__ == "__main__": unittest.main()