xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/dtypes_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.python.framework.dtypes."""
16
17from absl.testing import parameterized
18import numpy as np
19
20# pylint: disable=g-bad-import-order
21from tensorflow.python.framework import _dtypes
22# pylint: enable=g-bad-import-order
23
24from tensorflow.core.framework import types_pb2
25from tensorflow.core.function import trace_type
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import test_util
28from tensorflow.python.platform import googletest
29
30
31def _is_numeric_dtype_enum(datatype_enum):
32  non_numeric_dtypes = [
33      types_pb2.DT_VARIANT, types_pb2.DT_VARIANT_REF, types_pb2.DT_INVALID,
34      types_pb2.DT_RESOURCE, types_pb2.DT_RESOURCE_REF
35  ]
36  return datatype_enum not in non_numeric_dtypes
37
38
39class TypesTest(test_util.TensorFlowTestCase, parameterized.TestCase):
40
41  def testAllTypesConstructible(self):
42    for datatype_enum in types_pb2.DataType.values():
43      if datatype_enum == types_pb2.DT_INVALID:
44        continue
45      self.assertEqual(datatype_enum,
46                       dtypes.DType(datatype_enum).as_datatype_enum)
47
48  def testAllTypesConvertibleToDType(self):
49    for datatype_enum in types_pb2.DataType.values():
50      if datatype_enum == types_pb2.DT_INVALID:
51        continue
52      dt = dtypes.as_dtype(datatype_enum)
53      self.assertEqual(datatype_enum, dt.as_datatype_enum)
54
55  def testAllTypesConvertibleToNumpyDtype(self):
56    for datatype_enum in types_pb2.DataType.values():
57      if not _is_numeric_dtype_enum(datatype_enum):
58        continue
59      dtype = dtypes.as_dtype(datatype_enum)
60      numpy_dtype = dtype.as_numpy_dtype
61      _ = np.empty((1, 1, 1, 1), dtype=numpy_dtype)
62      if dtype.base_dtype != dtypes.bfloat16:
63        # NOTE(touts): Intentionally no way to feed a DT_BFLOAT16.
64        self.assertEqual(
65            dtypes.as_dtype(datatype_enum).base_dtype,
66            dtypes.as_dtype(numpy_dtype))
67
68  def testAllPybind11DTypeConvertibleToDType(self):
69    for datatype_enum in types_pb2.DataType.values():
70      if datatype_enum == types_pb2.DT_INVALID:
71        continue
72      dtype = _dtypes.DType(datatype_enum)
73      self.assertEqual(dtypes.as_dtype(datatype_enum), dtype)
74
75  def testInvalid(self):
76    with self.assertRaises(TypeError):
77      dtypes.DType(types_pb2.DT_INVALID)
78    with self.assertRaises(TypeError):
79      dtypes.as_dtype(types_pb2.DT_INVALID)
80
81  def testNumpyConversion(self):
82    self.assertIs(dtypes.float32, dtypes.as_dtype(np.float32))
83    self.assertIs(dtypes.float64, dtypes.as_dtype(np.float64))
84    self.assertIs(dtypes.int32, dtypes.as_dtype(np.int32))
85    self.assertIs(dtypes.int64, dtypes.as_dtype(np.int64))
86    self.assertIs(dtypes.uint8, dtypes.as_dtype(np.uint8))
87    self.assertIs(dtypes.uint16, dtypes.as_dtype(np.uint16))
88    self.assertIs(dtypes.int16, dtypes.as_dtype(np.int16))
89    self.assertIs(dtypes.int8, dtypes.as_dtype(np.int8))
90    self.assertIs(dtypes.complex64, dtypes.as_dtype(np.complex64))
91    self.assertIs(dtypes.complex128, dtypes.as_dtype(np.complex128))
92    self.assertIs(dtypes.string, dtypes.as_dtype(np.object_))
93    self.assertIs(dtypes.string,
94                  dtypes.as_dtype(np.array(["foo", "bar"]).dtype))
95    self.assertIs(dtypes.bool, dtypes.as_dtype(np.bool_))
96    with self.assertRaises(TypeError):
97      dtypes.as_dtype(np.dtype([("f1", np.uint), ("f2", np.int32)]))
98
99    class AnObject(object):
100      dtype = "f4"
101
102    self.assertIs(dtypes.float32, dtypes.as_dtype(AnObject))
103
104    class AnotherObject(object):
105      dtype = np.dtype(np.complex64)
106
107    self.assertIs(dtypes.complex64, dtypes.as_dtype(AnotherObject))
108
109  def testRealDtype(self):
110    for dtype in [
111        dtypes.float32, dtypes.float64, dtypes.bool, dtypes.uint8, dtypes.int8,
112        dtypes.int16, dtypes.int32, dtypes.int64
113    ]:
114      self.assertIs(dtype.real_dtype, dtype)
115    self.assertIs(dtypes.complex64.real_dtype, dtypes.float32)
116    self.assertIs(dtypes.complex128.real_dtype, dtypes.float64)
117
118  def testStringConversion(self):
119    self.assertIs(dtypes.float32, dtypes.as_dtype("float32"))
120    self.assertIs(dtypes.float64, dtypes.as_dtype("float64"))
121    self.assertIs(dtypes.int32, dtypes.as_dtype("int32"))
122    self.assertIs(dtypes.uint8, dtypes.as_dtype("uint8"))
123    self.assertIs(dtypes.uint16, dtypes.as_dtype("uint16"))
124    self.assertIs(dtypes.int16, dtypes.as_dtype("int16"))
125    self.assertIs(dtypes.int8, dtypes.as_dtype("int8"))
126    self.assertIs(dtypes.string, dtypes.as_dtype("string"))
127    self.assertIs(dtypes.complex64, dtypes.as_dtype("complex64"))
128    self.assertIs(dtypes.complex128, dtypes.as_dtype("complex128"))
129    self.assertIs(dtypes.int64, dtypes.as_dtype("int64"))
130    self.assertIs(dtypes.bool, dtypes.as_dtype("bool"))
131    self.assertIs(dtypes.qint8, dtypes.as_dtype("qint8"))
132    self.assertIs(dtypes.quint8, dtypes.as_dtype("quint8"))
133    self.assertIs(dtypes.qint32, dtypes.as_dtype("qint32"))
134    self.assertIs(dtypes.bfloat16, dtypes.as_dtype("bfloat16"))
135    self.assertIs(dtypes.float32_ref, dtypes.as_dtype("float32_ref"))
136    self.assertIs(dtypes.float64_ref, dtypes.as_dtype("float64_ref"))
137    self.assertIs(dtypes.int32_ref, dtypes.as_dtype("int32_ref"))
138    self.assertIs(dtypes.uint8_ref, dtypes.as_dtype("uint8_ref"))
139    self.assertIs(dtypes.int16_ref, dtypes.as_dtype("int16_ref"))
140    self.assertIs(dtypes.int8_ref, dtypes.as_dtype("int8_ref"))
141    self.assertIs(dtypes.string_ref, dtypes.as_dtype("string_ref"))
142    self.assertIs(dtypes.complex64_ref, dtypes.as_dtype("complex64_ref"))
143    self.assertIs(dtypes.complex128_ref, dtypes.as_dtype("complex128_ref"))
144    self.assertIs(dtypes.int64_ref, dtypes.as_dtype("int64_ref"))
145    self.assertIs(dtypes.bool_ref, dtypes.as_dtype("bool_ref"))
146    self.assertIs(dtypes.qint8_ref, dtypes.as_dtype("qint8_ref"))
147    self.assertIs(dtypes.quint8_ref, dtypes.as_dtype("quint8_ref"))
148    self.assertIs(dtypes.qint32_ref, dtypes.as_dtype("qint32_ref"))
149    self.assertIs(dtypes.bfloat16_ref, dtypes.as_dtype("bfloat16_ref"))
150    with self.assertRaises(TypeError):
151      dtypes.as_dtype("not_a_type")
152
153  def testDTypesHaveUniqueNames(self):
154    dtypez = []
155    names = set()
156    for datatype_enum in types_pb2.DataType.values():
157      if datatype_enum == types_pb2.DT_INVALID:
158        continue
159      dtype = dtypes.as_dtype(datatype_enum)
160      dtypez.append(dtype)
161      names.add(dtype.name)
162    self.assertEqual(len(dtypez), len(names))
163
164  def testIsInteger(self):
165    self.assertEqual(dtypes.as_dtype("int8").is_integer, True)
166    self.assertEqual(dtypes.as_dtype("int16").is_integer, True)
167    self.assertEqual(dtypes.as_dtype("int32").is_integer, True)
168    self.assertEqual(dtypes.as_dtype("int64").is_integer, True)
169    self.assertEqual(dtypes.as_dtype("uint8").is_integer, True)
170    self.assertEqual(dtypes.as_dtype("uint16").is_integer, True)
171    self.assertEqual(dtypes.as_dtype("complex64").is_integer, False)
172    self.assertEqual(dtypes.as_dtype("complex128").is_integer, False)
173    self.assertEqual(dtypes.as_dtype("float").is_integer, False)
174    self.assertEqual(dtypes.as_dtype("double").is_integer, False)
175    self.assertEqual(dtypes.as_dtype("string").is_integer, False)
176    self.assertEqual(dtypes.as_dtype("bool").is_integer, False)
177    self.assertEqual(dtypes.as_dtype("bfloat16").is_integer, False)
178    self.assertEqual(dtypes.as_dtype("qint8").is_integer, False)
179    self.assertEqual(dtypes.as_dtype("qint16").is_integer, False)
180    self.assertEqual(dtypes.as_dtype("qint32").is_integer, False)
181    self.assertEqual(dtypes.as_dtype("quint8").is_integer, False)
182    self.assertEqual(dtypes.as_dtype("quint16").is_integer, False)
183
184  def testIsFloating(self):
185    self.assertEqual(dtypes.as_dtype("int8").is_floating, False)
186    self.assertEqual(dtypes.as_dtype("int16").is_floating, False)
187    self.assertEqual(dtypes.as_dtype("int32").is_floating, False)
188    self.assertEqual(dtypes.as_dtype("int64").is_floating, False)
189    self.assertEqual(dtypes.as_dtype("uint8").is_floating, False)
190    self.assertEqual(dtypes.as_dtype("uint16").is_floating, False)
191    self.assertEqual(dtypes.as_dtype("complex64").is_floating, False)
192    self.assertEqual(dtypes.as_dtype("complex128").is_floating, False)
193    self.assertEqual(dtypes.as_dtype("float32").is_floating, True)
194    self.assertEqual(dtypes.as_dtype("float64").is_floating, True)
195    self.assertEqual(dtypes.as_dtype("string").is_floating, False)
196    self.assertEqual(dtypes.as_dtype("bool").is_floating, False)
197    self.assertEqual(dtypes.as_dtype("bfloat16").is_floating, True)
198    self.assertEqual(dtypes.as_dtype("qint8").is_floating, False)
199    self.assertEqual(dtypes.as_dtype("qint16").is_floating, False)
200    self.assertEqual(dtypes.as_dtype("qint32").is_floating, False)
201    self.assertEqual(dtypes.as_dtype("quint8").is_floating, False)
202    self.assertEqual(dtypes.as_dtype("quint16").is_floating, False)
203
204  def testIsComplex(self):
205    self.assertEqual(dtypes.as_dtype("int8").is_complex, False)
206    self.assertEqual(dtypes.as_dtype("int16").is_complex, False)
207    self.assertEqual(dtypes.as_dtype("int32").is_complex, False)
208    self.assertEqual(dtypes.as_dtype("int64").is_complex, False)
209    self.assertEqual(dtypes.as_dtype("uint8").is_complex, False)
210    self.assertEqual(dtypes.as_dtype("uint16").is_complex, False)
211    self.assertEqual(dtypes.as_dtype("complex64").is_complex, True)
212    self.assertEqual(dtypes.as_dtype("complex128").is_complex, True)
213    self.assertEqual(dtypes.as_dtype("float32").is_complex, False)
214    self.assertEqual(dtypes.as_dtype("float64").is_complex, False)
215    self.assertEqual(dtypes.as_dtype("string").is_complex, False)
216    self.assertEqual(dtypes.as_dtype("bool").is_complex, False)
217    self.assertEqual(dtypes.as_dtype("bfloat16").is_complex, False)
218    self.assertEqual(dtypes.as_dtype("qint8").is_complex, False)
219    self.assertEqual(dtypes.as_dtype("qint16").is_complex, False)
220    self.assertEqual(dtypes.as_dtype("qint32").is_complex, False)
221    self.assertEqual(dtypes.as_dtype("quint8").is_complex, False)
222    self.assertEqual(dtypes.as_dtype("quint16").is_complex, False)
223
224  def testIsUnsigned(self):
225    self.assertEqual(dtypes.as_dtype("int8").is_unsigned, False)
226    self.assertEqual(dtypes.as_dtype("int16").is_unsigned, False)
227    self.assertEqual(dtypes.as_dtype("int32").is_unsigned, False)
228    self.assertEqual(dtypes.as_dtype("int64").is_unsigned, False)
229    self.assertEqual(dtypes.as_dtype("uint8").is_unsigned, True)
230    self.assertEqual(dtypes.as_dtype("uint16").is_unsigned, True)
231    self.assertEqual(dtypes.as_dtype("float32").is_unsigned, False)
232    self.assertEqual(dtypes.as_dtype("float64").is_unsigned, False)
233    self.assertEqual(dtypes.as_dtype("bool").is_unsigned, False)
234    self.assertEqual(dtypes.as_dtype("string").is_unsigned, False)
235    self.assertEqual(dtypes.as_dtype("complex64").is_unsigned, False)
236    self.assertEqual(dtypes.as_dtype("complex128").is_unsigned, False)
237    self.assertEqual(dtypes.as_dtype("bfloat16").is_unsigned, False)
238    self.assertEqual(dtypes.as_dtype("qint8").is_unsigned, False)
239    self.assertEqual(dtypes.as_dtype("qint16").is_unsigned, False)
240    self.assertEqual(dtypes.as_dtype("qint32").is_unsigned, False)
241    self.assertEqual(dtypes.as_dtype("quint8").is_unsigned, False)
242    self.assertEqual(dtypes.as_dtype("quint16").is_unsigned, False)
243
244  def testMinMax(self):
245    # make sure min/max evaluates for all data types that have min/max
246    for datatype_enum in types_pb2.DataType.values():
247      if not _is_numeric_dtype_enum(datatype_enum):
248        continue
249      dtype = dtypes.as_dtype(datatype_enum)
250      numpy_dtype = dtype.as_numpy_dtype
251
252      # ignore types for which there are no minimum/maximum (or we cannot
253      # compute it, such as for the q* types)
254      if (dtype.is_quantized or dtype.base_dtype == dtypes.bool or
255          dtype.base_dtype == dtypes.string or
256          dtype.base_dtype == dtypes.complex64 or
257          dtype.base_dtype == dtypes.complex128):
258        continue
259
260      print("%s: %s - %s" % (dtype, dtype.min, dtype.max))
261
262      # check some values that are known
263      if numpy_dtype == np.bool_:
264        self.assertEqual(dtype.min, 0)
265        self.assertEqual(dtype.max, 1)
266      if numpy_dtype == np.int8:
267        self.assertEqual(dtype.min, -128)
268        self.assertEqual(dtype.max, 127)
269      if numpy_dtype == np.int16:
270        self.assertEqual(dtype.min, -32768)
271        self.assertEqual(dtype.max, 32767)
272      if numpy_dtype == np.int32:
273        self.assertEqual(dtype.min, -2147483648)
274        self.assertEqual(dtype.max, 2147483647)
275      if numpy_dtype == np.int64:
276        self.assertEqual(dtype.min, -9223372036854775808)
277        self.assertEqual(dtype.max, 9223372036854775807)
278      if numpy_dtype == np.uint8:
279        self.assertEqual(dtype.min, 0)
280        self.assertEqual(dtype.max, 255)
281      if numpy_dtype == np.uint16:
282        if dtype == dtypes.uint16:
283          self.assertEqual(dtype.min, 0)
284          self.assertEqual(dtype.max, 65535)
285        elif dtype == dtypes.bfloat16:
286          self.assertEqual(dtype.min, 0)
287          self.assertEqual(dtype.max, 4294967295)
288      if numpy_dtype == np.uint32:
289        self.assertEqual(dtype.min, 0)
290        self.assertEqual(dtype.max, 4294967295)
291      if numpy_dtype == np.uint64:
292        self.assertEqual(dtype.min, 0)
293        self.assertEqual(dtype.max, 18446744073709551615)
294      if numpy_dtype in (np.float16, np.float32, np.float64):
295        self.assertEqual(dtype.min, np.finfo(numpy_dtype).min)
296        self.assertEqual(dtype.max, np.finfo(numpy_dtype).max)
297      if numpy_dtype == dtypes.bfloat16.as_numpy_dtype:
298        self.assertEqual(dtype.min, float.fromhex("-0x1.FEp127"))
299        self.assertEqual(dtype.max, float.fromhex("0x1.FEp127"))
300
301  def testLimitsUndefinedError(self):
302    with self.assertRaises(ValueError):
303      dtypes.string.limits()
304
305  def testRepr(self):
306    self.skipTest("b/142725777")
307    for enum, name in dtypes._TYPE_TO_STRING.items():
308      if enum > 100:
309        continue
310      dtype = dtypes.DType(enum)
311      self.assertEqual(repr(dtype), "tf." + name)
312      import tensorflow as tf
313      dtype2 = eval(repr(dtype))
314      self.assertEqual(type(dtype2), dtypes.DType)
315      self.assertEqual(dtype, dtype2)
316
317  def testEqWithNonTFTypes(self):
318    self.assertNotEqual(dtypes.int32, int)
319    self.assertNotEqual(dtypes.float64, 2.1)
320
321  def testPythonLongConversion(self):
322    self.assertIs(dtypes.int64, dtypes.as_dtype(np.array(2**32).dtype))
323
324  def testPythonTypesConversion(self):
325    self.assertIs(dtypes.float32, dtypes.as_dtype(float))
326    self.assertIs(dtypes.bool, dtypes.as_dtype(bool))
327
328  def testReduce(self):
329    for enum in dtypes._TYPE_TO_STRING:
330      dtype = dtypes.DType(enum)
331      ctor, args = dtype.__reduce__()
332      self.assertEqual(ctor, dtypes.as_dtype)
333      self.assertEqual(args, (dtype.name,))
334      reconstructed = ctor(*args)
335      self.assertEqual(reconstructed, dtype)
336
337  def testAsDtypeInvalidArgument(self):
338    with self.assertRaises(TypeError):
339      dtypes.as_dtype((dtypes.int32, dtypes.float32))
340
341  def testAsDtypeReturnsInternedVersion(self):
342    dt = dtypes.DType(types_pb2.DT_VARIANT)
343    self.assertIs(dtypes.as_dtype(dt), dtypes.variant)
344
345  def testDTypeSubtypes(self):
346    self.assertTrue(dtypes.string.is_subtype_of(dtypes.string))
347    self.assertFalse(dtypes.string.is_subtype_of(dtypes.uint32))
348    self.assertTrue(dtypes.uint64.is_subtype_of(dtypes.uint64))
349
350  def testDTypeSupertypes(self):
351    self.assertEqual(dtypes.string,
352                     dtypes.string.most_specific_common_supertype([]))
353    self.assertEqual(
354        dtypes.string,
355        dtypes.string.most_specific_common_supertype([dtypes.string]))
356    self.assertIsNone(
357        dtypes.string.most_specific_common_supertype([dtypes.uint32]))
358
359  @parameterized.parameters(*tuple(dtype for dtype in dtypes.TF_VALUE_DTYPES))
360  def testDTypeSerialization(self, dtype):
361    self.assertEqual(trace_type.deserialize(trace_type.serialize(dtype)), dtype)
362
363
364if __name__ == "__main__":
365  googletest.main()
366