1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC 2*14675a02SAndroid Build Coastguard Worker# 3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License"); 4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License. 5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at 6*14675a02SAndroid Build Coastguard Worker# 7*14675a02SAndroid Build Coastguard Worker# http://www.apache.org/licenses/LICENSE-2.0 8*14675a02SAndroid Build Coastguard Worker# 9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software 10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS, 11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and 13*14675a02SAndroid Build Coastguard Worker# limitations under the License. 14*14675a02SAndroid Build Coastguard Worker"""Utilities used in tests.""" 15*14675a02SAndroid Build Coastguard Worker 16*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf 17*14675a02SAndroid Build Coastguard Worker 18*14675a02SAndroid Build Coastguard Workerfrom fcp.protos import plan_pb2 19*14675a02SAndroid Build Coastguard Worker 20*14675a02SAndroid Build Coastguard Worker 21*14675a02SAndroid Build Coastguard Workerdef set_checkpoint_op( 22*14675a02SAndroid Build Coastguard Worker checkpoint_op_proto: plan_pb2.CheckpointOp, 23*14675a02SAndroid Build Coastguard Worker saver: tf.compat.v1.train.SaverDef, 24*14675a02SAndroid Build Coastguard Worker): 25*14675a02SAndroid Build Coastguard Worker """Sets the saver_def from saver onto checkpoint_op_proto and fixes a name.""" 26*14675a02SAndroid Build Coastguard Worker if not saver: 27*14675a02SAndroid Build Coastguard Worker return 28*14675a02SAndroid Build Coastguard Worker saver_def_proto = checkpoint_op_proto.saver_def 29*14675a02SAndroid Build Coastguard Worker 30*14675a02SAndroid Build Coastguard Worker saver_def_proto.CopyFrom(saver.as_saver_def()) 31*14675a02SAndroid Build Coastguard Worker # They are calling an Op a Tensor and it works in python and 32*14675a02SAndroid Build Coastguard Worker # breaks in C++. However, for use in the python Saver class, we 33*14675a02SAndroid Build Coastguard Worker # need the tensor because we need sess.run() to return the 34*14675a02SAndroid Build Coastguard Worker # tensor's value. So, we only strip the ":0" in the case of 35*14675a02SAndroid Build Coastguard Worker # plan execution, where we use the write_checkpoint and 36*14675a02SAndroid Build Coastguard Worker # read_checkpoint methods below instead of the Saver. 37*14675a02SAndroid Build Coastguard Worker saver_def_proto.save_tensor_name = saver_def_proto.save_tensor_name.replace( 38*14675a02SAndroid Build Coastguard Worker ':0', '' 39*14675a02SAndroid Build Coastguard Worker ) 40*14675a02SAndroid Build Coastguard Worker assert saver_def_proto.save_tensor_name.rfind(':') == -1 41