xref: /aosp_15_r20/external/emboss/compiler/util/ir_data_fields_test.py (revision 99e0aae7469b87d12f0ad23e61142c2d74c1ef70)
1# Copyright 2024 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Tests for util.ir_data_fields."""
16
17import dataclasses
18import enum
19import sys
20from typing import Optional
21import unittest
22
23from compiler.util import ir_data
24from compiler.util import ir_data_fields
25
26
27class TestEnum(enum.Enum):
28  """Used to test python Enum handling."""
29
30  UNKNOWN = 0
31  VALUE_1 = 1
32  VALUE_2 = 2
33
34
35@dataclasses.dataclass
36class Opaque(ir_data.Message):
37  """Used for testing data field helpers"""
38
39
40@dataclasses.dataclass
41class ClassWithUnion(ir_data.Message):
42  """Used for testing data field helpers"""
43
44  opaque: Optional[Opaque] = ir_data_fields.oneof_field("type")
45  integer: Optional[int] = ir_data_fields.oneof_field("type")
46  boolean: Optional[bool] = ir_data_fields.oneof_field("type")
47  enumeration: Optional[TestEnum] = ir_data_fields.oneof_field("type")
48  non_union_field: int = 0
49
50
51@dataclasses.dataclass
52class ClassWithTwoUnions(ir_data.Message):
53  """Used for testing data field helpers"""
54
55  opaque: Optional[Opaque] = ir_data_fields.oneof_field("type_1")
56  integer: Optional[int] = ir_data_fields.oneof_field("type_1")
57  boolean: Optional[bool] = ir_data_fields.oneof_field("type_2")
58  enumeration: Optional[TestEnum] = ir_data_fields.oneof_field("type_2")
59  non_union_field: int = 0
60  seq_field: list[int] = ir_data_fields.list_field(int)
61
62
63@dataclasses.dataclass
64class NestedClass(ir_data.Message):
65  """Used for testing data field helpers"""
66
67  one_union_class: Optional[ClassWithUnion] = None
68  two_union_class: Optional[ClassWithTwoUnions] = None
69
70
71@dataclasses.dataclass
72class ListCopyTestClass(ir_data.Message):
73  """Used to test behavior or extending a sequence."""
74
75  non_union_field: int = 0
76  seq_field: list[int] = ir_data_fields.list_field(int)
77
78
79@dataclasses.dataclass
80class OneofFieldTest(ir_data.Message):
81  """Basic test class for oneof fields"""
82
83  int_field_1: Optional[int] = ir_data_fields.oneof_field("type_1")
84  int_field_2: Optional[int] = ir_data_fields.oneof_field("type_1")
85  normal_field: bool = True
86
87
88class OneOfTest(unittest.TestCase):
89  """Tests for the the various oneof field helpers"""
90
91  def test_field_attribute(self):
92    """Test the `oneof_field` helper."""
93    test_field = ir_data_fields.oneof_field("type_1")
94    self.assertIsNotNone(test_field)
95    self.assertTrue(test_field.init)
96    self.assertIsInstance(test_field.default, ir_data_fields.OneOfField)
97    self.assertEqual(test_field.metadata.get("oneof"), "type_1")
98
99  def test_init_default(self):
100    """Test creating an instance with default fields"""
101    one_of_field_test = OneofFieldTest()
102    self.assertIsNone(one_of_field_test.int_field_1)
103    self.assertIsNone(one_of_field_test.int_field_2)
104    self.assertTrue(one_of_field_test.normal_field)
105
106  def test_init(self):
107    """Test creating an instance with non-default fields"""
108    one_of_field_test = OneofFieldTest(int_field_1=10, normal_field=False)
109    self.assertEqual(one_of_field_test.int_field_1, 10)
110    self.assertIsNone(one_of_field_test.int_field_2)
111    self.assertFalse(one_of_field_test.normal_field)
112
113  def test_set_oneof_field(self):
114    """Tests setting oneof fields causes others in the group to be unset"""
115    one_of_field_test = OneofFieldTest()
116    one_of_field_test.int_field_1 = 10
117    self.assertEqual(one_of_field_test.int_field_1, 10)
118    self.assertEqual(one_of_field_test.int_field_2, None)
119    one_of_field_test.int_field_2 = 20
120    self.assertEqual(one_of_field_test.int_field_1, None)
121    self.assertEqual(one_of_field_test.int_field_2, 20)
122
123    # Do it again
124    one_of_field_test.int_field_1 = 10
125    self.assertEqual(one_of_field_test.int_field_1, 10)
126    self.assertEqual(one_of_field_test.int_field_2, None)
127    one_of_field_test.int_field_2 = 20
128    self.assertEqual(one_of_field_test.int_field_1, None)
129    self.assertEqual(one_of_field_test.int_field_2, 20)
130
131    # Now create a new instance and make sure changes to it are not reflected
132    # on the original object.
133    one_of_field_test_2 = OneofFieldTest()
134    one_of_field_test_2.int_field_1 = 1000
135    self.assertEqual(one_of_field_test_2.int_field_1, 1000)
136    self.assertEqual(one_of_field_test_2.int_field_2, None)
137    self.assertEqual(one_of_field_test.int_field_1, None)
138    self.assertEqual(one_of_field_test.int_field_2, 20)
139
140  def test_set_to_none(self):
141    """Tests explicitly setting a oneof field to None"""
142    one_of_field_test = OneofFieldTest(int_field_1=10, normal_field=False)
143    self.assertEqual(one_of_field_test.int_field_1, 10)
144    self.assertIsNone(one_of_field_test.int_field_2)
145    self.assertFalse(one_of_field_test.normal_field)
146
147    # Clear the set fields
148    one_of_field_test.int_field_1 = None
149    self.assertIsNone(one_of_field_test.int_field_1)
150    self.assertIsNone(one_of_field_test.int_field_2)
151    self.assertFalse(one_of_field_test.normal_field)
152
153    # Set another field
154    one_of_field_test.int_field_2 = 200
155    self.assertIsNone(one_of_field_test.int_field_1)
156    self.assertEqual(one_of_field_test.int_field_2, 200)
157    self.assertFalse(one_of_field_test.normal_field)
158
159    # Clear the already unset field
160    one_of_field_test.int_field_1 = None
161    self.assertIsNone(one_of_field_test.int_field_1)
162    self.assertEqual(one_of_field_test.int_field_2, 200)
163    self.assertFalse(one_of_field_test.normal_field)
164
165  def test_oneof_specs(self):
166    """Tests the `oneof_field_specs` filter"""
167    expected = {
168        "int_field_1": ir_data_fields.make_field_spec(
169            "int_field_1", int, ir_data_fields.FieldContainer.OPTIONAL, "type_1"
170        ),
171        "int_field_2": ir_data_fields.make_field_spec(
172            "int_field_2", int, ir_data_fields.FieldContainer.OPTIONAL, "type_1"
173        ),
174    }
175    actual = ir_data_fields.IrDataclassSpecs.get_specs(
176        OneofFieldTest
177    ).oneof_field_specs
178    self.assertDictEqual(actual, expected)
179
180  def test_oneof_mappings(self):
181    """Tests the `oneof_mappings` function"""
182    expected = (("int_field_1", "type_1"), ("int_field_2", "type_1"))
183    actual = ir_data_fields.IrDataclassSpecs.get_specs(
184        OneofFieldTest
185    ).oneof_mappings
186    self.assertTupleEqual(actual, expected)
187
188
189class IrDataFieldsTest(unittest.TestCase):
190  """Tests misc methods in ir_data_fields"""
191
192  def test_copy(self):
193    """Tests copying a data class works as expected"""
194    union = ClassWithTwoUnions(
195        opaque=Opaque(), boolean=True, non_union_field=10, seq_field=[1, 2, 3]
196    )
197    nested_class = NestedClass(two_union_class=union)
198    nested_class_copy = ir_data_fields.copy(nested_class)
199    self.assertIsNotNone(nested_class_copy)
200    self.assertIsNot(nested_class, nested_class_copy)
201    self.assertEqual(nested_class_copy, nested_class)
202
203    empty_copy = ir_data_fields.copy(None)
204    self.assertIsNone(empty_copy)
205
206  def test_copy_values_list(self):
207    """Tests that CopyValuesList copies values"""
208    data_list = ir_data_fields.CopyValuesList(ListCopyTestClass)
209    self.assertEqual(len(data_list), 0)
210
211    list_test = ListCopyTestClass(non_union_field=2, seq_field=[5, 6, 7])
212    list_tests = [ir_data_fields.copy(list_test) for _ in range(4)]
213    data_list.extend(list_tests)
214    self.assertEqual(len(data_list), 4)
215    for i in data_list:
216      self.assertEqual(i, list_test)
217
218  def test_list_param_is_copied(self):
219    """Test that lists passed to constructors are converted to CopyValuesList"""
220    seq_field = [5, 6, 7]
221    list_test = ListCopyTestClass(non_union_field=2, seq_field=seq_field)
222    self.assertEqual(len(list_test.seq_field), len(seq_field))
223    self.assertIsNot(list_test.seq_field, seq_field)
224    self.assertEqual(list_test.seq_field, seq_field)
225    self.assertTrue(
226        isinstance(list_test.seq_field, ir_data_fields.CopyValuesList)
227    )
228
229  def test_copy_oneof(self):
230    """Tests copying an IR data class that has oneof fields."""
231    oneof_test = OneofFieldTest()
232    oneof_test.int_field_1 = 10
233    oneof_test.normal_field = False
234    self.assertEqual(oneof_test.int_field_1, 10)
235    self.assertEqual(oneof_test.normal_field, False)
236
237    oneof_copy = ir_data_fields.copy(oneof_test)
238    self.assertIsNotNone(oneof_copy)
239    self.assertEqual(oneof_copy.int_field_1, 10)
240    self.assertIsNone(oneof_copy.int_field_2)
241    self.assertEqual(oneof_copy.normal_field, False)
242
243    oneof_copy.int_field_2 = 100
244    self.assertEqual(oneof_copy.int_field_2, 100)
245    self.assertIsNone(oneof_copy.int_field_1)
246    self.assertEqual(oneof_test.int_field_1, 10)
247    self.assertEqual(oneof_test.normal_field, False)
248
249
250ir_data_fields.cache_message_specs(
251  sys.modules[OneofFieldTest.__module__], ir_data.Message)
252
253if __name__ == "__main__":
254  unittest.main()
255