1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for unit-testing Keras.""" 16 17import collections 18import functools 19import itertools 20import unittest 21 22from absl.testing import parameterized 23 24from tensorflow.python import keras 25from tensorflow.python import tf2 26from tensorflow.python.eager import context 27from tensorflow.python.framework import ops 28from tensorflow.python.keras import testing_utils 29from tensorflow.python.platform import test 30from tensorflow.python.util import nest 31 32try: 33 import h5py # pylint:disable=g-import-not-at-top 34except ImportError: 35 h5py = None 36 37 38class TestCase(test.TestCase, parameterized.TestCase): 39 40 def tearDown(self): 41 keras.backend.clear_session() 42 super(TestCase, self).tearDown() 43 44 45def run_with_all_saved_model_formats( 46 test_or_class=None, 47 exclude_formats=None): 48 """Execute the decorated test with all Keras saved model formats). 49 50 This decorator is intended to be applied either to individual test methods in 51 a `keras_parameterized.TestCase` class, or directly to a test class that 52 extends it. Doing so will cause the contents of the individual test 53 method (or all test methods in the class) to be executed multiple times - once 54 for each Keras saved model format. 55 56 The Keras saved model formats include: 57 1. HDF5: 'h5' 58 2. SavedModel: 'tf' 59 60 Note: if stacking this decorator with absl.testing's parameterized decorators, 61 those should be at the bottom of the stack. 62 63 Various methods in `testing_utils` to get file path for saved models will 64 auto-generate a string of the two saved model formats. This allows unittests 65 to confirm the equivalence between the two Keras saved model formats. 66 67 For example, consider the following unittest: 68 69 ```python 70 class MyTests(testing_utils.KerasTestCase): 71 72 @testing_utils.run_with_all_saved_model_formats 73 def test_foo(self): 74 save_format = testing_utils.get_save_format() 75 saved_model_dir = '/tmp/saved_model/' 76 model = keras.models.Sequential() 77 model.add(keras.layers.Dense(2, input_shape=(3,))) 78 model.add(keras.layers.Dense(3)) 79 model.compile(loss='mse', optimizer='sgd', metrics=['acc']) 80 81 keras.models.save_model(model, saved_model_dir, save_format=save_format) 82 model = keras.models.load_model(saved_model_dir) 83 84 if __name__ == "__main__": 85 tf.test.main() 86 ``` 87 88 This test tries to save the model into the formats of 'hdf5', 'h5', 'keras', 89 'tensorflow', and 'tf'. 90 91 We can also annotate the whole class if we want this to apply to all tests in 92 the class: 93 ```python 94 @testing_utils.run_with_all_saved_model_formats 95 class MyTests(testing_utils.KerasTestCase): 96 97 def test_foo(self): 98 save_format = testing_utils.get_save_format() 99 saved_model_dir = '/tmp/saved_model/' 100 model = keras.models.Sequential() 101 model.add(keras.layers.Dense(2, input_shape=(3,))) 102 model.add(keras.layers.Dense(3)) 103 model.compile(loss='mse', optimizer='sgd', metrics=['acc']) 104 105 keras.models.save_model(model, saved_model_dir, save_format=save_format) 106 model = tf.keras.models.load_model(saved_model_dir) 107 108 if __name__ == "__main__": 109 tf.test.main() 110 ``` 111 112 Args: 113 test_or_class: test method or class to be annotated. If None, 114 this method returns a decorator that can be applied to a test method or 115 test class. If it is not None this returns the decorator applied to the 116 test or class. 117 exclude_formats: A collection of Keras saved model formats to not run. 118 (May also be a single format not wrapped in a collection). 119 Defaults to None. 120 121 Returns: 122 Returns a decorator that will run the decorated test method multiple times: 123 once for each desired Keras saved model format. 124 125 Raises: 126 ImportError: If abseil parameterized is not installed or not included as 127 a target dependency. 128 """ 129 # Exclude h5 save format if H5py isn't available. 130 if h5py is None: 131 exclude_formats.append(['h5']) 132 saved_model_formats = ['h5', 'tf', 'tf_no_traces'] 133 params = [('_%s' % saved_format, saved_format) 134 for saved_format in saved_model_formats 135 if saved_format not in nest.flatten(exclude_formats)] 136 137 def single_method_decorator(f): 138 """Decorator that constructs the test cases.""" 139 # Use named_parameters so it can be individually run from the command line 140 @parameterized.named_parameters(*params) 141 @functools.wraps(f) 142 def decorated(self, saved_format, *args, **kwargs): 143 """A run of a single test case w/ the specified model type.""" 144 if saved_format == 'h5': 145 _test_h5_saved_model_format(f, self, *args, **kwargs) 146 elif saved_format == 'tf': 147 _test_tf_saved_model_format(f, self, *args, **kwargs) 148 elif saved_format == 'tf_no_traces': 149 _test_tf_saved_model_format_no_traces(f, self, *args, **kwargs) 150 else: 151 raise ValueError('Unknown model type: %s' % (saved_format,)) 152 return decorated 153 154 return _test_or_class_decorator(test_or_class, single_method_decorator) 155 156 157def _test_h5_saved_model_format(f, test_or_class, *args, **kwargs): 158 with testing_utils.saved_model_format_scope('h5'): 159 f(test_or_class, *args, **kwargs) 160 161 162def _test_tf_saved_model_format(f, test_or_class, *args, **kwargs): 163 with testing_utils.saved_model_format_scope('tf'): 164 f(test_or_class, *args, **kwargs) 165 166 167def _test_tf_saved_model_format_no_traces(f, test_or_class, *args, **kwargs): 168 with testing_utils.saved_model_format_scope('tf', save_traces=False): 169 f(test_or_class, *args, **kwargs) 170 171 172def run_with_all_weight_formats(test_or_class=None, exclude_formats=None): 173 """Runs all tests with the supported formats for saving weights.""" 174 exclude_formats = exclude_formats or [] 175 exclude_formats.append('tf_no_traces') # Only applies to saving models 176 return run_with_all_saved_model_formats(test_or_class, exclude_formats) 177 178 179# TODO(kaftan): Possibly enable 'subclass_custom_build' when tests begin to pass 180# it. Or perhaps make 'subclass' always use a custom build method. 181def run_with_all_model_types( 182 test_or_class=None, 183 exclude_models=None): 184 """Execute the decorated test with all Keras model types. 185 186 This decorator is intended to be applied either to individual test methods in 187 a `keras_parameterized.TestCase` class, or directly to a test class that 188 extends it. Doing so will cause the contents of the individual test 189 method (or all test methods in the class) to be executed multiple times - once 190 for each Keras model type. 191 192 The Keras model types are: ['functional', 'subclass', 'sequential'] 193 194 Note: if stacking this decorator with absl.testing's parameterized decorators, 195 those should be at the bottom of the stack. 196 197 Various methods in `testing_utils` to get models will auto-generate a model 198 of the currently active Keras model type. This allows unittests to confirm 199 the equivalence between different Keras models. 200 201 For example, consider the following unittest: 202 203 ```python 204 class MyTests(testing_utils.KerasTestCase): 205 206 @testing_utils.run_with_all_model_types( 207 exclude_models = ['sequential']) 208 def test_foo(self): 209 model = testing_utils.get_small_mlp(1, 4, input_dim=3) 210 optimizer = RMSPropOptimizer(learning_rate=0.001) 211 loss = 'mse' 212 metrics = ['mae'] 213 model.compile(optimizer, loss, metrics=metrics) 214 215 inputs = np.zeros((10, 3)) 216 targets = np.zeros((10, 4)) 217 dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) 218 dataset = dataset.repeat(100) 219 dataset = dataset.batch(10) 220 221 model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) 222 223 if __name__ == "__main__": 224 tf.test.main() 225 ``` 226 227 This test tries building a small mlp as both a functional model and as a 228 subclass model. 229 230 We can also annotate the whole class if we want this to apply to all tests in 231 the class: 232 ```python 233 @testing_utils.run_with_all_model_types(exclude_models = ['sequential']) 234 class MyTests(testing_utils.KerasTestCase): 235 236 def test_foo(self): 237 model = testing_utils.get_small_mlp(1, 4, input_dim=3) 238 optimizer = RMSPropOptimizer(learning_rate=0.001) 239 loss = 'mse' 240 metrics = ['mae'] 241 model.compile(optimizer, loss, metrics=metrics) 242 243 inputs = np.zeros((10, 3)) 244 targets = np.zeros((10, 4)) 245 dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) 246 dataset = dataset.repeat(100) 247 dataset = dataset.batch(10) 248 249 model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) 250 251 if __name__ == "__main__": 252 tf.test.main() 253 ``` 254 255 256 Args: 257 test_or_class: test method or class to be annotated. If None, 258 this method returns a decorator that can be applied to a test method or 259 test class. If it is not None this returns the decorator applied to the 260 test or class. 261 exclude_models: A collection of Keras model types to not run. 262 (May also be a single model type not wrapped in a collection). 263 Defaults to None. 264 265 Returns: 266 Returns a decorator that will run the decorated test method multiple times: 267 once for each desired Keras model type. 268 269 Raises: 270 ImportError: If abseil parameterized is not installed or not included as 271 a target dependency. 272 """ 273 model_types = ['functional', 'subclass', 'sequential'] 274 params = [('_%s' % model, model) for model in model_types 275 if model not in nest.flatten(exclude_models)] 276 277 def single_method_decorator(f): 278 """Decorator that constructs the test cases.""" 279 # Use named_parameters so it can be individually run from the command line 280 @parameterized.named_parameters(*params) 281 @functools.wraps(f) 282 def decorated(self, model_type, *args, **kwargs): 283 """A run of a single test case w/ the specified model type.""" 284 if model_type == 'functional': 285 _test_functional_model_type(f, self, *args, **kwargs) 286 elif model_type == 'subclass': 287 _test_subclass_model_type(f, self, *args, **kwargs) 288 elif model_type == 'sequential': 289 _test_sequential_model_type(f, self, *args, **kwargs) 290 else: 291 raise ValueError('Unknown model type: %s' % (model_type,)) 292 return decorated 293 294 return _test_or_class_decorator(test_or_class, single_method_decorator) 295 296 297def _test_functional_model_type(f, test_or_class, *args, **kwargs): 298 with testing_utils.model_type_scope('functional'): 299 f(test_or_class, *args, **kwargs) 300 301 302def _test_subclass_model_type(f, test_or_class, *args, **kwargs): 303 with testing_utils.model_type_scope('subclass'): 304 f(test_or_class, *args, **kwargs) 305 306 307def _test_sequential_model_type(f, test_or_class, *args, **kwargs): 308 with testing_utils.model_type_scope('sequential'): 309 f(test_or_class, *args, **kwargs) 310 311 312def run_all_keras_modes(test_or_class=None, 313 config=None, 314 always_skip_v1=False, 315 always_skip_eager=False, 316 **kwargs): 317 """Execute the decorated test with all keras execution modes. 318 319 This decorator is intended to be applied either to individual test methods in 320 a `keras_parameterized.TestCase` class, or directly to a test class that 321 extends it. Doing so will cause the contents of the individual test 322 method (or all test methods in the class) to be executed multiple times - 323 once executing in legacy graph mode, once running eagerly and with 324 `should_run_eagerly` returning True, and once running eagerly with 325 `should_run_eagerly` returning False. 326 327 If Tensorflow v2 behavior is enabled, legacy graph mode will be skipped, and 328 the test will only run twice. 329 330 Note: if stacking this decorator with absl.testing's parameterized decorators, 331 those should be at the bottom of the stack. 332 333 For example, consider the following unittest: 334 335 ```python 336 class MyTests(testing_utils.KerasTestCase): 337 338 @testing_utils.run_all_keras_modes 339 def test_foo(self): 340 model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3) 341 optimizer = RMSPropOptimizer(learning_rate=0.001) 342 loss = 'mse' 343 metrics = ['mae'] 344 model.compile( 345 optimizer, loss, metrics=metrics, 346 run_eagerly=testing_utils.should_run_eagerly()) 347 348 inputs = np.zeros((10, 3)) 349 targets = np.zeros((10, 4)) 350 dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) 351 dataset = dataset.repeat(100) 352 dataset = dataset.batch(10) 353 354 model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) 355 356 if __name__ == "__main__": 357 tf.test.main() 358 ``` 359 360 This test will try compiling & fitting the small functional mlp using all 361 three Keras execution modes. 362 363 Args: 364 test_or_class: test method or class to be annotated. If None, 365 this method returns a decorator that can be applied to a test method or 366 test class. If it is not None this returns the decorator applied to the 367 test or class. 368 config: An optional config_pb2.ConfigProto to use to configure the 369 session when executing graphs. 370 always_skip_v1: If True, does not try running the legacy graph mode even 371 when Tensorflow v2 behavior is not enabled. 372 always_skip_eager: If True, does not execute the decorated test 373 with eager execution modes. 374 **kwargs: Additional kwargs for configuring tests for 375 in-progress Keras behaviors/ refactorings that we haven't fully 376 rolled out yet 377 378 Returns: 379 Returns a decorator that will run the decorated test method multiple times. 380 381 Raises: 382 ImportError: If abseil parameterized is not installed or not included as 383 a target dependency. 384 """ 385 if kwargs: 386 raise ValueError('Unrecognized keyword args: {}'.format(kwargs)) 387 388 params = [('_v2_function', 'v2_function')] 389 if not always_skip_eager: 390 params.append(('_v2_eager', 'v2_eager')) 391 if not (always_skip_v1 or tf2.enabled()): 392 params.append(('_v1_session', 'v1_session')) 393 394 def single_method_decorator(f): 395 """Decorator that constructs the test cases.""" 396 397 # Use named_parameters so it can be individually run from the command line 398 @parameterized.named_parameters(*params) 399 @functools.wraps(f) 400 def decorated(self, run_mode, *args, **kwargs): 401 """A run of a single test case w/ specified run mode.""" 402 if run_mode == 'v1_session': 403 _v1_session_test(f, self, config, *args, **kwargs) 404 elif run_mode == 'v2_eager': 405 _v2_eager_test(f, self, *args, **kwargs) 406 elif run_mode == 'v2_function': 407 _v2_function_test(f, self, *args, **kwargs) 408 else: 409 return ValueError('Unknown run mode %s' % run_mode) 410 411 return decorated 412 413 return _test_or_class_decorator(test_or_class, single_method_decorator) 414 415 416def _v1_session_test(f, test_or_class, config, *args, **kwargs): 417 with ops.get_default_graph().as_default(): 418 with testing_utils.run_eagerly_scope(False): 419 with test_or_class.test_session(config=config): 420 f(test_or_class, *args, **kwargs) 421 422 423def _v2_eager_test(f, test_or_class, *args, **kwargs): 424 with context.eager_mode(): 425 with testing_utils.run_eagerly_scope(True): 426 f(test_or_class, *args, **kwargs) 427 428 429def _v2_function_test(f, test_or_class, *args, **kwargs): 430 with context.eager_mode(): 431 with testing_utils.run_eagerly_scope(False): 432 f(test_or_class, *args, **kwargs) 433 434 435def _test_or_class_decorator(test_or_class, single_method_decorator): 436 """Decorate a test or class with a decorator intended for one method. 437 438 If the test_or_class is a class: 439 This will apply the decorator to all test methods in the class. 440 441 If the test_or_class is an iterable of already-parameterized test cases: 442 This will apply the decorator to all the cases, and then flatten the 443 resulting cross-product of test cases. This allows stacking the Keras 444 parameterized decorators w/ each other, and to apply them to test methods 445 that have already been marked with an absl parameterized decorator. 446 447 Otherwise, treat the obj as a single method and apply the decorator directly. 448 449 Args: 450 test_or_class: A test method (that may have already been decorated with a 451 parameterized decorator, or a test class that extends 452 keras_parameterized.TestCase 453 single_method_decorator: 454 A parameterized decorator intended for a single test method. 455 Returns: 456 The decorated result. 457 """ 458 def _decorate_test_or_class(obj): 459 if isinstance(obj, collections.abc.Iterable): 460 return itertools.chain.from_iterable( 461 single_method_decorator(method) for method in obj) 462 if isinstance(obj, type): 463 cls = obj 464 for name, value in cls.__dict__.copy().items(): 465 if callable(value) and name.startswith( 466 unittest.TestLoader.testMethodPrefix): 467 setattr(cls, name, single_method_decorator(value)) 468 469 cls = type(cls).__new__(type(cls), cls.__name__, cls.__bases__, 470 cls.__dict__.copy()) 471 return cls 472 473 return single_method_decorator(obj) 474 475 if test_or_class is not None: 476 return _decorate_test_or_class(test_or_class) 477 478 return _decorate_test_or_class 479