1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expresus or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""MaterializableValueReference that reads from a TensorFlow checkpoint.""" 15 16from typing import Any, Optional 17import uuid 18 19import tensorflow as tf 20import tensorflow_federated as tff 21 22 23class CheckpointTensorReference(tff.program.MaterializableValueReference): 24 """A reference to a tensor in a TF checkpoint file.""" 25 26 def __init__(self, tensor_name: str, dtype: tf.DType, shape: Any, 27 checkpoint_future: tff.async_utils.SharedAwaitable): 28 """Constructs a new CheckpointTensorReference object. 29 30 Args: 31 tensor_name: The name of the tensor in the TF checkpoint. 32 dtype: The type of the tensor. 33 shape: The shape of the tensor, expressed as a value convertible to 34 `tf.TensorShape`. 35 checkpoint_future: A `tff.async_utils.SharedAwaitable` that resolves to 36 the TF checkpoint bytes once they're available. 37 """ 38 self._tensor_name = tensor_name 39 self._type_signature = tff.TensorType(dtype, shape) 40 self._checkpoint_future = checkpoint_future 41 self._tensor: Optional[tf.Tensor] = None 42 43 @property 44 def type_signature(self) -> tff.Type: 45 return self._type_signature 46 47 async def get_value(self) -> tff.program.MaterializedValue: 48 if self._tensor is None: 49 checkpoint = await self._checkpoint_future 50 # Write to a file in TensorFlow's RamFileSystem to avoid disk I/O. 51 tmpfile = f'ram://{uuid.uuid4()}.ckpt' 52 with tf.io.gfile.GFile(tmpfile, 'wb') as f: 53 f.write(checkpoint) 54 try: 55 self._tensor = tf.raw_ops.RestoreV2( 56 prefix=tmpfile, 57 tensor_names=[self._tensor_name], 58 shape_and_slices=[''], 59 dtypes=[self._type_signature.dtype])[0] 60 finally: 61 tf.io.gfile.remove(tmpfile) 62 63 try: 64 return self._tensor.numpy() 65 except AttributeError as e: 66 raise ValueError('get_value is only supported in eager mode.') from e 67