1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Helper methods for proto creation logic.""" 15 16from typing import Optional 17 18import tensorflow as tf 19import tensorflow_federated as tff 20 21from fcp.artifact_building import tensor_utils 22from fcp.artifact_building import type_checks 23from fcp.protos import plan_pb2 24 25 26def make_tensor_spec_from_tensor( 27 t: tf.Tensor, shape_hint: Optional[tf.TensorShape] = None 28) -> tf.TensorSpec: 29 """Creates a `TensorSpec` from Tensor w/ optional shape hint. 30 31 Args: 32 t: A `tf.Tensor` instance to be used to create a `TensorSpec`. 33 shape_hint: A `tf.TensorShape` that provides a fully defined shape in the 34 case that `t` is partially defined. If `t` has a fully defined shape, 35 `shape_hint` is ignored. `shape_hint` must be compatible with the 36 partially defined shape of `t`. 37 38 Returns: 39 A `tf.TensorSpec` instance corresponding to the input `tf.Tensor`. 40 41 Raises: 42 NotImplementedError: If the input `tf.Tensor` type is not supported. 43 TypeError: if `shape_hint` is not `None` and is incompatible with the 44 runtime shape of `t`. 45 """ 46 if not tf.is_tensor(t): 47 raise NotImplementedError( 48 'Cannot handle type {t}: {v}'.format(t=type(t), v=t) 49 ) 50 derived_shape = tf.TensorShape(t.shape) 51 if not derived_shape.is_fully_defined() and shape_hint is not None: 52 if derived_shape.is_compatible_with(shape_hint): 53 shape = shape_hint 54 else: 55 raise TypeError( 56 'shape_hint is not compatible with tensor (' 57 f'{shape_hint} vs {derived_shape})' 58 ) 59 else: 60 shape = derived_shape 61 return tf.TensorSpec(shape, t.dtype, name=t.name) 62 63 64def make_measurement( 65 t: tf.Tensor, name: str, tff_type: tff.types.TensorType 66) -> plan_pb2.Measurement: 67 """Creates a `plan_pb.Measurement` descriptor for a tensor. 68 69 Args: 70 t: A tensor to create the measurement for. 71 name: The name of the measurement (e.g. 'server/loss'). 72 tff_type: The `tff.Type` of the measurement. 73 74 Returns: 75 An instance of `plan_pb.Measurement`. 76 77 Raises: 78 ValueError: If the `dtype`s or `shape`s of the provided tensor and TFF type 79 do not match. 80 """ 81 type_checks.check_type(tff_type, tff.types.TensorType) 82 if tff_type.dtype != t.dtype: 83 raise ValueError( 84 f'`tff_type.dtype`: {tff_type.dtype} does not match ' 85 f"provided tensor's dtype: {t.dtype}." 86 ) 87 if tff_type.shape.is_fully_defined() and t.shape.is_fully_defined(): 88 if tff_type.shape.as_list() != t.shape.as_list(): 89 raise ValueError( 90 f'`tff_type.shape`: {tff_type.shape} does not match ' 91 f"provided tensor's shape: {t.shape}." 92 ) 93 return plan_pb2.Measurement( 94 read_op_name=t.name, 95 name=name, 96 tff_type=tff.types.serialize_type(tff_type).SerializeToString(), 97 ) 98 99 100def make_metric(v: tf.Variable, stat_name_prefix: str) -> plan_pb2.Metric: 101 """Creates a `plan_pb.Metric` descriptor for a resource variable. 102 103 The stat name is formed by stripping the leading `..../` prefix and any 104 colon-based suffix. 105 106 Args: 107 v: A variable to create the metric descriptor for. 108 stat_name_prefix: The prefix (string) to use in formulating a stat name, 109 excluding the trailing slash `/` (added automatically). 110 111 Returns: 112 An instance of `plan_pb.Metric` for `v`. 113 114 Raises: 115 TypeError: If the arguments are of the wrong types. 116 ValueError: If the arguments are malformed (e.g., no leading name prefix). 117 """ 118 type_checks.check_type(stat_name_prefix, str, name='stat_name_prefix') 119 if not hasattr(v, 'read_value'): 120 raise TypeError('Expected a resource variable, found {!r}.'.format(type(v))) 121 bare_name = tensor_utils.bare_name(v.name) 122 if '/' not in bare_name: 123 raise ValueError( 124 'Expected a prefix in the name, found none in {}.'.format(bare_name) 125 ) 126 stat_name = '{}/{}'.format( 127 stat_name_prefix, bare_name[(bare_name.find('/') + 1) :] 128 ) 129 return plan_pb2.Metric(variable_name=v.read_value().name, stat_name=stat_name) 130