1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Helper functions for writing tests.""" 15 16import tempfile 17from typing import Any, Mapping 18 19import tensorflow as tf 20 21 22def create_checkpoint(data: Mapping[str, Any]) -> bytes: 23 """Creates a TensorFlow checkpoint.""" 24 with tempfile.NamedTemporaryFile() as tmpfile: 25 with tf.compat.v1.Session() as session: 26 session.run( 27 tf.raw_ops.Save( 28 filename=tmpfile.name, 29 tensor_names=list(data.keys()), 30 data=list(data.values()))) 31 with open(tmpfile.name, 'rb') as f: 32 return f.read() 33 34 35def read_tensor_from_checkpoint(checkpoint: bytes, tensor_name: str, 36 dt: tf.DType) -> Any: 37 """Reads a single tensor from a checkpoint.""" 38 with tempfile.NamedTemporaryFile('wb') as tmpfile: 39 tmpfile.write(checkpoint) 40 tmpfile.flush() 41 with tf.compat.v1.Session() as session: 42 return session.run( 43 tf.raw_ops.Restore( 44 file_pattern=tmpfile.name, tensor_name=tensor_name, dt=dt)) 45