xref: /aosp_15_r20/external/federated-compute/fcp/artifact_building/test_utils.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Utilities used in tests."""
15
16import tensorflow as tf
17
18from fcp.protos import plan_pb2
19
20
21def set_checkpoint_op(
22    checkpoint_op_proto: plan_pb2.CheckpointOp,
23    saver: tf.compat.v1.train.SaverDef,
24):
25  """Sets the saver_def from saver onto checkpoint_op_proto and fixes a name."""
26  if not saver:
27    return
28  saver_def_proto = checkpoint_op_proto.saver_def
29
30  saver_def_proto.CopyFrom(saver.as_saver_def())
31  # They are calling an Op a Tensor and it works in python and
32  # breaks in C++.  However, for use in the python Saver class, we
33  # need the tensor because we need sess.run() to return the
34  # tensor's value. So, we only strip the ":0" in the case of
35  # plan execution, where we use the write_checkpoint and
36  # read_checkpoint methods below instead of the Saver.
37  saver_def_proto.save_tensor_name = saver_def_proto.save_tensor_name.replace(
38      ':0', ''
39  )
40  assert saver_def_proto.save_tensor_name.rfind(':') == -1
41