xref: /aosp_15_r20/external/federated-compute/fcp/demo/test_utils.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC
2*14675a02SAndroid Build Coastguard Worker#
3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License");
4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License.
5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at
6*14675a02SAndroid Build Coastguard Worker#
7*14675a02SAndroid Build Coastguard Worker#      http://www.apache.org/licenses/LICENSE-2.0
8*14675a02SAndroid Build Coastguard Worker#
9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software
10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS,
11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and
13*14675a02SAndroid Build Coastguard Worker# limitations under the License.
14*14675a02SAndroid Build Coastguard Worker"""Helper functions for writing tests."""
15*14675a02SAndroid Build Coastguard Worker
16*14675a02SAndroid Build Coastguard Workerimport tempfile
17*14675a02SAndroid Build Coastguard Workerfrom typing import Any, Mapping
18*14675a02SAndroid Build Coastguard Worker
19*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf
20*14675a02SAndroid Build Coastguard Worker
21*14675a02SAndroid Build Coastguard Worker
22*14675a02SAndroid Build Coastguard Workerdef create_checkpoint(data: Mapping[str, Any]) -> bytes:
23*14675a02SAndroid Build Coastguard Worker  """Creates a TensorFlow checkpoint."""
24*14675a02SAndroid Build Coastguard Worker  with tempfile.NamedTemporaryFile() as tmpfile:
25*14675a02SAndroid Build Coastguard Worker    with tf.compat.v1.Session() as session:
26*14675a02SAndroid Build Coastguard Worker      session.run(
27*14675a02SAndroid Build Coastguard Worker          tf.raw_ops.Save(
28*14675a02SAndroid Build Coastguard Worker              filename=tmpfile.name,
29*14675a02SAndroid Build Coastguard Worker              tensor_names=list(data.keys()),
30*14675a02SAndroid Build Coastguard Worker              data=list(data.values())))
31*14675a02SAndroid Build Coastguard Worker    with open(tmpfile.name, 'rb') as f:
32*14675a02SAndroid Build Coastguard Worker      return f.read()
33*14675a02SAndroid Build Coastguard Worker
34*14675a02SAndroid Build Coastguard Worker
35*14675a02SAndroid Build Coastguard Workerdef read_tensor_from_checkpoint(checkpoint: bytes, tensor_name: str,
36*14675a02SAndroid Build Coastguard Worker                                dt: tf.DType) -> Any:
37*14675a02SAndroid Build Coastguard Worker  """Reads a single tensor from a checkpoint."""
38*14675a02SAndroid Build Coastguard Worker  with tempfile.NamedTemporaryFile('wb') as tmpfile:
39*14675a02SAndroid Build Coastguard Worker    tmpfile.write(checkpoint)
40*14675a02SAndroid Build Coastguard Worker    tmpfile.flush()
41*14675a02SAndroid Build Coastguard Worker    with tf.compat.v1.Session() as session:
42*14675a02SAndroid Build Coastguard Worker      return session.run(
43*14675a02SAndroid Build Coastguard Worker          tf.raw_ops.Restore(
44*14675a02SAndroid Build Coastguard Worker              file_pattern=tmpfile.name, tensor_name=tensor_name, dt=dt))
45