xref: /aosp_15_r20/external/federated-compute/fcp/artifact_building/tensor_utils.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Utilities specific to the manipulation of tensors and operators."""
15
16from typing import Any, Callable, Optional, Union
17
18import tensorflow as tf
19
20
21######################################################################
22# Helper functions for names and naming.
23#
24def bare_name(v) -> str:
25  """Strips off the part after the colon in a tensor name."""
26  name = name_or_str(v)
27  if name[0] == '^':
28    name = name[1:]
29  # User specified names are everything up to the first colon. User supplied
30  # names cannot contain colons, TensorFlow will raise an error on invalid name.
31  colon = name.find(':')
32  if colon >= 0:
33    return name[:colon]
34  else:
35    return name
36
37
38def name_or_str(v) -> str:
39  """Returns the name of v, or if v has no name attr, str(op)."""
40  if hasattr(v, 'name'):
41    name = v.name
42    assert isinstance(name, str)
43    return name
44  return str(v)
45
46
47######################################################################
48# Helper function for graphs.
49#
50
51
52def import_graph_def_from_any(an) -> tf.compat.v1.GraphDef:
53  """Parses a tf.compat.v1.GraphDef from an Any message.
54
55  Args:
56    an: An 'Any' message, which contains a serialized tf.compat.v1.GraphDef. The
57      type_url field of the Any message must identify a supported type;
58      currently, the only supported type is 'type.googleapis.com/GraphDef'.
59
60  Returns:
61    A tf.compat.v1.GraphDef object.
62  """
63  assert an
64  # The only kind of supported graph is a TensorFlow GraphDef.
65  assert an.Is(tf.compat.v1.GraphDef.DESCRIPTOR)
66  g = tf.compat.v1.GraphDef()
67  an.Unpack(g)
68  return g
69
70
71######################################################################
72# Helper functions for savers or saverdefs.
73#
74
75
76def save(
77    filename: Union[tf.Tensor, str],
78    tensor_names: list[str],
79    tensors: list[tf.Tensor],
80    tensor_slices: Optional[list[str]] = None,
81    name: str = 'save',
82    save_op: Callable[..., Any] = tf.raw_ops.SaveSlices,
83) -> tf.Operation:
84  """Saves a list of tensors to file.
85
86  This function always passes a value for the `tensor_slices` argument in order
87  to use the `SaveSlices` op (instead of a `Save` op).
88
89  Args:
90    filename: A string or a scalar tensor of dtype string that specifies the
91      path to file.
92    tensor_names: A list of strings.
93    tensors: A list of tensors to be saved.
94    tensor_slices: An optional list of strings, that specifies the shape and
95      slices of a larger virtual tensor that each tensor is a part of. If not
96      specified, each tensor is saved as a full slice.
97    name: An optional name for the op.
98    save_op: A callable that creates the op(s) to use for performing the tensor
99      save. Defaults to `tf.raw_ops.SaveSlices`.
100
101  Returns:
102    A `SaveSlices` op in graph mode or None in eager mode.
103  """
104  tensor_slices = tensor_slices if tensor_slices else ([''] * len(tensors))
105  return save_op(
106      filename=filename,
107      tensor_names=tensor_names,
108      shapes_and_slices=tensor_slices,
109      data=tensors,
110      name=name,
111  )
112
113
114def restore(
115    filename: Union[tf.Tensor, str],
116    tensor_name: str,
117    tensor_type: tf.DType,
118    tensor_shape: Optional[tf.TensorShape] = None,
119    name: str = 'restore',
120) -> tf.Tensor:
121  """Restores a tensor from the file.
122
123  It is a wrapper of `tf.raw_ops.RestoreV2`. When used in graph mode, it adds a
124  `RestoreV2` op to the graph.
125
126  Args:
127    filename: A string or a scalar tensor of dtype string that specifies the
128      path to file.
129    tensor_name: The name of the tensor to restore.
130    tensor_type: The type of the tensor to restore.
131    tensor_shape: Optional. The shape of the tensor to restore.
132    name: An optional name for the op.
133
134  Returns:
135    A tensor of dtype `tensor_type`.
136  """
137  shape_str = ''
138  slice_str = ''
139  if tensor_shape is not None and tensor_shape.rank > 0:
140    shape_str = ' '.join('%d' % d for d in tensor_shape) + ' '
141    # Ideally we want to pass an empty string to slice, but this is not allowed
142    # because the size of the slice string list (after the string is split by
143    # separator ':') needs to match the rank of the tensor (see b/197779415 for
144    # more information).
145    slice_str = ':-' * tensor_shape.rank
146  restored_tensors = tf.raw_ops.RestoreV2(
147      prefix=filename,
148      tensor_names=[tensor_name],
149      shape_and_slices=[shape_str + slice_str],
150      dtypes=[tensor_type],
151      name=name,
152  )
153  return restored_tensors[0]
154