1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Utilities specific to the manipulation of tensors and operators.""" 15 16from typing import Any, Callable, Optional, Union 17 18import tensorflow as tf 19 20 21###################################################################### 22# Helper functions for names and naming. 23# 24def bare_name(v) -> str: 25 """Strips off the part after the colon in a tensor name.""" 26 name = name_or_str(v) 27 if name[0] == '^': 28 name = name[1:] 29 # User specified names are everything up to the first colon. User supplied 30 # names cannot contain colons, TensorFlow will raise an error on invalid name. 31 colon = name.find(':') 32 if colon >= 0: 33 return name[:colon] 34 else: 35 return name 36 37 38def name_or_str(v) -> str: 39 """Returns the name of v, or if v has no name attr, str(op).""" 40 if hasattr(v, 'name'): 41 name = v.name 42 assert isinstance(name, str) 43 return name 44 return str(v) 45 46 47###################################################################### 48# Helper function for graphs. 49# 50 51 52def import_graph_def_from_any(an) -> tf.compat.v1.GraphDef: 53 """Parses a tf.compat.v1.GraphDef from an Any message. 54 55 Args: 56 an: An 'Any' message, which contains a serialized tf.compat.v1.GraphDef. The 57 type_url field of the Any message must identify a supported type; 58 currently, the only supported type is 'type.googleapis.com/GraphDef'. 59 60 Returns: 61 A tf.compat.v1.GraphDef object. 62 """ 63 assert an 64 # The only kind of supported graph is a TensorFlow GraphDef. 65 assert an.Is(tf.compat.v1.GraphDef.DESCRIPTOR) 66 g = tf.compat.v1.GraphDef() 67 an.Unpack(g) 68 return g 69 70 71###################################################################### 72# Helper functions for savers or saverdefs. 73# 74 75 76def save( 77 filename: Union[tf.Tensor, str], 78 tensor_names: list[str], 79 tensors: list[tf.Tensor], 80 tensor_slices: Optional[list[str]] = None, 81 name: str = 'save', 82 save_op: Callable[..., Any] = tf.raw_ops.SaveSlices, 83) -> tf.Operation: 84 """Saves a list of tensors to file. 85 86 This function always passes a value for the `tensor_slices` argument in order 87 to use the `SaveSlices` op (instead of a `Save` op). 88 89 Args: 90 filename: A string or a scalar tensor of dtype string that specifies the 91 path to file. 92 tensor_names: A list of strings. 93 tensors: A list of tensors to be saved. 94 tensor_slices: An optional list of strings, that specifies the shape and 95 slices of a larger virtual tensor that each tensor is a part of. If not 96 specified, each tensor is saved as a full slice. 97 name: An optional name for the op. 98 save_op: A callable that creates the op(s) to use for performing the tensor 99 save. Defaults to `tf.raw_ops.SaveSlices`. 100 101 Returns: 102 A `SaveSlices` op in graph mode or None in eager mode. 103 """ 104 tensor_slices = tensor_slices if tensor_slices else ([''] * len(tensors)) 105 return save_op( 106 filename=filename, 107 tensor_names=tensor_names, 108 shapes_and_slices=tensor_slices, 109 data=tensors, 110 name=name, 111 ) 112 113 114def restore( 115 filename: Union[tf.Tensor, str], 116 tensor_name: str, 117 tensor_type: tf.DType, 118 tensor_shape: Optional[tf.TensorShape] = None, 119 name: str = 'restore', 120) -> tf.Tensor: 121 """Restores a tensor from the file. 122 123 It is a wrapper of `tf.raw_ops.RestoreV2`. When used in graph mode, it adds a 124 `RestoreV2` op to the graph. 125 126 Args: 127 filename: A string or a scalar tensor of dtype string that specifies the 128 path to file. 129 tensor_name: The name of the tensor to restore. 130 tensor_type: The type of the tensor to restore. 131 tensor_shape: Optional. The shape of the tensor to restore. 132 name: An optional name for the op. 133 134 Returns: 135 A tensor of dtype `tensor_type`. 136 """ 137 shape_str = '' 138 slice_str = '' 139 if tensor_shape is not None and tensor_shape.rank > 0: 140 shape_str = ' '.join('%d' % d for d in tensor_shape) + ' ' 141 # Ideally we want to pass an empty string to slice, but this is not allowed 142 # because the size of the slice string list (after the string is split by 143 # separator ':') needs to match the rank of the tensor (see b/197779415 for 144 # more information). 145 slice_str = ':-' * tensor_shape.rank 146 restored_tensors = tf.raw_ops.RestoreV2( 147 prefix=filename, 148 tensor_names=[tensor_name], 149 shape_and_slices=[shape_str + slice_str], 150 dtypes=[tensor_type], 151 name=name, 152 ) 153 return restored_tensors[0] 154