xref: /aosp_15_r20/external/federated-compute/fcp/artifact_building/proto_helpers.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Helper methods for proto creation logic."""
15
16from typing import Optional
17
18import tensorflow as tf
19import tensorflow_federated as tff
20
21from fcp.artifact_building import tensor_utils
22from fcp.artifact_building import type_checks
23from fcp.protos import plan_pb2
24
25
26def make_tensor_spec_from_tensor(
27    t: tf.Tensor, shape_hint: Optional[tf.TensorShape] = None
28) -> tf.TensorSpec:
29  """Creates a `TensorSpec` from Tensor w/ optional shape hint.
30
31  Args:
32    t: A `tf.Tensor` instance to be used to create a `TensorSpec`.
33    shape_hint: A `tf.TensorShape` that provides a fully defined shape in the
34      case that `t` is partially defined. If `t` has a fully defined shape,
35      `shape_hint` is ignored. `shape_hint` must be compatible with the
36      partially defined shape of `t`.
37
38  Returns:
39    A `tf.TensorSpec` instance corresponding to the input `tf.Tensor`.
40
41  Raises:
42    NotImplementedError: If the input `tf.Tensor` type is not supported.
43    TypeError: if `shape_hint` is not `None` and is incompatible with the
44      runtime shape of `t`.
45  """
46  if not tf.is_tensor(t):
47    raise NotImplementedError(
48        'Cannot handle type {t}: {v}'.format(t=type(t), v=t)
49    )
50  derived_shape = tf.TensorShape(t.shape)
51  if not derived_shape.is_fully_defined() and shape_hint is not None:
52    if derived_shape.is_compatible_with(shape_hint):
53      shape = shape_hint
54    else:
55      raise TypeError(
56          'shape_hint is not compatible with tensor ('
57          f'{shape_hint} vs {derived_shape})'
58      )
59  else:
60    shape = derived_shape
61  return tf.TensorSpec(shape, t.dtype, name=t.name)
62
63
64def make_measurement(
65    t: tf.Tensor, name: str, tff_type: tff.types.TensorType
66) -> plan_pb2.Measurement:
67  """Creates a `plan_pb.Measurement` descriptor for a tensor.
68
69  Args:
70    t: A tensor to create the measurement for.
71    name: The name of the measurement (e.g. 'server/loss').
72    tff_type: The `tff.Type` of the measurement.
73
74  Returns:
75    An instance of `plan_pb.Measurement`.
76
77  Raises:
78    ValueError: If the `dtype`s or `shape`s of the provided tensor and TFF type
79      do not match.
80  """
81  type_checks.check_type(tff_type, tff.types.TensorType)
82  if tff_type.dtype != t.dtype:
83    raise ValueError(
84        f'`tff_type.dtype`: {tff_type.dtype} does not match '
85        f"provided tensor's dtype: {t.dtype}."
86    )
87  if tff_type.shape.is_fully_defined() and t.shape.is_fully_defined():
88    if tff_type.shape.as_list() != t.shape.as_list():
89      raise ValueError(
90          f'`tff_type.shape`: {tff_type.shape} does not match '
91          f"provided tensor's shape: {t.shape}."
92      )
93  return plan_pb2.Measurement(
94      read_op_name=t.name,
95      name=name,
96      tff_type=tff.types.serialize_type(tff_type).SerializeToString(),
97  )
98
99
100def make_metric(v: tf.Variable, stat_name_prefix: str) -> plan_pb2.Metric:
101  """Creates a `plan_pb.Metric` descriptor for a resource variable.
102
103  The stat name is formed by stripping the leading `..../` prefix and any
104  colon-based suffix.
105
106  Args:
107    v: A variable to create the metric descriptor for.
108    stat_name_prefix: The prefix (string) to use in formulating a stat name,
109      excluding the trailing slash `/` (added automatically).
110
111  Returns:
112    An instance of `plan_pb.Metric` for `v`.
113
114  Raises:
115    TypeError: If the arguments are of the wrong types.
116    ValueError: If the arguments are malformed (e.g., no leading name prefix).
117  """
118  type_checks.check_type(stat_name_prefix, str, name='stat_name_prefix')
119  if not hasattr(v, 'read_value'):
120    raise TypeError('Expected a resource variable, found {!r}.'.format(type(v)))
121  bare_name = tensor_utils.bare_name(v.name)
122  if '/' not in bare_name:
123    raise ValueError(
124        'Expected a prefix in the name, found none in {}.'.format(bare_name)
125    )
126  stat_name = '{}/{}'.format(
127      stat_name_prefix, bare_name[(bare_name.find('/') + 1) :]
128  )
129  return plan_pb2.Metric(variable_name=v.read_value().name, stat_name=stat_name)
130