1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Utilities used in tests.""" 15 16import tensorflow as tf 17 18from fcp.protos import plan_pb2 19 20 21def set_checkpoint_op( 22 checkpoint_op_proto: plan_pb2.CheckpointOp, 23 saver: tf.compat.v1.train.SaverDef, 24): 25 """Sets the saver_def from saver onto checkpoint_op_proto and fixes a name.""" 26 if not saver: 27 return 28 saver_def_proto = checkpoint_op_proto.saver_def 29 30 saver_def_proto.CopyFrom(saver.as_saver_def()) 31 # They are calling an Op a Tensor and it works in python and 32 # breaks in C++. However, for use in the python Saver class, we 33 # need the tensor because we need sess.run() to return the 34 # tensor's value. So, we only strip the ":0" in the case of 35 # plan execution, where we use the write_checkpoint and 36 # read_checkpoint methods below instead of the Saver. 37 saver_def_proto.save_tensor_name = saver_def_proto.save_tensor_name.replace( 38 ':0', '' 39 ) 40 assert saver_def_proto.save_tensor_name.rfind(':') == -1 41