xref: /aosp_15_r20/external/federated-compute/fcp/artifact_building/test_utils.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC
2*14675a02SAndroid Build Coastguard Worker#
3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License");
4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License.
5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at
6*14675a02SAndroid Build Coastguard Worker#
7*14675a02SAndroid Build Coastguard Worker#      http://www.apache.org/licenses/LICENSE-2.0
8*14675a02SAndroid Build Coastguard Worker#
9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software
10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS,
11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and
13*14675a02SAndroid Build Coastguard Worker# limitations under the License.
14*14675a02SAndroid Build Coastguard Worker"""Utilities used in tests."""
15*14675a02SAndroid Build Coastguard Worker
16*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf
17*14675a02SAndroid Build Coastguard Worker
18*14675a02SAndroid Build Coastguard Workerfrom fcp.protos import plan_pb2
19*14675a02SAndroid Build Coastguard Worker
20*14675a02SAndroid Build Coastguard Worker
21*14675a02SAndroid Build Coastguard Workerdef set_checkpoint_op(
22*14675a02SAndroid Build Coastguard Worker    checkpoint_op_proto: plan_pb2.CheckpointOp,
23*14675a02SAndroid Build Coastguard Worker    saver: tf.compat.v1.train.SaverDef,
24*14675a02SAndroid Build Coastguard Worker):
25*14675a02SAndroid Build Coastguard Worker  """Sets the saver_def from saver onto checkpoint_op_proto and fixes a name."""
26*14675a02SAndroid Build Coastguard Worker  if not saver:
27*14675a02SAndroid Build Coastguard Worker    return
28*14675a02SAndroid Build Coastguard Worker  saver_def_proto = checkpoint_op_proto.saver_def
29*14675a02SAndroid Build Coastguard Worker
30*14675a02SAndroid Build Coastguard Worker  saver_def_proto.CopyFrom(saver.as_saver_def())
31*14675a02SAndroid Build Coastguard Worker  # They are calling an Op a Tensor and it works in python and
32*14675a02SAndroid Build Coastguard Worker  # breaks in C++.  However, for use in the python Saver class, we
33*14675a02SAndroid Build Coastguard Worker  # need the tensor because we need sess.run() to return the
34*14675a02SAndroid Build Coastguard Worker  # tensor's value. So, we only strip the ":0" in the case of
35*14675a02SAndroid Build Coastguard Worker  # plan execution, where we use the write_checkpoint and
36*14675a02SAndroid Build Coastguard Worker  # read_checkpoint methods below instead of the Saver.
37*14675a02SAndroid Build Coastguard Worker  saver_def_proto.save_tensor_name = saver_def_proto.save_tensor_name.replace(
38*14675a02SAndroid Build Coastguard Worker      ':0', ''
39*14675a02SAndroid Build Coastguard Worker  )
40*14675a02SAndroid Build Coastguard Worker  assert saver_def_proto.save_tensor_name.rfind(':') == -1
41