xref: /aosp_15_r20/external/tensorflow/tensorflow/python/saved_model/nested_structure_coder.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Module that encodes (decodes) nested structures into (from) protos.
16
17The intended use is to serialize everything needed to restore a `Function` that
18was saved into a SavedModel. This may include concrete function inputs and
19outputs, signatures, function specs, etc.
20
21Example use:
22# Encode into proto.
23signature_proto = nested_structure_coder.encode_structure(
24    function.input_signature)
25# Decode into a Python object.
26restored_signature = nested_structure_coder.decode_proto(signature_proto)
27"""
28
29import collections
30import functools
31import warnings
32
33from tensorflow.core.protobuf import struct_pb2
34from tensorflow.python.data.ops import dataset_ops
35from tensorflow.python.data.ops import iterator_ops
36from tensorflow.python.data.ops import optional_ops
37from tensorflow.python.distribute import values
38from tensorflow.python.framework import dtypes
39from tensorflow.python.framework import extension_type
40from tensorflow.python.framework import indexed_slices
41from tensorflow.python.framework import sparse_tensor
42from tensorflow.python.framework import tensor_shape
43from tensorflow.python.framework import tensor_spec
44from tensorflow.python.framework import tensor_util
45from tensorflow.python.framework import type_spec
46from tensorflow.python.ops import resource_variable_ops
47from tensorflow.python.ops import tensor_array_ops
48from tensorflow.python.ops.ragged import ragged_tensor
49from tensorflow.python.ops.ragged import row_partition
50from tensorflow.python.util import compat
51from tensorflow.python.util import nest
52from tensorflow.python.util.compat import collections_abc
53from tensorflow.python.util.tf_export import tf_export
54
55
56class NotEncodableError(Exception):
57  """Error raised when a coder cannot encode an object."""
58
59
60def register_codec(x):
61  """Registers a codec to use for encoding/decoding.
62
63  Args:
64    x: The codec object to register. The object must implement can_encode,
65      do_encode, can_decode, and do_decode. See the various _*Codec classes for
66      examples.
67  """
68  _codecs.append(x)
69
70
71def _get_encoders():
72  return [(c.can_encode, c.do_encode) for c in _codecs]
73
74
75def _get_decoders():
76  return [(c.can_decode, c.do_decode) for c in _codecs]
77
78
79def _map_structure(pyobj, coders):
80  for can, do in coders:
81    if can(pyobj):
82      recursion_fn = functools.partial(_map_structure, coders=coders)
83      return do(pyobj, recursion_fn)
84  raise NotEncodableError(
85      f"No encoder for object {str(pyobj)} of type {type(pyobj)}.")
86
87
88@tf_export("__internal__.saved_model.encode_structure", v1=[])
89def encode_structure(nested_structure):
90  """Encodes nested structures composed of encodable types into a proto.
91
92  Args:
93    nested_structure: Structure to encode.
94
95  Returns:
96    Encoded proto.
97
98  Raises:
99    NotEncodableError: For values for which there are no encoders.
100  """
101  return _map_structure(nested_structure, _get_encoders())
102
103
104def can_encode(nested_structure):
105  """Determines whether a nested structure can be encoded into a proto.
106
107  Args:
108    nested_structure: Structure to encode.
109
110  Returns:
111    True if the nested structured can be encoded.
112  """
113  try:
114    encode_structure(nested_structure)
115  except NotEncodableError:
116    return False
117  return True
118
119
120@tf_export("__internal__.saved_model.decode_proto", v1=[])
121def decode_proto(proto):
122  """Decodes proto representing a nested structure.
123
124  Args:
125    proto: Proto to decode.
126
127  Returns:
128    Decoded structure.
129
130  Raises:
131    NotEncodableError: For values for which there are no encoders.
132  """
133  return _map_structure(proto, _get_decoders())
134
135
136class _ListCodec:
137  """Codec for lists."""
138
139  def can_encode(self, pyobj):
140    return isinstance(pyobj, list)
141
142  def do_encode(self, list_value, encode_fn):
143    encoded_list = struct_pb2.StructuredValue()
144    encoded_list.list_value.CopyFrom(struct_pb2.ListValue())
145    for element in list_value:
146      encoded_list.list_value.values.add().CopyFrom(encode_fn(element))
147    return encoded_list
148
149  def can_decode(self, value):
150    return value.HasField("list_value")
151
152  def do_decode(self, value, decode_fn):
153    return [decode_fn(element) for element in value.list_value.values]
154
155
156def _is_tuple(obj):
157  return not _is_named_tuple(obj) and isinstance(obj, tuple)
158
159
160def _is_named_tuple(instance):
161  """Returns True iff `instance` is a `namedtuple`.
162
163  Args:
164    instance: An instance of a Python object.
165
166  Returns:
167    True if `instance` is a `namedtuple`.
168  """
169  if not isinstance(instance, tuple):
170    return False
171  return (hasattr(instance, "_fields") and
172          isinstance(instance._fields, collections_abc.Sequence) and
173          all(isinstance(f, str) for f in instance._fields))
174
175
176class _TupleCodec:
177  """Codec for tuples."""
178
179  def can_encode(self, pyobj):
180    return _is_tuple(pyobj)
181
182  def do_encode(self, tuple_value, encode_fn):
183    encoded_tuple = struct_pb2.StructuredValue()
184    encoded_tuple.tuple_value.CopyFrom(struct_pb2.TupleValue())
185    for element in tuple_value:
186      encoded_tuple.tuple_value.values.add().CopyFrom(encode_fn(element))
187    return encoded_tuple
188
189  def can_decode(self, value):
190    return value.HasField("tuple_value")
191
192  def do_decode(self, value, decode_fn):
193    return tuple(decode_fn(element) for element in value.tuple_value.values)
194
195
196class _DictCodec:
197  """Codec for dicts."""
198
199  def can_encode(self, pyobj):
200    return isinstance(pyobj, dict)
201
202  def do_encode(self, dict_value, encode_fn):
203    encoded_dict = struct_pb2.StructuredValue()
204    encoded_dict.dict_value.CopyFrom(struct_pb2.DictValue())
205    for key, value in dict_value.items():
206      encoded_dict.dict_value.fields[key].CopyFrom(encode_fn(value))
207    return encoded_dict
208
209  def can_decode(self, value):
210    return value.HasField("dict_value")
211
212  def do_decode(self, value, decode_fn):
213    return {key: decode_fn(val) for key, val in value.dict_value.fields.items()}
214
215
216class _NamedTupleCodec:
217  """Codec for namedtuples.
218
219  Encoding and decoding a namedtuple reconstructs a namedtuple with a different
220  actual Python type, but with the same `typename` and `fields`.
221  """
222
223  def can_encode(self, pyobj):
224    return _is_named_tuple(pyobj)
225
226  def do_encode(self, named_tuple_value, encode_fn):
227    encoded_named_tuple = struct_pb2.StructuredValue()
228    encoded_named_tuple.named_tuple_value.CopyFrom(struct_pb2.NamedTupleValue())
229    encoded_named_tuple.named_tuple_value.name = \
230      named_tuple_value.__class__.__name__
231    for key in named_tuple_value._fields:
232      pair = encoded_named_tuple.named_tuple_value.values.add()
233      pair.key = key
234      pair.value.CopyFrom(encode_fn(named_tuple_value._asdict()[key]))
235    return encoded_named_tuple
236
237  def can_decode(self, value):
238    return value.HasField("named_tuple_value")
239
240  def do_decode(self, value, decode_fn):
241    key_value_pairs = value.named_tuple_value.values
242    items = [(pair.key, decode_fn(pair.value)) for pair in key_value_pairs]
243    named_tuple_type = collections.namedtuple(value.named_tuple_value.name,
244                                              [item[0] for item in items])
245    return named_tuple_type(**dict(items))
246
247
248class _Float64Codec:
249  """Codec for floats."""
250
251  def can_encode(self, pyobj):
252    return isinstance(pyobj, float)
253
254  def do_encode(self, float64_value, encode_fn):
255    del encode_fn
256    value = struct_pb2.StructuredValue()
257    value.float64_value = float64_value
258    return value
259
260  def can_decode(self, value):
261    return value.HasField("float64_value")
262
263  def do_decode(self, value, decode_fn):
264    del decode_fn
265    return value.float64_value
266
267
268class _Int64Codec:
269  """Codec for Python integers (limited to 64 bit values)."""
270
271  def can_encode(self, pyobj):
272    return not isinstance(pyobj, bool) and isinstance(pyobj, int)
273
274  def do_encode(self, int_value, encode_fn):
275    del encode_fn
276    value = struct_pb2.StructuredValue()
277    value.int64_value = int_value
278    return value
279
280  def can_decode(self, value):
281    return value.HasField("int64_value")
282
283  def do_decode(self, value, decode_fn):
284    del decode_fn
285    return int(value.int64_value)
286
287
288class _StringCodec:
289  """Codec for strings.
290
291  See StructuredValue.string_value in proto/struct.proto for more detailed
292  explanation.
293  """
294
295  def can_encode(self, pyobj):
296    return isinstance(pyobj, str)
297
298  def do_encode(self, string_value, encode_fn):
299    del encode_fn
300    value = struct_pb2.StructuredValue()
301    value.string_value = string_value
302    return value
303
304  def can_decode(self, value):
305    return value.HasField("string_value")
306
307  def do_decode(self, value, decode_fn):
308    del decode_fn
309    return compat.as_str(value.string_value)
310
311
312class _NoneCodec:
313  """Codec for None."""
314
315  def can_encode(self, pyobj):
316    return pyobj is None
317
318  def do_encode(self, none_value, encode_fn):
319    del encode_fn, none_value
320    value = struct_pb2.StructuredValue()
321    value.none_value.CopyFrom(struct_pb2.NoneValue())
322    return value
323
324  def can_decode(self, value):
325    return value.HasField("none_value")
326
327  def do_decode(self, value, decode_fn):
328    del decode_fn, value
329    return None
330
331
332class _BoolCodec:
333  """Codec for booleans."""
334
335  def can_encode(self, pyobj):
336    return isinstance(pyobj, bool)
337
338  def do_encode(self, bool_value, encode_fn):
339    del encode_fn
340    value = struct_pb2.StructuredValue()
341    value.bool_value = bool_value
342    return value
343
344  def can_decode(self, value):
345    return value.HasField("bool_value")
346
347  def do_decode(self, value, decode_fn):
348    del decode_fn
349    return value.bool_value
350
351
352class _TensorShapeCodec:
353  """Codec for `TensorShape`."""
354
355  def can_encode(self, pyobj):
356    return isinstance(pyobj, tensor_shape.TensorShape)
357
358  def do_encode(self, tensor_shape_value, encode_fn):
359    del encode_fn
360    encoded_tensor_shape = struct_pb2.StructuredValue()
361    encoded_tensor_shape.tensor_shape_value.CopyFrom(
362        tensor_shape_value.as_proto())
363    return encoded_tensor_shape
364
365  def can_decode(self, value):
366    return value.HasField("tensor_shape_value")
367
368  def do_decode(self, value, decode_fn):
369    del decode_fn
370    return tensor_shape.TensorShape(value.tensor_shape_value)
371
372
373class _TensorTypeCodec:
374  """Codec for `TensorType`."""
375
376  def can_encode(self, pyobj):
377    return isinstance(pyobj, dtypes.DType)
378
379  def do_encode(self, tensor_dtype_value, encode_fn):
380    del encode_fn
381    encoded_tensor_type = struct_pb2.StructuredValue()
382    encoded_tensor_type.tensor_dtype_value = tensor_dtype_value.as_datatype_enum
383    return encoded_tensor_type
384
385  def can_decode(self, value):
386    return value.HasField("tensor_dtype_value")
387
388  def do_decode(self, value, decode_fn):
389    del decode_fn
390    return dtypes.DType(value.tensor_dtype_value)
391
392
393class _TensorSpecCodec:
394  """Codec for `TensorSpec`."""
395
396  def can_encode(self, pyobj):
397    # BoundedTensorSpec has its own decoder.
398    return (isinstance(pyobj, tensor_spec.TensorSpec) and
399            not isinstance(pyobj, tensor_spec.BoundedTensorSpec))
400
401  def do_encode(self, tensor_spec_value, encode_fn):
402    encoded_tensor_spec = struct_pb2.StructuredValue()
403    encoded_tensor_spec.tensor_spec_value.CopyFrom(
404        struct_pb2.TensorSpecProto(
405            shape=encode_fn(tensor_spec_value.shape).tensor_shape_value,
406            dtype=encode_fn(tensor_spec_value.dtype).tensor_dtype_value,
407            name=tensor_spec_value.name))
408    return encoded_tensor_spec
409
410  def can_decode(self, value):
411    return value.HasField("tensor_spec_value")
412
413  def do_decode(self, value, decode_fn):
414    name = value.tensor_spec_value.name
415    return tensor_spec.TensorSpec(
416        shape=decode_fn(
417            struct_pb2.StructuredValue(
418                tensor_shape_value=value.tensor_spec_value.shape)),
419        dtype=decode_fn(
420            struct_pb2.StructuredValue(
421                tensor_dtype_value=value.tensor_spec_value.dtype)),
422        name=(name if name else None))
423
424
425class _BoundedTensorSpecCodec:
426  """Codec for `BoundedTensorSpec`."""
427
428  def can_encode(self, pyobj):
429    return isinstance(pyobj, tensor_spec.BoundedTensorSpec)
430
431  def do_encode(self, bounded_tensor_spec_value, encode_fn):
432    """Returns an encoded proto for the given `tf.BoundedTensorSpec`."""
433    encoded_bounded_tensor_spec = struct_pb2.StructuredValue()
434    encoded_bounded_tensor_spec.bounded_tensor_spec_value.CopyFrom(
435        struct_pb2.BoundedTensorSpecProto(
436            shape=encode_fn(bounded_tensor_spec_value.shape).tensor_shape_value,
437            dtype=encode_fn(bounded_tensor_spec_value.dtype).tensor_dtype_value,
438            name=bounded_tensor_spec_value.name,
439            minimum=tensor_util.make_tensor_proto(
440                bounded_tensor_spec_value.minimum),
441            maximum=tensor_util.make_tensor_proto(
442                bounded_tensor_spec_value.maximum)))
443    return encoded_bounded_tensor_spec
444
445  def can_decode(self, value):
446    return value.HasField("bounded_tensor_spec_value")
447
448  def do_decode(self, value, decode_fn):
449    btsv = value.bounded_tensor_spec_value
450    name = btsv.name
451    return tensor_spec.BoundedTensorSpec(
452        shape=decode_fn(
453            struct_pb2.StructuredValue(tensor_shape_value=btsv.shape)),
454        dtype=decode_fn(
455            struct_pb2.StructuredValue(tensor_dtype_value=btsv.dtype)),
456        minimum=tensor_util.MakeNdarray(btsv.minimum),
457        maximum=tensor_util.MakeNdarray(btsv.maximum),
458        name=(name if name else None))
459
460
461# TODO(b/238903802): Use TraceType serialization and specific protos.
462class _TypeSpecCodec:
463  """Codec for `tf.TypeSpec`."""
464
465  # Mapping from enum value to type (TypeSpec subclass).
466  TYPE_SPEC_CLASS_FROM_PROTO = {
467      struct_pb2.TypeSpecProto.SPARSE_TENSOR_SPEC:
468          sparse_tensor.SparseTensorSpec,
469      struct_pb2.TypeSpecProto.INDEXED_SLICES_SPEC:
470          indexed_slices.IndexedSlicesSpec,
471      struct_pb2.TypeSpecProto.RAGGED_TENSOR_SPEC:
472          ragged_tensor.RaggedTensorSpec,
473      struct_pb2.TypeSpecProto.TENSOR_ARRAY_SPEC:
474          tensor_array_ops.TensorArraySpec,
475      struct_pb2.TypeSpecProto.DATA_DATASET_SPEC:
476          dataset_ops.DatasetSpec,
477      struct_pb2.TypeSpecProto.DATA_ITERATOR_SPEC:
478          iterator_ops.IteratorSpec,
479      struct_pb2.TypeSpecProto.OPTIONAL_SPEC:
480          optional_ops.OptionalSpec,
481      struct_pb2.TypeSpecProto.PER_REPLICA_SPEC:
482          values.PerReplicaSpec,
483      struct_pb2.TypeSpecProto.VARIABLE_SPEC:
484          resource_variable_ops.VariableSpec,
485      struct_pb2.TypeSpecProto.ROW_PARTITION_SPEC:
486          row_partition.RowPartitionSpec,
487  }
488
489  # Mapping from type (TypeSpec subclass) to enum value.
490  TYPE_SPEC_CLASS_TO_PROTO = dict(
491      (cls, enum) for (enum, cls) in TYPE_SPEC_CLASS_FROM_PROTO.items())
492
493  def can_encode(self, pyobj):
494    """Returns true if `pyboj` can be encoded as a TypeSpec."""
495    if type(pyobj) in self.TYPE_SPEC_CLASS_TO_PROTO:  # pylint: disable=unidiomatic-typecheck
496      return True
497
498    # Check if it's a registered type.
499    if isinstance(pyobj, type_spec.TypeSpec):
500      try:
501        type_spec.get_name(type(pyobj))
502        return True
503      except ValueError:
504        return False
505
506    return False
507
508  def do_encode(self, type_spec_value, encode_fn):
509    """Returns an encoded proto for the given `tf.TypeSpec`."""
510    type_spec_class = self.TYPE_SPEC_CLASS_TO_PROTO.get(type(type_spec_value))
511    type_spec_class_name = type(type_spec_value).__name__
512
513    if type_spec_class is None:
514      type_spec_class_name = type_spec.get_name(type(type_spec_value))
515      if isinstance(type_spec_value, extension_type.ExtensionTypeSpec):
516        type_spec_class = struct_pb2.TypeSpecProto.EXTENSION_TYPE_SPEC
517      else:
518        type_spec_class = struct_pb2.TypeSpecProto.REGISTERED_TYPE_SPEC
519        # Support for saving registered TypeSpecs is currently experimental.
520        # Issue a warning to indicate the limitations.
521        warnings.warn("Encoding a StructuredValue with type %s; loading this "
522                      "StructuredValue will require that this type be "
523                      "imported and registered." % type_spec_class_name)
524
525    type_state = type_spec_value._serialize()  # pylint: disable=protected-access
526    num_flat_components = len(
527        nest.flatten(type_spec_value._component_specs, expand_composites=True))  # pylint: disable=protected-access
528    encoded_type_spec = struct_pb2.StructuredValue()
529    encoded_type_spec.type_spec_value.CopyFrom(
530        struct_pb2.TypeSpecProto(
531            type_spec_class=type_spec_class,
532            type_state=encode_fn(type_state),
533            type_spec_class_name=type_spec_class_name,
534            num_flat_components=num_flat_components))
535    return encoded_type_spec
536
537  def can_decode(self, value):
538    return value.HasField("type_spec_value")
539
540  def do_decode(self, value, decode_fn):
541    """Returns the `tf.TypeSpec` encoded by the proto `value`."""
542    type_spec_proto = value.type_spec_value
543    type_spec_class_enum = type_spec_proto.type_spec_class
544    class_name = type_spec_proto.type_spec_class_name
545
546    if type_spec_class_enum == struct_pb2.TypeSpecProto.REGISTERED_TYPE_SPEC:
547      try:
548        type_spec_class = type_spec.lookup(class_name)
549      except ValueError as e:
550        raise ValueError(
551            f"The type '{class_name}' has not been registered.  It must be "
552            "registered before you load this object (typically by importing "
553            "its module).") from e
554    elif type_spec_class_enum == struct_pb2.TypeSpecProto.EXTENSION_TYPE_SPEC:
555      try:
556        type_spec_class = type_spec.lookup(class_name)
557      except ValueError:
558        type_spec_class = extension_type.AnonymousExtensionTypeSpec
559        warnings.warn("The type %r has not been registered.  Falling back to "
560                      "using AnonymousExtensionTypeSpec instead.")
561    else:
562      if type_spec_class_enum not in self.TYPE_SPEC_CLASS_FROM_PROTO:
563        raise ValueError(
564            f"The type '{class_name}' is not supported by this version of "
565            "TensorFlow. (The object you are loading must have been created "
566            "with a newer version of TensorFlow.)")
567      type_spec_class = self.TYPE_SPEC_CLASS_FROM_PROTO[type_spec_class_enum]
568
569    # pylint: disable=protected-access
570    return type_spec_class._deserialize(decode_fn(type_spec_proto.type_state))
571
572
573_codecs = [
574    _ListCodec(),
575    _TupleCodec(),
576    _NamedTupleCodec(),
577    _StringCodec(),
578    _Float64Codec(),
579    _NoneCodec(),
580    _Int64Codec(),
581    _TensorShapeCodec(),
582    _BoolCodec(),
583    _BoundedTensorSpecCodec(),
584    _TensorTypeCodec(),
585    _DictCodec(),
586    _TensorSpecCodec(),
587    _TypeSpecCodec(),
588]
589