1# Description: 2# Utilities for reading and writing object-based checkpoints. 3 4load( 5 "//tensorflow/tools/test:performance.bzl", 6 "tf_py_logged_benchmark", 7) 8 9# buildifier: disable=same-origin-load 10load("//tensorflow:tensorflow.bzl", "cuda_py_test") 11 12# buildifier: disable=same-origin-load 13load("//tensorflow:tensorflow.bzl", "tf_py_test") 14 15package( 16 default_visibility = [ 17 "//tensorflow:internal", 18 ], 19 licenses = ["notice"], 20) 21 22py_library( 23 name = "checkpoint_lib", 24 deps = [ 25 ":checkpoint", 26 ":checkpoint_management", 27 ":checkpoint_options", 28 ":functional_saver", 29 ":graph_view", 30 ":saveable_compat", 31 ":util", 32 ], 33) 34 35py_library( 36 name = "checkpoint", 37 srcs = [ 38 "__init__.py", 39 "checkpoint.py", 40 ], 41 srcs_version = "PY3", 42 deps = [ 43 ":checkpoint_options", 44 ":checkpoint_view", 45 ":functional_saver", 46 ":graph_view", 47 ":restore", 48 ":save_util_v1", 49 ":util", 50 "//tensorflow/core:protos_all_py", 51 "//tensorflow/python:array_ops", 52 "//tensorflow/python:constant_op", 53 "//tensorflow/python:dtypes", 54 "//tensorflow/python:errors", 55 "//tensorflow/python:framework_ops", 56 "//tensorflow/python:init_ops", 57 "//tensorflow/python:io_ops_gen", 58 "//tensorflow/python:lib", 59 "//tensorflow/python:platform", 60 "//tensorflow/python:pywrap_tensorflow", 61 "//tensorflow/python:saver", 62 "//tensorflow/python:session", 63 "//tensorflow/python:tensor_shape", 64 "//tensorflow/python:tensor_util", 65 "//tensorflow/python:util", 66 "//tensorflow/python:variable_scope", 67 "//tensorflow/python:variables", 68 "//tensorflow/python/checkpoint:checkpoint_management", 69 "//tensorflow/python/eager:context", 70 "//tensorflow/python/eager:def_function", 71 "//tensorflow/python/saved_model:utils", 72 "//tensorflow/python/trackable:autotrackable", 73 "//tensorflow/python/trackable:base", 74 "//tensorflow/python/trackable:data_structures", 75 "//tensorflow/python/training/saving:saveable_object_util", 76 ], 77) 78 79tf_py_test( 80 name = "checkpoint_test", 81 srcs = ["checkpoint_test.py"], 82 tags = [ 83 "no_windows", # TODO(b/201457117) 84 "notsan", # TODO(b/74395663) 85 ], 86 deps = [ 87 ":checkpoint", 88 ":checkpoint_options", 89 ":graph_view", 90 "//tensorflow/python:client_testlib", 91 "//tensorflow/python:constant_op", 92 "//tensorflow/python:control_flow_ops", 93 "//tensorflow/python:dtypes", 94 "//tensorflow/python:framework_ops", 95 "//tensorflow/python:framework_test_lib", 96 "//tensorflow/python:init_ops", 97 "//tensorflow/python:platform", 98 "//tensorflow/python:pywrap_tensorflow", 99 "//tensorflow/python:resource_variable_ops", 100 "//tensorflow/python:saver", 101 "//tensorflow/python:session", 102 "//tensorflow/python:state_ops", 103 "//tensorflow/python:template", 104 "//tensorflow/python:training_util", 105 "//tensorflow/python:variable_scope", 106 "//tensorflow/python:variables", 107 "//tensorflow/python/checkpoint:checkpoint_management", 108 "//tensorflow/python/eager:backprop", 109 "//tensorflow/python/eager:context", 110 "//tensorflow/python/eager:def_function", 111 "//tensorflow/python/eager:test", 112 "//tensorflow/python/saved_model:save", 113 "//tensorflow/python/trackable:autotrackable", 114 "//tensorflow/python/trackable:base", 115 "@absl_py//absl/testing:parameterized", 116 ], 117) 118 119tf_py_test( 120 name = "checkpoint_with_v1_optimizers_test", 121 srcs = ["checkpoint_with_v1_optimizers_test.py"], 122 tags = [ 123 "notsan", # b/74395663 124 ], 125 deps = [ 126 ":checkpoint", 127 "//tensorflow/python:framework_ops", 128 "//tensorflow/python:framework_test_lib", 129 "//tensorflow/python:init_ops", 130 "//tensorflow/python:resource_variable_ops", 131 "//tensorflow/python:session", 132 "//tensorflow/python:state_ops", 133 "//tensorflow/python:template", 134 "//tensorflow/python:training", 135 "//tensorflow/python:variable_scope", 136 "//tensorflow/python/eager:context", 137 "//tensorflow/python/eager:test", 138 "//tensorflow/python/trackable:autotrackable", 139 ], 140) 141 142tf_py_test( 143 name = "checkpoint_metrics_test", 144 srcs = ["checkpoint_metrics_test.py"], 145 deps = [ 146 ":checkpoint", 147 "//tensorflow/python:platform_test", 148 ], 149) 150 151py_library( 152 name = "checkpoint_view", 153 srcs = ["checkpoint_view.py"], 154 srcs_version = "PY3", 155 tags = ["no_pip"], 156 deps = [ 157 ":trackable_view", 158 "//tensorflow/core:protos_all_py", 159 "//tensorflow/python:platform", 160 "//tensorflow/python/framework:errors", 161 "//tensorflow/python/trackable:base", 162 "//tensorflow/python/training:py_checkpoint_reader", 163 "//tensorflow/python/util:tf_export", 164 ], 165) 166 167tf_py_test( 168 name = "checkpoint_view_test", 169 srcs = ["checkpoint_view_test.py"], 170 tags = ["no_pip"], 171 deps = [ 172 ":checkpoint_view", 173 "//tensorflow/python:variables", 174 "//tensorflow/python/eager:test", 175 "//tensorflow/python/trackable:base", 176 ], 177) 178 179py_library( 180 name = "graph_view", 181 srcs = ["graph_view.py"], 182 srcs_version = "PY3", 183 deps = [ 184 ":save_util_v1", 185 ":trackable_view", 186 "//tensorflow/python:util", 187 "//tensorflow/python/trackable:base", 188 "//tensorflow/python/trackable:converter", 189 ], 190) 191 192py_library( 193 name = "save_util_v1", 194 srcs = ["save_util_v1.py"], 195 srcs_version = "PY3", 196 deps = [ 197 ":saveable_compat", 198 "//tensorflow/python:constant_op", 199 "//tensorflow/python:dtypes", 200 "//tensorflow/python:framework_ops", 201 "//tensorflow/python:util", 202 "//tensorflow/python/saved_model/registration", 203 "//tensorflow/python/trackable:base", 204 "//tensorflow/python/trackable:python_state", 205 "//tensorflow/python/trackable:trackable_utils", 206 "//tensorflow/python/training/saving:saveable_object", 207 "//tensorflow/python/training/saving:saveable_object_util", 208 ], 209) 210 211tf_py_test( 212 name = "save_util_v1_test", 213 srcs = ["save_util_v1_test.py"], 214 deps = [ 215 ":graph_view", 216 ":save_util_v1", 217 "//tensorflow/python:util", 218 "//tensorflow/python:variables", 219 "//tensorflow/python/eager:test", 220 "//tensorflow/python/saved_model/registration", 221 "//tensorflow/python/trackable:autotrackable", 222 ], 223) 224 225py_library( 226 name = "trackable_view", 227 srcs = ["trackable_view.py"], 228 srcs_version = "PY3", 229 tags = ["no_pip"], 230 deps = [ 231 "//tensorflow/python:util", 232 "//tensorflow/python/trackable:base", 233 "//tensorflow/python/trackable:converter", 234 "//tensorflow/python/util:tf_export", 235 ], 236) 237 238tf_py_test( 239 name = "trackable_view_test", 240 srcs = ["trackable_view_test.py"], 241 deps = [ 242 ":trackable_view", 243 "//tensorflow/python/eager:test", 244 "//tensorflow/python/trackable:base", 245 ], 246) 247 248py_library( 249 name = "util", 250 srcs = ["util.py"], 251 srcs_version = "PY3", 252 deps = [ 253 "//tensorflow/core:protos_all_py", 254 "//tensorflow/python:resource_variable_ops", 255 "//tensorflow/python:util", 256 "//tensorflow/python:variables", 257 "//tensorflow/python/trackable:trackable_utils", 258 "//tensorflow/python/training:optimizer", 259 ], 260) 261 262py_library( 263 name = "restore", 264 srcs = ["restore.py"], 265 srcs_version = "PY3", 266 deps = [ 267 ":saveable_compat", 268 "//tensorflow/python:array_ops", 269 "//tensorflow/python:framework_ops", 270 "//tensorflow/python:io_ops_gen", 271 "//tensorflow/python:platform", 272 "//tensorflow/python/eager:context", 273 "//tensorflow/python/saved_model/registration", 274 "//tensorflow/python/trackable:constants", 275 "//tensorflow/python/trackable:python_state", 276 "//tensorflow/python/trackable:trackable_utils", 277 ], 278) 279 280tf_py_test( 281 name = "restore_test", 282 srcs = ["restore_test.py"], 283 deps = [ 284 ":restore", 285 "//tensorflow/python/eager:test", 286 ], 287) 288 289tf_py_test( 290 name = "benchmarks_test", 291 srcs = ["benchmarks_test.py"], 292 deps = [ 293 ":checkpoint", 294 "//tensorflow/python:framework_ops", 295 "//tensorflow/python:platform_test", 296 ], 297) 298 299tf_py_logged_benchmark( 300 name = "benchmarks", 301 target = "//tensorflow/python/checkpoint:benchmarks_test", 302) 303 304py_library( 305 name = "checkpoint_options", 306 srcs = ["checkpoint_options.py"], 307 srcs_version = "PY3", 308 deps = [ 309 "//tensorflow/python/util:tf_export", 310 ], 311) 312 313py_library( 314 name = "functional_saver", 315 srcs = ["functional_saver.py"], 316 srcs_version = "PY3", 317 deps = [ 318 ":checkpoint_options", 319 "//tensorflow/python/eager:def_function", 320 "//tensorflow/python/saved_model/registration", 321 "//tensorflow/python/training/saving:saveable_object", 322 "//tensorflow/python/training/saving:saveable_object_util", 323 ], 324) 325 326cuda_py_test( 327 name = "functional_saver_test", 328 size = "medium", 329 srcs = [ 330 "functional_saver_test.py", 331 ], 332 deps = [ 333 ":checkpoint_options", 334 ":functional_saver", 335 "//tensorflow/python/eager:remote", 336 "//tensorflow/python/eager:test", 337 ], 338) 339 340py_library( 341 name = "checkpoint_management", 342 srcs = ["checkpoint_management.py"], 343 srcs_version = "PY3", 344 deps = [ 345 "//tensorflow/python:errors", 346 "//tensorflow/python:framework_ops", 347 "//tensorflow/python:lib", 348 "//tensorflow/python:platform", 349 "//tensorflow/python:util", 350 "//tensorflow/python:variable_scope", 351 "//tensorflow/python/eager:context", 352 "//tensorflow/python/training:training_util", 353 "//tensorflow/python/util:tf_export", 354 ], 355) 356 357cuda_py_test( 358 name = "checkpoint_management_test", 359 size = "small", 360 srcs = [ 361 "checkpoint_management_test.py", 362 ], 363 python_version = "PY3", 364 deps = [ 365 ":checkpoint", 366 "//tensorflow/python:client_testlib", 367 "//tensorflow/python:dtypes", 368 "//tensorflow/python:framework_ops", 369 "//tensorflow/python:framework_test_lib", 370 "//tensorflow/python:lib", 371 "//tensorflow/python:platform", 372 "//tensorflow/python:variables", 373 "//tensorflow/python/eager:context", 374 "//tensorflow/python/training:checkpoint_management", 375 "//tensorflow/python/training:saver", 376 ], 377) 378 379py_library( 380 name = "saveable_compat", 381 srcs = [ 382 "saveable_compat.py", 383 ], 384) 385 386tf_py_test( 387 name = "saveable_compat_test", 388 srcs = [ 389 "saveable_compat_test.py", 390 ], 391 data = [ 392 "testdata/table_legacy_saveable_object.data-00000-of-00001", 393 "testdata/table_legacy_saveable_object.index", 394 ], 395 tags = ["no_pip"], 396 deps = [ 397 ":checkpoint", 398 ":saveable_compat", 399 ":testdata/generate_checkpoint", 400 "//tensorflow/python:variables", 401 "//tensorflow/python/eager:test", 402 "//tensorflow/python/trackable:base", 403 "//tensorflow/python/training/saving:saveable_object", 404 ], 405) 406 407py_binary( 408 name = "testdata/generate_checkpoint", 409 srcs = ["testdata/generate_checkpoint.py"], 410 python_version = "PY3", 411 srcs_version = "PY3", 412 deps = [ 413 "//tensorflow/python:checkpoint", 414 "//tensorflow/python:dtypes", 415 "//tensorflow/python:framework_ops", 416 "//tensorflow/python:lookup_ops", 417 "//tensorflow/python:variables", 418 "//tensorflow/python/compat:v2_compat", 419 "//tensorflow/python/module", 420 "@absl_py//absl:app", 421 ], 422) 423