1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC 2*14675a02SAndroid Build Coastguard Worker# 3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License"); 4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License. 5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at 6*14675a02SAndroid Build Coastguard Worker# 7*14675a02SAndroid Build Coastguard Worker# http://www.apache.org/licenses/LICENSE-2.0 8*14675a02SAndroid Build Coastguard Worker# 9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software 10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS, 11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and 13*14675a02SAndroid Build Coastguard Worker# limitations under the License. 14*14675a02SAndroid Build Coastguard Worker"""Helper functions for writing tests.""" 15*14675a02SAndroid Build Coastguard Worker 16*14675a02SAndroid Build Coastguard Workerimport tempfile 17*14675a02SAndroid Build Coastguard Workerfrom typing import Any, Mapping 18*14675a02SAndroid Build Coastguard Worker 19*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf 20*14675a02SAndroid Build Coastguard Worker 21*14675a02SAndroid Build Coastguard Worker 22*14675a02SAndroid Build Coastguard Workerdef create_checkpoint(data: Mapping[str, Any]) -> bytes: 23*14675a02SAndroid Build Coastguard Worker """Creates a TensorFlow checkpoint.""" 24*14675a02SAndroid Build Coastguard Worker with tempfile.NamedTemporaryFile() as tmpfile: 25*14675a02SAndroid Build Coastguard Worker with tf.compat.v1.Session() as session: 26*14675a02SAndroid Build Coastguard Worker session.run( 27*14675a02SAndroid Build Coastguard Worker tf.raw_ops.Save( 28*14675a02SAndroid Build Coastguard Worker filename=tmpfile.name, 29*14675a02SAndroid Build Coastguard Worker tensor_names=list(data.keys()), 30*14675a02SAndroid Build Coastguard Worker data=list(data.values()))) 31*14675a02SAndroid Build Coastguard Worker with open(tmpfile.name, 'rb') as f: 32*14675a02SAndroid Build Coastguard Worker return f.read() 33*14675a02SAndroid Build Coastguard Worker 34*14675a02SAndroid Build Coastguard Worker 35*14675a02SAndroid Build Coastguard Workerdef read_tensor_from_checkpoint(checkpoint: bytes, tensor_name: str, 36*14675a02SAndroid Build Coastguard Worker dt: tf.DType) -> Any: 37*14675a02SAndroid Build Coastguard Worker """Reads a single tensor from a checkpoint.""" 38*14675a02SAndroid Build Coastguard Worker with tempfile.NamedTemporaryFile('wb') as tmpfile: 39*14675a02SAndroid Build Coastguard Worker tmpfile.write(checkpoint) 40*14675a02SAndroid Build Coastguard Worker tmpfile.flush() 41*14675a02SAndroid Build Coastguard Worker with tf.compat.v1.Session() as session: 42*14675a02SAndroid Build Coastguard Worker return session.run( 43*14675a02SAndroid Build Coastguard Worker tf.raw_ops.Restore( 44*14675a02SAndroid Build Coastguard Worker file_pattern=tmpfile.name, tensor_name=tensor_name, dt=dt)) 45