xref: /aosp_15_r20/external/federated-compute/fcp/demo/checkpoint_tensor_reference.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expresus or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""MaterializableValueReference that reads from a TensorFlow checkpoint."""
15
16from typing import Any, Optional
17import uuid
18
19import tensorflow as tf
20import tensorflow_federated as tff
21
22
23class CheckpointTensorReference(tff.program.MaterializableValueReference):
24  """A reference to a tensor in a TF checkpoint file."""
25
26  def __init__(self, tensor_name: str, dtype: tf.DType, shape: Any,
27               checkpoint_future: tff.async_utils.SharedAwaitable):
28    """Constructs a new CheckpointTensorReference object.
29
30    Args:
31      tensor_name: The name of the tensor in the TF checkpoint.
32      dtype: The type of the tensor.
33      shape: The shape of the tensor, expressed as a value convertible to
34        `tf.TensorShape`.
35      checkpoint_future: A `tff.async_utils.SharedAwaitable` that resolves to
36        the TF checkpoint bytes once they're available.
37    """
38    self._tensor_name = tensor_name
39    self._type_signature = tff.TensorType(dtype, shape)
40    self._checkpoint_future = checkpoint_future
41    self._tensor: Optional[tf.Tensor] = None
42
43  @property
44  def type_signature(self) -> tff.Type:
45    return self._type_signature
46
47  async def get_value(self) -> tff.program.MaterializedValue:
48    if self._tensor is None:
49      checkpoint = await self._checkpoint_future
50      # Write to a file in TensorFlow's RamFileSystem to avoid disk I/O.
51      tmpfile = f'ram://{uuid.uuid4()}.ckpt'
52      with tf.io.gfile.GFile(tmpfile, 'wb') as f:
53        f.write(checkpoint)
54      try:
55        self._tensor = tf.raw_ops.RestoreV2(
56            prefix=tmpfile,
57            tensor_names=[self._tensor_name],
58            shape_and_slices=[''],
59            dtypes=[self._type_signature.dtype])[0]
60      finally:
61        tf.io.gfile.remove(tmpfile)
62
63    try:
64      return self._tensor.numpy()
65    except AttributeError as e:
66      raise ValueError('get_value is only supported in eager mode.') from e
67