xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/append_slices.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Provides the `append_slices` and `merge_appended_slices operations.
15
16This wraps the generated ops and ensures that necessary shared libraries
17are loaded.
18"""
19
20import tensorflow as tf
21
22from fcp.tensorflow import gen_append_slices_py
23
24_append_slices_so = tf.load_op_library(
25    tf.compat.v1.resource_loader.get_path_to_datafile('./_append_slices_op.so'))
26
27
28def append_slices(filename, tensor_names, shapes_and_slices, data, name=None):
29  """Append slices to `filename`.
30
31  Must be paired with `merge_appended_slices`.
32
33  This op is identical to `tf.raw_ops.SaveSlices`, except that it appends the
34  resulting checkpoint to `filename` rather than erasing the contents of
35  `filename`.
36
37  Note: the resulting file at `filename` will not be in checkpoint format until
38  `merge_appended_slices` has been called.
39
40  Args:
41    filename: A `Tensor` fo type `string`. Must have a single element. The name
42      of the file to which the tensor should be appended.
43    tensor_names: A `Tensor` of type `string`. Shape `[N]`. The names of the
44      tensors to be saved.
45    shapes_and_slices: A `Tensor` of type `string`. Shape `[N]`. The shapes and
46      slice specifications to use when saving the tensors.
47    data: A list of `Tensor` objects. `N` tensors to save.
48    name: A name for the operation (optional).
49
50  Returns:
51    The created `Operation`.
52  """
53  return gen_append_slices_py.append_slices(
54      filename, tensor_names, shapes_and_slices, data, name=name)
55
56
57def merge_appended_slices(filename, name=None):
58  """Merges the appended file created by `append_slices` to a single checkpoint.
59
60  The immediate file output of `append_slices` is not in checkpoint format. It
61  must be converted to a checkpoint using this function `merge_appended_slices`.
62
63  Note: Users must call `control_dependencies` or other mechanisms to ensure
64  that the `append_slices` calls have executed prior to the execution of
65  `merge_appended_slices`.
66
67  Args:
68    filename: The name of a file appended to by calls to `append_slices`.
69    name: A name for the operation (optional).
70
71  Returns:
72    The created `Operation`.
73  """
74  return gen_append_slices_py.merge_appended_slices(filename, name)
75