xref: /aosp_15_r20/external/tensorflow/tensorflow/python/checkpoint/BUILD (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Description:
2#   Utilities for reading and writing object-based checkpoints.
3
4load(
5    "//tensorflow/tools/test:performance.bzl",
6    "tf_py_logged_benchmark",
7)
8
9# buildifier: disable=same-origin-load
10load("//tensorflow:tensorflow.bzl", "cuda_py_test")
11
12# buildifier: disable=same-origin-load
13load("//tensorflow:tensorflow.bzl", "tf_py_test")
14
15package(
16    default_visibility = [
17        "//tensorflow:internal",
18    ],
19    licenses = ["notice"],
20)
21
22py_library(
23    name = "checkpoint_lib",
24    deps = [
25        ":checkpoint",
26        ":checkpoint_management",
27        ":checkpoint_options",
28        ":functional_saver",
29        ":graph_view",
30        ":saveable_compat",
31        ":util",
32    ],
33)
34
35py_library(
36    name = "checkpoint",
37    srcs = [
38        "__init__.py",
39        "checkpoint.py",
40    ],
41    srcs_version = "PY3",
42    deps = [
43        ":checkpoint_options",
44        ":checkpoint_view",
45        ":functional_saver",
46        ":graph_view",
47        ":restore",
48        ":save_util_v1",
49        ":util",
50        "//tensorflow/core:protos_all_py",
51        "//tensorflow/python:array_ops",
52        "//tensorflow/python:constant_op",
53        "//tensorflow/python:dtypes",
54        "//tensorflow/python:errors",
55        "//tensorflow/python:framework_ops",
56        "//tensorflow/python:init_ops",
57        "//tensorflow/python:io_ops_gen",
58        "//tensorflow/python:lib",
59        "//tensorflow/python:platform",
60        "//tensorflow/python:pywrap_tensorflow",
61        "//tensorflow/python:saver",
62        "//tensorflow/python:session",
63        "//tensorflow/python:tensor_shape",
64        "//tensorflow/python:tensor_util",
65        "//tensorflow/python:util",
66        "//tensorflow/python:variable_scope",
67        "//tensorflow/python:variables",
68        "//tensorflow/python/checkpoint:checkpoint_management",
69        "//tensorflow/python/eager:context",
70        "//tensorflow/python/eager:def_function",
71        "//tensorflow/python/saved_model:utils",
72        "//tensorflow/python/trackable:autotrackable",
73        "//tensorflow/python/trackable:base",
74        "//tensorflow/python/trackable:data_structures",
75        "//tensorflow/python/training/saving:saveable_object_util",
76    ],
77)
78
79tf_py_test(
80    name = "checkpoint_test",
81    srcs = ["checkpoint_test.py"],
82    tags = [
83        "no_windows",  # TODO(b/201457117)
84        "notsan",  # TODO(b/74395663)
85    ],
86    deps = [
87        ":checkpoint",
88        ":checkpoint_options",
89        ":graph_view",
90        "//tensorflow/python:client_testlib",
91        "//tensorflow/python:constant_op",
92        "//tensorflow/python:control_flow_ops",
93        "//tensorflow/python:dtypes",
94        "//tensorflow/python:framework_ops",
95        "//tensorflow/python:framework_test_lib",
96        "//tensorflow/python:init_ops",
97        "//tensorflow/python:platform",
98        "//tensorflow/python:pywrap_tensorflow",
99        "//tensorflow/python:resource_variable_ops",
100        "//tensorflow/python:saver",
101        "//tensorflow/python:session",
102        "//tensorflow/python:state_ops",
103        "//tensorflow/python:template",
104        "//tensorflow/python:training_util",
105        "//tensorflow/python:variable_scope",
106        "//tensorflow/python:variables",
107        "//tensorflow/python/checkpoint:checkpoint_management",
108        "//tensorflow/python/eager:backprop",
109        "//tensorflow/python/eager:context",
110        "//tensorflow/python/eager:def_function",
111        "//tensorflow/python/eager:test",
112        "//tensorflow/python/saved_model:save",
113        "//tensorflow/python/trackable:autotrackable",
114        "//tensorflow/python/trackable:base",
115        "@absl_py//absl/testing:parameterized",
116    ],
117)
118
119tf_py_test(
120    name = "checkpoint_with_v1_optimizers_test",
121    srcs = ["checkpoint_with_v1_optimizers_test.py"],
122    tags = [
123        "notsan",  # b/74395663
124    ],
125    deps = [
126        ":checkpoint",
127        "//tensorflow/python:framework_ops",
128        "//tensorflow/python:framework_test_lib",
129        "//tensorflow/python:init_ops",
130        "//tensorflow/python:resource_variable_ops",
131        "//tensorflow/python:session",
132        "//tensorflow/python:state_ops",
133        "//tensorflow/python:template",
134        "//tensorflow/python:training",
135        "//tensorflow/python:variable_scope",
136        "//tensorflow/python/eager:context",
137        "//tensorflow/python/eager:test",
138        "//tensorflow/python/trackable:autotrackable",
139    ],
140)
141
142tf_py_test(
143    name = "checkpoint_metrics_test",
144    srcs = ["checkpoint_metrics_test.py"],
145    deps = [
146        ":checkpoint",
147        "//tensorflow/python:platform_test",
148    ],
149)
150
151py_library(
152    name = "checkpoint_view",
153    srcs = ["checkpoint_view.py"],
154    srcs_version = "PY3",
155    tags = ["no_pip"],
156    deps = [
157        ":trackable_view",
158        "//tensorflow/core:protos_all_py",
159        "//tensorflow/python:platform",
160        "//tensorflow/python/framework:errors",
161        "//tensorflow/python/trackable:base",
162        "//tensorflow/python/training:py_checkpoint_reader",
163        "//tensorflow/python/util:tf_export",
164    ],
165)
166
167tf_py_test(
168    name = "checkpoint_view_test",
169    srcs = ["checkpoint_view_test.py"],
170    tags = ["no_pip"],
171    deps = [
172        ":checkpoint_view",
173        "//tensorflow/python:variables",
174        "//tensorflow/python/eager:test",
175        "//tensorflow/python/trackable:base",
176    ],
177)
178
179py_library(
180    name = "graph_view",
181    srcs = ["graph_view.py"],
182    srcs_version = "PY3",
183    deps = [
184        ":save_util_v1",
185        ":trackable_view",
186        "//tensorflow/python:util",
187        "//tensorflow/python/trackable:base",
188        "//tensorflow/python/trackable:converter",
189    ],
190)
191
192py_library(
193    name = "save_util_v1",
194    srcs = ["save_util_v1.py"],
195    srcs_version = "PY3",
196    deps = [
197        ":saveable_compat",
198        "//tensorflow/python:constant_op",
199        "//tensorflow/python:dtypes",
200        "//tensorflow/python:framework_ops",
201        "//tensorflow/python:util",
202        "//tensorflow/python/saved_model/registration",
203        "//tensorflow/python/trackable:base",
204        "//tensorflow/python/trackable:python_state",
205        "//tensorflow/python/trackable:trackable_utils",
206        "//tensorflow/python/training/saving:saveable_object",
207        "//tensorflow/python/training/saving:saveable_object_util",
208    ],
209)
210
211tf_py_test(
212    name = "save_util_v1_test",
213    srcs = ["save_util_v1_test.py"],
214    deps = [
215        ":graph_view",
216        ":save_util_v1",
217        "//tensorflow/python:util",
218        "//tensorflow/python:variables",
219        "//tensorflow/python/eager:test",
220        "//tensorflow/python/saved_model/registration",
221        "//tensorflow/python/trackable:autotrackable",
222    ],
223)
224
225py_library(
226    name = "trackable_view",
227    srcs = ["trackable_view.py"],
228    srcs_version = "PY3",
229    tags = ["no_pip"],
230    deps = [
231        "//tensorflow/python:util",
232        "//tensorflow/python/trackable:base",
233        "//tensorflow/python/trackable:converter",
234        "//tensorflow/python/util:tf_export",
235    ],
236)
237
238tf_py_test(
239    name = "trackable_view_test",
240    srcs = ["trackable_view_test.py"],
241    deps = [
242        ":trackable_view",
243        "//tensorflow/python/eager:test",
244        "//tensorflow/python/trackable:base",
245    ],
246)
247
248py_library(
249    name = "util",
250    srcs = ["util.py"],
251    srcs_version = "PY3",
252    deps = [
253        "//tensorflow/core:protos_all_py",
254        "//tensorflow/python:resource_variable_ops",
255        "//tensorflow/python:util",
256        "//tensorflow/python:variables",
257        "//tensorflow/python/trackable:trackable_utils",
258        "//tensorflow/python/training:optimizer",
259    ],
260)
261
262py_library(
263    name = "restore",
264    srcs = ["restore.py"],
265    srcs_version = "PY3",
266    deps = [
267        ":saveable_compat",
268        "//tensorflow/python:array_ops",
269        "//tensorflow/python:framework_ops",
270        "//tensorflow/python:io_ops_gen",
271        "//tensorflow/python:platform",
272        "//tensorflow/python/eager:context",
273        "//tensorflow/python/saved_model/registration",
274        "//tensorflow/python/trackable:constants",
275        "//tensorflow/python/trackable:python_state",
276        "//tensorflow/python/trackable:trackable_utils",
277    ],
278)
279
280tf_py_test(
281    name = "restore_test",
282    srcs = ["restore_test.py"],
283    deps = [
284        ":restore",
285        "//tensorflow/python/eager:test",
286    ],
287)
288
289tf_py_test(
290    name = "benchmarks_test",
291    srcs = ["benchmarks_test.py"],
292    deps = [
293        ":checkpoint",
294        "//tensorflow/python:framework_ops",
295        "//tensorflow/python:platform_test",
296    ],
297)
298
299tf_py_logged_benchmark(
300    name = "benchmarks",
301    target = "//tensorflow/python/checkpoint:benchmarks_test",
302)
303
304py_library(
305    name = "checkpoint_options",
306    srcs = ["checkpoint_options.py"],
307    srcs_version = "PY3",
308    deps = [
309        "//tensorflow/python/util:tf_export",
310    ],
311)
312
313py_library(
314    name = "functional_saver",
315    srcs = ["functional_saver.py"],
316    srcs_version = "PY3",
317    deps = [
318        ":checkpoint_options",
319        "//tensorflow/python/eager:def_function",
320        "//tensorflow/python/saved_model/registration",
321        "//tensorflow/python/training/saving:saveable_object",
322        "//tensorflow/python/training/saving:saveable_object_util",
323    ],
324)
325
326cuda_py_test(
327    name = "functional_saver_test",
328    size = "medium",
329    srcs = [
330        "functional_saver_test.py",
331    ],
332    deps = [
333        ":checkpoint_options",
334        ":functional_saver",
335        "//tensorflow/python/eager:remote",
336        "//tensorflow/python/eager:test",
337    ],
338)
339
340py_library(
341    name = "checkpoint_management",
342    srcs = ["checkpoint_management.py"],
343    srcs_version = "PY3",
344    deps = [
345        "//tensorflow/python:errors",
346        "//tensorflow/python:framework_ops",
347        "//tensorflow/python:lib",
348        "//tensorflow/python:platform",
349        "//tensorflow/python:util",
350        "//tensorflow/python:variable_scope",
351        "//tensorflow/python/eager:context",
352        "//tensorflow/python/training:training_util",
353        "//tensorflow/python/util:tf_export",
354    ],
355)
356
357cuda_py_test(
358    name = "checkpoint_management_test",
359    size = "small",
360    srcs = [
361        "checkpoint_management_test.py",
362    ],
363    python_version = "PY3",
364    deps = [
365        ":checkpoint",
366        "//tensorflow/python:client_testlib",
367        "//tensorflow/python:dtypes",
368        "//tensorflow/python:framework_ops",
369        "//tensorflow/python:framework_test_lib",
370        "//tensorflow/python:lib",
371        "//tensorflow/python:platform",
372        "//tensorflow/python:variables",
373        "//tensorflow/python/eager:context",
374        "//tensorflow/python/training:checkpoint_management",
375        "//tensorflow/python/training:saver",
376    ],
377)
378
379py_library(
380    name = "saveable_compat",
381    srcs = [
382        "saveable_compat.py",
383    ],
384)
385
386tf_py_test(
387    name = "saveable_compat_test",
388    srcs = [
389        "saveable_compat_test.py",
390    ],
391    data = [
392        "testdata/table_legacy_saveable_object.data-00000-of-00001",
393        "testdata/table_legacy_saveable_object.index",
394    ],
395    tags = ["no_pip"],
396    deps = [
397        ":checkpoint",
398        ":saveable_compat",
399        ":testdata/generate_checkpoint",
400        "//tensorflow/python:variables",
401        "//tensorflow/python/eager:test",
402        "//tensorflow/python/trackable:base",
403        "//tensorflow/python/training/saving:saveable_object",
404    ],
405)
406
407py_binary(
408    name = "testdata/generate_checkpoint",
409    srcs = ["testdata/generate_checkpoint.py"],
410    python_version = "PY3",
411    srcs_version = "PY3",
412    deps = [
413        "//tensorflow/python:checkpoint",
414        "//tensorflow/python:dtypes",
415        "//tensorflow/python:framework_ops",
416        "//tensorflow/python:lookup_ops",
417        "//tensorflow/python:variables",
418        "//tensorflow/python/compat:v2_compat",
419        "//tensorflow/python/module",
420        "@absl_py//absl:app",
421    ],
422)
423