xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/keras_parameterized.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for unit-testing Keras."""
16
17import collections
18import functools
19import itertools
20import unittest
21
22from absl.testing import parameterized
23
24from tensorflow.python import keras
25from tensorflow.python import tf2
26from tensorflow.python.eager import context
27from tensorflow.python.framework import ops
28from tensorflow.python.keras import testing_utils
29from tensorflow.python.platform import test
30from tensorflow.python.util import nest
31
32try:
33  import h5py  # pylint:disable=g-import-not-at-top
34except ImportError:
35  h5py = None
36
37
38class TestCase(test.TestCase, parameterized.TestCase):
39
40  def tearDown(self):
41    keras.backend.clear_session()
42    super(TestCase, self).tearDown()
43
44
45def run_with_all_saved_model_formats(
46    test_or_class=None,
47    exclude_formats=None):
48  """Execute the decorated test with all Keras saved model formats).
49
50  This decorator is intended to be applied either to individual test methods in
51  a `keras_parameterized.TestCase` class, or directly to a test class that
52  extends it. Doing so will cause the contents of the individual test
53  method (or all test methods in the class) to be executed multiple times - once
54  for each Keras saved model format.
55
56  The Keras saved model formats include:
57  1. HDF5: 'h5'
58  2. SavedModel: 'tf'
59
60  Note: if stacking this decorator with absl.testing's parameterized decorators,
61  those should be at the bottom of the stack.
62
63  Various methods in `testing_utils` to get file path for saved models will
64  auto-generate a string of the two saved model formats. This allows unittests
65  to confirm the equivalence between the two Keras saved model formats.
66
67  For example, consider the following unittest:
68
69  ```python
70  class MyTests(testing_utils.KerasTestCase):
71
72    @testing_utils.run_with_all_saved_model_formats
73    def test_foo(self):
74      save_format = testing_utils.get_save_format()
75      saved_model_dir = '/tmp/saved_model/'
76      model = keras.models.Sequential()
77      model.add(keras.layers.Dense(2, input_shape=(3,)))
78      model.add(keras.layers.Dense(3))
79      model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
80
81      keras.models.save_model(model, saved_model_dir, save_format=save_format)
82      model = keras.models.load_model(saved_model_dir)
83
84  if __name__ == "__main__":
85    tf.test.main()
86  ```
87
88  This test tries to save the model into the formats of 'hdf5', 'h5', 'keras',
89  'tensorflow', and 'tf'.
90
91  We can also annotate the whole class if we want this to apply to all tests in
92  the class:
93  ```python
94  @testing_utils.run_with_all_saved_model_formats
95  class MyTests(testing_utils.KerasTestCase):
96
97    def test_foo(self):
98      save_format = testing_utils.get_save_format()
99      saved_model_dir = '/tmp/saved_model/'
100      model = keras.models.Sequential()
101      model.add(keras.layers.Dense(2, input_shape=(3,)))
102      model.add(keras.layers.Dense(3))
103      model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
104
105      keras.models.save_model(model, saved_model_dir, save_format=save_format)
106      model = tf.keras.models.load_model(saved_model_dir)
107
108  if __name__ == "__main__":
109    tf.test.main()
110  ```
111
112  Args:
113    test_or_class: test method or class to be annotated. If None,
114      this method returns a decorator that can be applied to a test method or
115      test class. If it is not None this returns the decorator applied to the
116      test or class.
117    exclude_formats: A collection of Keras saved model formats to not run.
118      (May also be a single format not wrapped in a collection).
119      Defaults to None.
120
121  Returns:
122    Returns a decorator that will run the decorated test method multiple times:
123    once for each desired Keras saved model format.
124
125  Raises:
126    ImportError: If abseil parameterized is not installed or not included as
127      a target dependency.
128  """
129  # Exclude h5 save format if H5py isn't available.
130  if h5py is None:
131    exclude_formats.append(['h5'])
132  saved_model_formats = ['h5', 'tf', 'tf_no_traces']
133  params = [('_%s' % saved_format, saved_format)
134            for saved_format in saved_model_formats
135            if saved_format not in nest.flatten(exclude_formats)]
136
137  def single_method_decorator(f):
138    """Decorator that constructs the test cases."""
139    # Use named_parameters so it can be individually run from the command line
140    @parameterized.named_parameters(*params)
141    @functools.wraps(f)
142    def decorated(self, saved_format, *args, **kwargs):
143      """A run of a single test case w/ the specified model type."""
144      if saved_format == 'h5':
145        _test_h5_saved_model_format(f, self, *args, **kwargs)
146      elif saved_format == 'tf':
147        _test_tf_saved_model_format(f, self, *args, **kwargs)
148      elif saved_format == 'tf_no_traces':
149        _test_tf_saved_model_format_no_traces(f, self, *args, **kwargs)
150      else:
151        raise ValueError('Unknown model type: %s' % (saved_format,))
152    return decorated
153
154  return _test_or_class_decorator(test_or_class, single_method_decorator)
155
156
157def _test_h5_saved_model_format(f, test_or_class, *args, **kwargs):
158  with testing_utils.saved_model_format_scope('h5'):
159    f(test_or_class, *args, **kwargs)
160
161
162def _test_tf_saved_model_format(f, test_or_class, *args, **kwargs):
163  with testing_utils.saved_model_format_scope('tf'):
164    f(test_or_class, *args, **kwargs)
165
166
167def _test_tf_saved_model_format_no_traces(f, test_or_class, *args, **kwargs):
168  with testing_utils.saved_model_format_scope('tf', save_traces=False):
169    f(test_or_class, *args, **kwargs)
170
171
172def run_with_all_weight_formats(test_or_class=None, exclude_formats=None):
173  """Runs all tests with the supported formats for saving weights."""
174  exclude_formats = exclude_formats or []
175  exclude_formats.append('tf_no_traces')  # Only applies to saving models
176  return run_with_all_saved_model_formats(test_or_class, exclude_formats)
177
178
179# TODO(kaftan): Possibly enable 'subclass_custom_build' when tests begin to pass
180# it. Or perhaps make 'subclass' always use a custom build method.
181def run_with_all_model_types(
182    test_or_class=None,
183    exclude_models=None):
184  """Execute the decorated test with all Keras model types.
185
186  This decorator is intended to be applied either to individual test methods in
187  a `keras_parameterized.TestCase` class, or directly to a test class that
188  extends it. Doing so will cause the contents of the individual test
189  method (or all test methods in the class) to be executed multiple times - once
190  for each Keras model type.
191
192  The Keras model types are: ['functional', 'subclass', 'sequential']
193
194  Note: if stacking this decorator with absl.testing's parameterized decorators,
195  those should be at the bottom of the stack.
196
197  Various methods in `testing_utils` to get models will auto-generate a model
198  of the currently active Keras model type. This allows unittests to confirm
199  the equivalence between different Keras models.
200
201  For example, consider the following unittest:
202
203  ```python
204  class MyTests(testing_utils.KerasTestCase):
205
206    @testing_utils.run_with_all_model_types(
207      exclude_models = ['sequential'])
208    def test_foo(self):
209      model = testing_utils.get_small_mlp(1, 4, input_dim=3)
210      optimizer = RMSPropOptimizer(learning_rate=0.001)
211      loss = 'mse'
212      metrics = ['mae']
213      model.compile(optimizer, loss, metrics=metrics)
214
215      inputs = np.zeros((10, 3))
216      targets = np.zeros((10, 4))
217      dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
218      dataset = dataset.repeat(100)
219      dataset = dataset.batch(10)
220
221      model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
222
223  if __name__ == "__main__":
224    tf.test.main()
225  ```
226
227  This test tries building a small mlp as both a functional model and as a
228  subclass model.
229
230  We can also annotate the whole class if we want this to apply to all tests in
231  the class:
232  ```python
233  @testing_utils.run_with_all_model_types(exclude_models = ['sequential'])
234  class MyTests(testing_utils.KerasTestCase):
235
236    def test_foo(self):
237      model = testing_utils.get_small_mlp(1, 4, input_dim=3)
238      optimizer = RMSPropOptimizer(learning_rate=0.001)
239      loss = 'mse'
240      metrics = ['mae']
241      model.compile(optimizer, loss, metrics=metrics)
242
243      inputs = np.zeros((10, 3))
244      targets = np.zeros((10, 4))
245      dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
246      dataset = dataset.repeat(100)
247      dataset = dataset.batch(10)
248
249      model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
250
251  if __name__ == "__main__":
252    tf.test.main()
253  ```
254
255
256  Args:
257    test_or_class: test method or class to be annotated. If None,
258      this method returns a decorator that can be applied to a test method or
259      test class. If it is not None this returns the decorator applied to the
260      test or class.
261    exclude_models: A collection of Keras model types to not run.
262      (May also be a single model type not wrapped in a collection).
263      Defaults to None.
264
265  Returns:
266    Returns a decorator that will run the decorated test method multiple times:
267    once for each desired Keras model type.
268
269  Raises:
270    ImportError: If abseil parameterized is not installed or not included as
271      a target dependency.
272  """
273  model_types = ['functional', 'subclass', 'sequential']
274  params = [('_%s' % model, model) for model in model_types
275            if model not in nest.flatten(exclude_models)]
276
277  def single_method_decorator(f):
278    """Decorator that constructs the test cases."""
279    # Use named_parameters so it can be individually run from the command line
280    @parameterized.named_parameters(*params)
281    @functools.wraps(f)
282    def decorated(self, model_type, *args, **kwargs):
283      """A run of a single test case w/ the specified model type."""
284      if model_type == 'functional':
285        _test_functional_model_type(f, self, *args, **kwargs)
286      elif model_type == 'subclass':
287        _test_subclass_model_type(f, self, *args, **kwargs)
288      elif model_type == 'sequential':
289        _test_sequential_model_type(f, self, *args, **kwargs)
290      else:
291        raise ValueError('Unknown model type: %s' % (model_type,))
292    return decorated
293
294  return _test_or_class_decorator(test_or_class, single_method_decorator)
295
296
297def _test_functional_model_type(f, test_or_class, *args, **kwargs):
298  with testing_utils.model_type_scope('functional'):
299    f(test_or_class, *args, **kwargs)
300
301
302def _test_subclass_model_type(f, test_or_class, *args, **kwargs):
303  with testing_utils.model_type_scope('subclass'):
304    f(test_or_class, *args, **kwargs)
305
306
307def _test_sequential_model_type(f, test_or_class, *args, **kwargs):
308  with testing_utils.model_type_scope('sequential'):
309    f(test_or_class, *args, **kwargs)
310
311
312def run_all_keras_modes(test_or_class=None,
313                        config=None,
314                        always_skip_v1=False,
315                        always_skip_eager=False,
316                        **kwargs):
317  """Execute the decorated test with all keras execution modes.
318
319  This decorator is intended to be applied either to individual test methods in
320  a `keras_parameterized.TestCase` class, or directly to a test class that
321  extends it. Doing so will cause the contents of the individual test
322  method (or all test methods in the class) to be executed multiple times -
323  once executing in legacy graph mode, once running eagerly and with
324  `should_run_eagerly` returning True, and once running eagerly with
325  `should_run_eagerly` returning False.
326
327  If Tensorflow v2 behavior is enabled, legacy graph mode will be skipped, and
328  the test will only run twice.
329
330  Note: if stacking this decorator with absl.testing's parameterized decorators,
331  those should be at the bottom of the stack.
332
333  For example, consider the following unittest:
334
335  ```python
336  class MyTests(testing_utils.KerasTestCase):
337
338    @testing_utils.run_all_keras_modes
339    def test_foo(self):
340      model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
341      optimizer = RMSPropOptimizer(learning_rate=0.001)
342      loss = 'mse'
343      metrics = ['mae']
344      model.compile(
345          optimizer, loss, metrics=metrics,
346          run_eagerly=testing_utils.should_run_eagerly())
347
348      inputs = np.zeros((10, 3))
349      targets = np.zeros((10, 4))
350      dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
351      dataset = dataset.repeat(100)
352      dataset = dataset.batch(10)
353
354      model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
355
356  if __name__ == "__main__":
357    tf.test.main()
358  ```
359
360  This test will try compiling & fitting the small functional mlp using all
361  three Keras execution modes.
362
363  Args:
364    test_or_class: test method or class to be annotated. If None,
365      this method returns a decorator that can be applied to a test method or
366      test class. If it is not None this returns the decorator applied to the
367      test or class.
368    config: An optional config_pb2.ConfigProto to use to configure the
369      session when executing graphs.
370    always_skip_v1: If True, does not try running the legacy graph mode even
371      when Tensorflow v2 behavior is not enabled.
372    always_skip_eager: If True, does not execute the decorated test
373      with eager execution modes.
374    **kwargs: Additional kwargs for configuring tests for
375     in-progress Keras behaviors/ refactorings that we haven't fully
376     rolled out yet
377
378  Returns:
379    Returns a decorator that will run the decorated test method multiple times.
380
381  Raises:
382    ImportError: If abseil parameterized is not installed or not included as
383      a target dependency.
384  """
385  if kwargs:
386    raise ValueError('Unrecognized keyword args: {}'.format(kwargs))
387
388  params = [('_v2_function', 'v2_function')]
389  if not always_skip_eager:
390    params.append(('_v2_eager', 'v2_eager'))
391  if not (always_skip_v1 or tf2.enabled()):
392    params.append(('_v1_session', 'v1_session'))
393
394  def single_method_decorator(f):
395    """Decorator that constructs the test cases."""
396
397    # Use named_parameters so it can be individually run from the command line
398    @parameterized.named_parameters(*params)
399    @functools.wraps(f)
400    def decorated(self, run_mode, *args, **kwargs):
401      """A run of a single test case w/ specified run mode."""
402      if run_mode == 'v1_session':
403        _v1_session_test(f, self, config, *args, **kwargs)
404      elif run_mode == 'v2_eager':
405        _v2_eager_test(f, self, *args, **kwargs)
406      elif run_mode == 'v2_function':
407        _v2_function_test(f, self, *args, **kwargs)
408      else:
409        return ValueError('Unknown run mode %s' % run_mode)
410
411    return decorated
412
413  return _test_or_class_decorator(test_or_class, single_method_decorator)
414
415
416def _v1_session_test(f, test_or_class, config, *args, **kwargs):
417  with ops.get_default_graph().as_default():
418    with testing_utils.run_eagerly_scope(False):
419      with test_or_class.test_session(config=config):
420        f(test_or_class, *args, **kwargs)
421
422
423def _v2_eager_test(f, test_or_class, *args, **kwargs):
424  with context.eager_mode():
425    with testing_utils.run_eagerly_scope(True):
426      f(test_or_class, *args, **kwargs)
427
428
429def _v2_function_test(f, test_or_class, *args, **kwargs):
430  with context.eager_mode():
431    with testing_utils.run_eagerly_scope(False):
432      f(test_or_class, *args, **kwargs)
433
434
435def _test_or_class_decorator(test_or_class, single_method_decorator):
436  """Decorate a test or class with a decorator intended for one method.
437
438  If the test_or_class is a class:
439    This will apply the decorator to all test methods in the class.
440
441  If the test_or_class is an iterable of already-parameterized test cases:
442    This will apply the decorator to all the cases, and then flatten the
443    resulting cross-product of test cases. This allows stacking the Keras
444    parameterized decorators w/ each other, and to apply them to test methods
445    that have already been marked with an absl parameterized decorator.
446
447  Otherwise, treat the obj as a single method and apply the decorator directly.
448
449  Args:
450    test_or_class: A test method (that may have already been decorated with a
451      parameterized decorator, or a test class that extends
452      keras_parameterized.TestCase
453    single_method_decorator:
454      A parameterized decorator intended for a single test method.
455  Returns:
456    The decorated result.
457  """
458  def _decorate_test_or_class(obj):
459    if isinstance(obj, collections.abc.Iterable):
460      return itertools.chain.from_iterable(
461          single_method_decorator(method) for method in obj)
462    if isinstance(obj, type):
463      cls = obj
464      for name, value in cls.__dict__.copy().items():
465        if callable(value) and name.startswith(
466            unittest.TestLoader.testMethodPrefix):
467          setattr(cls, name, single_method_decorator(value))
468
469      cls = type(cls).__new__(type(cls), cls.__name__, cls.__bases__,
470                              cls.__dict__.copy())
471      return cls
472
473    return single_method_decorator(obj)
474
475  if test_or_class is not None:
476    return _decorate_test_or_class(test_or_class)
477
478  return _decorate_test_or_class
479