1# Copyright 2024 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Tests for util.ir_data_fields.""" 16 17import dataclasses 18import enum 19import sys 20from typing import Optional 21import unittest 22 23from compiler.util import ir_data 24from compiler.util import ir_data_fields 25 26 27class TestEnum(enum.Enum): 28 """Used to test python Enum handling.""" 29 30 UNKNOWN = 0 31 VALUE_1 = 1 32 VALUE_2 = 2 33 34 35@dataclasses.dataclass 36class Opaque(ir_data.Message): 37 """Used for testing data field helpers""" 38 39 40@dataclasses.dataclass 41class ClassWithUnion(ir_data.Message): 42 """Used for testing data field helpers""" 43 44 opaque: Optional[Opaque] = ir_data_fields.oneof_field("type") 45 integer: Optional[int] = ir_data_fields.oneof_field("type") 46 boolean: Optional[bool] = ir_data_fields.oneof_field("type") 47 enumeration: Optional[TestEnum] = ir_data_fields.oneof_field("type") 48 non_union_field: int = 0 49 50 51@dataclasses.dataclass 52class ClassWithTwoUnions(ir_data.Message): 53 """Used for testing data field helpers""" 54 55 opaque: Optional[Opaque] = ir_data_fields.oneof_field("type_1") 56 integer: Optional[int] = ir_data_fields.oneof_field("type_1") 57 boolean: Optional[bool] = ir_data_fields.oneof_field("type_2") 58 enumeration: Optional[TestEnum] = ir_data_fields.oneof_field("type_2") 59 non_union_field: int = 0 60 seq_field: list[int] = ir_data_fields.list_field(int) 61 62 63@dataclasses.dataclass 64class NestedClass(ir_data.Message): 65 """Used for testing data field helpers""" 66 67 one_union_class: Optional[ClassWithUnion] = None 68 two_union_class: Optional[ClassWithTwoUnions] = None 69 70 71@dataclasses.dataclass 72class ListCopyTestClass(ir_data.Message): 73 """Used to test behavior or extending a sequence.""" 74 75 non_union_field: int = 0 76 seq_field: list[int] = ir_data_fields.list_field(int) 77 78 79@dataclasses.dataclass 80class OneofFieldTest(ir_data.Message): 81 """Basic test class for oneof fields""" 82 83 int_field_1: Optional[int] = ir_data_fields.oneof_field("type_1") 84 int_field_2: Optional[int] = ir_data_fields.oneof_field("type_1") 85 normal_field: bool = True 86 87 88class OneOfTest(unittest.TestCase): 89 """Tests for the the various oneof field helpers""" 90 91 def test_field_attribute(self): 92 """Test the `oneof_field` helper.""" 93 test_field = ir_data_fields.oneof_field("type_1") 94 self.assertIsNotNone(test_field) 95 self.assertTrue(test_field.init) 96 self.assertIsInstance(test_field.default, ir_data_fields.OneOfField) 97 self.assertEqual(test_field.metadata.get("oneof"), "type_1") 98 99 def test_init_default(self): 100 """Test creating an instance with default fields""" 101 one_of_field_test = OneofFieldTest() 102 self.assertIsNone(one_of_field_test.int_field_1) 103 self.assertIsNone(one_of_field_test.int_field_2) 104 self.assertTrue(one_of_field_test.normal_field) 105 106 def test_init(self): 107 """Test creating an instance with non-default fields""" 108 one_of_field_test = OneofFieldTest(int_field_1=10, normal_field=False) 109 self.assertEqual(one_of_field_test.int_field_1, 10) 110 self.assertIsNone(one_of_field_test.int_field_2) 111 self.assertFalse(one_of_field_test.normal_field) 112 113 def test_set_oneof_field(self): 114 """Tests setting oneof fields causes others in the group to be unset""" 115 one_of_field_test = OneofFieldTest() 116 one_of_field_test.int_field_1 = 10 117 self.assertEqual(one_of_field_test.int_field_1, 10) 118 self.assertEqual(one_of_field_test.int_field_2, None) 119 one_of_field_test.int_field_2 = 20 120 self.assertEqual(one_of_field_test.int_field_1, None) 121 self.assertEqual(one_of_field_test.int_field_2, 20) 122 123 # Do it again 124 one_of_field_test.int_field_1 = 10 125 self.assertEqual(one_of_field_test.int_field_1, 10) 126 self.assertEqual(one_of_field_test.int_field_2, None) 127 one_of_field_test.int_field_2 = 20 128 self.assertEqual(one_of_field_test.int_field_1, None) 129 self.assertEqual(one_of_field_test.int_field_2, 20) 130 131 # Now create a new instance and make sure changes to it are not reflected 132 # on the original object. 133 one_of_field_test_2 = OneofFieldTest() 134 one_of_field_test_2.int_field_1 = 1000 135 self.assertEqual(one_of_field_test_2.int_field_1, 1000) 136 self.assertEqual(one_of_field_test_2.int_field_2, None) 137 self.assertEqual(one_of_field_test.int_field_1, None) 138 self.assertEqual(one_of_field_test.int_field_2, 20) 139 140 def test_set_to_none(self): 141 """Tests explicitly setting a oneof field to None""" 142 one_of_field_test = OneofFieldTest(int_field_1=10, normal_field=False) 143 self.assertEqual(one_of_field_test.int_field_1, 10) 144 self.assertIsNone(one_of_field_test.int_field_2) 145 self.assertFalse(one_of_field_test.normal_field) 146 147 # Clear the set fields 148 one_of_field_test.int_field_1 = None 149 self.assertIsNone(one_of_field_test.int_field_1) 150 self.assertIsNone(one_of_field_test.int_field_2) 151 self.assertFalse(one_of_field_test.normal_field) 152 153 # Set another field 154 one_of_field_test.int_field_2 = 200 155 self.assertIsNone(one_of_field_test.int_field_1) 156 self.assertEqual(one_of_field_test.int_field_2, 200) 157 self.assertFalse(one_of_field_test.normal_field) 158 159 # Clear the already unset field 160 one_of_field_test.int_field_1 = None 161 self.assertIsNone(one_of_field_test.int_field_1) 162 self.assertEqual(one_of_field_test.int_field_2, 200) 163 self.assertFalse(one_of_field_test.normal_field) 164 165 def test_oneof_specs(self): 166 """Tests the `oneof_field_specs` filter""" 167 expected = { 168 "int_field_1": ir_data_fields.make_field_spec( 169 "int_field_1", int, ir_data_fields.FieldContainer.OPTIONAL, "type_1" 170 ), 171 "int_field_2": ir_data_fields.make_field_spec( 172 "int_field_2", int, ir_data_fields.FieldContainer.OPTIONAL, "type_1" 173 ), 174 } 175 actual = ir_data_fields.IrDataclassSpecs.get_specs( 176 OneofFieldTest 177 ).oneof_field_specs 178 self.assertDictEqual(actual, expected) 179 180 def test_oneof_mappings(self): 181 """Tests the `oneof_mappings` function""" 182 expected = (("int_field_1", "type_1"), ("int_field_2", "type_1")) 183 actual = ir_data_fields.IrDataclassSpecs.get_specs( 184 OneofFieldTest 185 ).oneof_mappings 186 self.assertTupleEqual(actual, expected) 187 188 189class IrDataFieldsTest(unittest.TestCase): 190 """Tests misc methods in ir_data_fields""" 191 192 def test_copy(self): 193 """Tests copying a data class works as expected""" 194 union = ClassWithTwoUnions( 195 opaque=Opaque(), boolean=True, non_union_field=10, seq_field=[1, 2, 3] 196 ) 197 nested_class = NestedClass(two_union_class=union) 198 nested_class_copy = ir_data_fields.copy(nested_class) 199 self.assertIsNotNone(nested_class_copy) 200 self.assertIsNot(nested_class, nested_class_copy) 201 self.assertEqual(nested_class_copy, nested_class) 202 203 empty_copy = ir_data_fields.copy(None) 204 self.assertIsNone(empty_copy) 205 206 def test_copy_values_list(self): 207 """Tests that CopyValuesList copies values""" 208 data_list = ir_data_fields.CopyValuesList(ListCopyTestClass) 209 self.assertEqual(len(data_list), 0) 210 211 list_test = ListCopyTestClass(non_union_field=2, seq_field=[5, 6, 7]) 212 list_tests = [ir_data_fields.copy(list_test) for _ in range(4)] 213 data_list.extend(list_tests) 214 self.assertEqual(len(data_list), 4) 215 for i in data_list: 216 self.assertEqual(i, list_test) 217 218 def test_list_param_is_copied(self): 219 """Test that lists passed to constructors are converted to CopyValuesList""" 220 seq_field = [5, 6, 7] 221 list_test = ListCopyTestClass(non_union_field=2, seq_field=seq_field) 222 self.assertEqual(len(list_test.seq_field), len(seq_field)) 223 self.assertIsNot(list_test.seq_field, seq_field) 224 self.assertEqual(list_test.seq_field, seq_field) 225 self.assertTrue( 226 isinstance(list_test.seq_field, ir_data_fields.CopyValuesList) 227 ) 228 229 def test_copy_oneof(self): 230 """Tests copying an IR data class that has oneof fields.""" 231 oneof_test = OneofFieldTest() 232 oneof_test.int_field_1 = 10 233 oneof_test.normal_field = False 234 self.assertEqual(oneof_test.int_field_1, 10) 235 self.assertEqual(oneof_test.normal_field, False) 236 237 oneof_copy = ir_data_fields.copy(oneof_test) 238 self.assertIsNotNone(oneof_copy) 239 self.assertEqual(oneof_copy.int_field_1, 10) 240 self.assertIsNone(oneof_copy.int_field_2) 241 self.assertEqual(oneof_copy.normal_field, False) 242 243 oneof_copy.int_field_2 = 100 244 self.assertEqual(oneof_copy.int_field_2, 100) 245 self.assertIsNone(oneof_copy.int_field_1) 246 self.assertEqual(oneof_test.int_field_1, 10) 247 self.assertEqual(oneof_test.normal_field, False) 248 249 250ir_data_fields.cache_message_specs( 251 sys.modules[OneofFieldTest.__module__], ir_data.Message) 252 253if __name__ == "__main__": 254 unittest.main() 255