1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Module that encodes (decodes) nested structures into (from) protos. 16 17The intended use is to serialize everything needed to restore a `Function` that 18was saved into a SavedModel. This may include concrete function inputs and 19outputs, signatures, function specs, etc. 20 21Example use: 22# Encode into proto. 23signature_proto = nested_structure_coder.encode_structure( 24 function.input_signature) 25# Decode into a Python object. 26restored_signature = nested_structure_coder.decode_proto(signature_proto) 27""" 28 29import collections 30import functools 31import warnings 32 33from tensorflow.core.protobuf import struct_pb2 34from tensorflow.python.data.ops import dataset_ops 35from tensorflow.python.data.ops import iterator_ops 36from tensorflow.python.data.ops import optional_ops 37from tensorflow.python.distribute import values 38from tensorflow.python.framework import dtypes 39from tensorflow.python.framework import extension_type 40from tensorflow.python.framework import indexed_slices 41from tensorflow.python.framework import sparse_tensor 42from tensorflow.python.framework import tensor_shape 43from tensorflow.python.framework import tensor_spec 44from tensorflow.python.framework import tensor_util 45from tensorflow.python.framework import type_spec 46from tensorflow.python.ops import resource_variable_ops 47from tensorflow.python.ops import tensor_array_ops 48from tensorflow.python.ops.ragged import ragged_tensor 49from tensorflow.python.ops.ragged import row_partition 50from tensorflow.python.util import compat 51from tensorflow.python.util import nest 52from tensorflow.python.util.compat import collections_abc 53from tensorflow.python.util.tf_export import tf_export 54 55 56class NotEncodableError(Exception): 57 """Error raised when a coder cannot encode an object.""" 58 59 60def register_codec(x): 61 """Registers a codec to use for encoding/decoding. 62 63 Args: 64 x: The codec object to register. The object must implement can_encode, 65 do_encode, can_decode, and do_decode. See the various _*Codec classes for 66 examples. 67 """ 68 _codecs.append(x) 69 70 71def _get_encoders(): 72 return [(c.can_encode, c.do_encode) for c in _codecs] 73 74 75def _get_decoders(): 76 return [(c.can_decode, c.do_decode) for c in _codecs] 77 78 79def _map_structure(pyobj, coders): 80 for can, do in coders: 81 if can(pyobj): 82 recursion_fn = functools.partial(_map_structure, coders=coders) 83 return do(pyobj, recursion_fn) 84 raise NotEncodableError( 85 f"No encoder for object {str(pyobj)} of type {type(pyobj)}.") 86 87 88@tf_export("__internal__.saved_model.encode_structure", v1=[]) 89def encode_structure(nested_structure): 90 """Encodes nested structures composed of encodable types into a proto. 91 92 Args: 93 nested_structure: Structure to encode. 94 95 Returns: 96 Encoded proto. 97 98 Raises: 99 NotEncodableError: For values for which there are no encoders. 100 """ 101 return _map_structure(nested_structure, _get_encoders()) 102 103 104def can_encode(nested_structure): 105 """Determines whether a nested structure can be encoded into a proto. 106 107 Args: 108 nested_structure: Structure to encode. 109 110 Returns: 111 True if the nested structured can be encoded. 112 """ 113 try: 114 encode_structure(nested_structure) 115 except NotEncodableError: 116 return False 117 return True 118 119 120@tf_export("__internal__.saved_model.decode_proto", v1=[]) 121def decode_proto(proto): 122 """Decodes proto representing a nested structure. 123 124 Args: 125 proto: Proto to decode. 126 127 Returns: 128 Decoded structure. 129 130 Raises: 131 NotEncodableError: For values for which there are no encoders. 132 """ 133 return _map_structure(proto, _get_decoders()) 134 135 136class _ListCodec: 137 """Codec for lists.""" 138 139 def can_encode(self, pyobj): 140 return isinstance(pyobj, list) 141 142 def do_encode(self, list_value, encode_fn): 143 encoded_list = struct_pb2.StructuredValue() 144 encoded_list.list_value.CopyFrom(struct_pb2.ListValue()) 145 for element in list_value: 146 encoded_list.list_value.values.add().CopyFrom(encode_fn(element)) 147 return encoded_list 148 149 def can_decode(self, value): 150 return value.HasField("list_value") 151 152 def do_decode(self, value, decode_fn): 153 return [decode_fn(element) for element in value.list_value.values] 154 155 156def _is_tuple(obj): 157 return not _is_named_tuple(obj) and isinstance(obj, tuple) 158 159 160def _is_named_tuple(instance): 161 """Returns True iff `instance` is a `namedtuple`. 162 163 Args: 164 instance: An instance of a Python object. 165 166 Returns: 167 True if `instance` is a `namedtuple`. 168 """ 169 if not isinstance(instance, tuple): 170 return False 171 return (hasattr(instance, "_fields") and 172 isinstance(instance._fields, collections_abc.Sequence) and 173 all(isinstance(f, str) for f in instance._fields)) 174 175 176class _TupleCodec: 177 """Codec for tuples.""" 178 179 def can_encode(self, pyobj): 180 return _is_tuple(pyobj) 181 182 def do_encode(self, tuple_value, encode_fn): 183 encoded_tuple = struct_pb2.StructuredValue() 184 encoded_tuple.tuple_value.CopyFrom(struct_pb2.TupleValue()) 185 for element in tuple_value: 186 encoded_tuple.tuple_value.values.add().CopyFrom(encode_fn(element)) 187 return encoded_tuple 188 189 def can_decode(self, value): 190 return value.HasField("tuple_value") 191 192 def do_decode(self, value, decode_fn): 193 return tuple(decode_fn(element) for element in value.tuple_value.values) 194 195 196class _DictCodec: 197 """Codec for dicts.""" 198 199 def can_encode(self, pyobj): 200 return isinstance(pyobj, dict) 201 202 def do_encode(self, dict_value, encode_fn): 203 encoded_dict = struct_pb2.StructuredValue() 204 encoded_dict.dict_value.CopyFrom(struct_pb2.DictValue()) 205 for key, value in dict_value.items(): 206 encoded_dict.dict_value.fields[key].CopyFrom(encode_fn(value)) 207 return encoded_dict 208 209 def can_decode(self, value): 210 return value.HasField("dict_value") 211 212 def do_decode(self, value, decode_fn): 213 return {key: decode_fn(val) for key, val in value.dict_value.fields.items()} 214 215 216class _NamedTupleCodec: 217 """Codec for namedtuples. 218 219 Encoding and decoding a namedtuple reconstructs a namedtuple with a different 220 actual Python type, but with the same `typename` and `fields`. 221 """ 222 223 def can_encode(self, pyobj): 224 return _is_named_tuple(pyobj) 225 226 def do_encode(self, named_tuple_value, encode_fn): 227 encoded_named_tuple = struct_pb2.StructuredValue() 228 encoded_named_tuple.named_tuple_value.CopyFrom(struct_pb2.NamedTupleValue()) 229 encoded_named_tuple.named_tuple_value.name = \ 230 named_tuple_value.__class__.__name__ 231 for key in named_tuple_value._fields: 232 pair = encoded_named_tuple.named_tuple_value.values.add() 233 pair.key = key 234 pair.value.CopyFrom(encode_fn(named_tuple_value._asdict()[key])) 235 return encoded_named_tuple 236 237 def can_decode(self, value): 238 return value.HasField("named_tuple_value") 239 240 def do_decode(self, value, decode_fn): 241 key_value_pairs = value.named_tuple_value.values 242 items = [(pair.key, decode_fn(pair.value)) for pair in key_value_pairs] 243 named_tuple_type = collections.namedtuple(value.named_tuple_value.name, 244 [item[0] for item in items]) 245 return named_tuple_type(**dict(items)) 246 247 248class _Float64Codec: 249 """Codec for floats.""" 250 251 def can_encode(self, pyobj): 252 return isinstance(pyobj, float) 253 254 def do_encode(self, float64_value, encode_fn): 255 del encode_fn 256 value = struct_pb2.StructuredValue() 257 value.float64_value = float64_value 258 return value 259 260 def can_decode(self, value): 261 return value.HasField("float64_value") 262 263 def do_decode(self, value, decode_fn): 264 del decode_fn 265 return value.float64_value 266 267 268class _Int64Codec: 269 """Codec for Python integers (limited to 64 bit values).""" 270 271 def can_encode(self, pyobj): 272 return not isinstance(pyobj, bool) and isinstance(pyobj, int) 273 274 def do_encode(self, int_value, encode_fn): 275 del encode_fn 276 value = struct_pb2.StructuredValue() 277 value.int64_value = int_value 278 return value 279 280 def can_decode(self, value): 281 return value.HasField("int64_value") 282 283 def do_decode(self, value, decode_fn): 284 del decode_fn 285 return int(value.int64_value) 286 287 288class _StringCodec: 289 """Codec for strings. 290 291 See StructuredValue.string_value in proto/struct.proto for more detailed 292 explanation. 293 """ 294 295 def can_encode(self, pyobj): 296 return isinstance(pyobj, str) 297 298 def do_encode(self, string_value, encode_fn): 299 del encode_fn 300 value = struct_pb2.StructuredValue() 301 value.string_value = string_value 302 return value 303 304 def can_decode(self, value): 305 return value.HasField("string_value") 306 307 def do_decode(self, value, decode_fn): 308 del decode_fn 309 return compat.as_str(value.string_value) 310 311 312class _NoneCodec: 313 """Codec for None.""" 314 315 def can_encode(self, pyobj): 316 return pyobj is None 317 318 def do_encode(self, none_value, encode_fn): 319 del encode_fn, none_value 320 value = struct_pb2.StructuredValue() 321 value.none_value.CopyFrom(struct_pb2.NoneValue()) 322 return value 323 324 def can_decode(self, value): 325 return value.HasField("none_value") 326 327 def do_decode(self, value, decode_fn): 328 del decode_fn, value 329 return None 330 331 332class _BoolCodec: 333 """Codec for booleans.""" 334 335 def can_encode(self, pyobj): 336 return isinstance(pyobj, bool) 337 338 def do_encode(self, bool_value, encode_fn): 339 del encode_fn 340 value = struct_pb2.StructuredValue() 341 value.bool_value = bool_value 342 return value 343 344 def can_decode(self, value): 345 return value.HasField("bool_value") 346 347 def do_decode(self, value, decode_fn): 348 del decode_fn 349 return value.bool_value 350 351 352class _TensorShapeCodec: 353 """Codec for `TensorShape`.""" 354 355 def can_encode(self, pyobj): 356 return isinstance(pyobj, tensor_shape.TensorShape) 357 358 def do_encode(self, tensor_shape_value, encode_fn): 359 del encode_fn 360 encoded_tensor_shape = struct_pb2.StructuredValue() 361 encoded_tensor_shape.tensor_shape_value.CopyFrom( 362 tensor_shape_value.as_proto()) 363 return encoded_tensor_shape 364 365 def can_decode(self, value): 366 return value.HasField("tensor_shape_value") 367 368 def do_decode(self, value, decode_fn): 369 del decode_fn 370 return tensor_shape.TensorShape(value.tensor_shape_value) 371 372 373class _TensorTypeCodec: 374 """Codec for `TensorType`.""" 375 376 def can_encode(self, pyobj): 377 return isinstance(pyobj, dtypes.DType) 378 379 def do_encode(self, tensor_dtype_value, encode_fn): 380 del encode_fn 381 encoded_tensor_type = struct_pb2.StructuredValue() 382 encoded_tensor_type.tensor_dtype_value = tensor_dtype_value.as_datatype_enum 383 return encoded_tensor_type 384 385 def can_decode(self, value): 386 return value.HasField("tensor_dtype_value") 387 388 def do_decode(self, value, decode_fn): 389 del decode_fn 390 return dtypes.DType(value.tensor_dtype_value) 391 392 393class _TensorSpecCodec: 394 """Codec for `TensorSpec`.""" 395 396 def can_encode(self, pyobj): 397 # BoundedTensorSpec has its own decoder. 398 return (isinstance(pyobj, tensor_spec.TensorSpec) and 399 not isinstance(pyobj, tensor_spec.BoundedTensorSpec)) 400 401 def do_encode(self, tensor_spec_value, encode_fn): 402 encoded_tensor_spec = struct_pb2.StructuredValue() 403 encoded_tensor_spec.tensor_spec_value.CopyFrom( 404 struct_pb2.TensorSpecProto( 405 shape=encode_fn(tensor_spec_value.shape).tensor_shape_value, 406 dtype=encode_fn(tensor_spec_value.dtype).tensor_dtype_value, 407 name=tensor_spec_value.name)) 408 return encoded_tensor_spec 409 410 def can_decode(self, value): 411 return value.HasField("tensor_spec_value") 412 413 def do_decode(self, value, decode_fn): 414 name = value.tensor_spec_value.name 415 return tensor_spec.TensorSpec( 416 shape=decode_fn( 417 struct_pb2.StructuredValue( 418 tensor_shape_value=value.tensor_spec_value.shape)), 419 dtype=decode_fn( 420 struct_pb2.StructuredValue( 421 tensor_dtype_value=value.tensor_spec_value.dtype)), 422 name=(name if name else None)) 423 424 425class _BoundedTensorSpecCodec: 426 """Codec for `BoundedTensorSpec`.""" 427 428 def can_encode(self, pyobj): 429 return isinstance(pyobj, tensor_spec.BoundedTensorSpec) 430 431 def do_encode(self, bounded_tensor_spec_value, encode_fn): 432 """Returns an encoded proto for the given `tf.BoundedTensorSpec`.""" 433 encoded_bounded_tensor_spec = struct_pb2.StructuredValue() 434 encoded_bounded_tensor_spec.bounded_tensor_spec_value.CopyFrom( 435 struct_pb2.BoundedTensorSpecProto( 436 shape=encode_fn(bounded_tensor_spec_value.shape).tensor_shape_value, 437 dtype=encode_fn(bounded_tensor_spec_value.dtype).tensor_dtype_value, 438 name=bounded_tensor_spec_value.name, 439 minimum=tensor_util.make_tensor_proto( 440 bounded_tensor_spec_value.minimum), 441 maximum=tensor_util.make_tensor_proto( 442 bounded_tensor_spec_value.maximum))) 443 return encoded_bounded_tensor_spec 444 445 def can_decode(self, value): 446 return value.HasField("bounded_tensor_spec_value") 447 448 def do_decode(self, value, decode_fn): 449 btsv = value.bounded_tensor_spec_value 450 name = btsv.name 451 return tensor_spec.BoundedTensorSpec( 452 shape=decode_fn( 453 struct_pb2.StructuredValue(tensor_shape_value=btsv.shape)), 454 dtype=decode_fn( 455 struct_pb2.StructuredValue(tensor_dtype_value=btsv.dtype)), 456 minimum=tensor_util.MakeNdarray(btsv.minimum), 457 maximum=tensor_util.MakeNdarray(btsv.maximum), 458 name=(name if name else None)) 459 460 461# TODO(b/238903802): Use TraceType serialization and specific protos. 462class _TypeSpecCodec: 463 """Codec for `tf.TypeSpec`.""" 464 465 # Mapping from enum value to type (TypeSpec subclass). 466 TYPE_SPEC_CLASS_FROM_PROTO = { 467 struct_pb2.TypeSpecProto.SPARSE_TENSOR_SPEC: 468 sparse_tensor.SparseTensorSpec, 469 struct_pb2.TypeSpecProto.INDEXED_SLICES_SPEC: 470 indexed_slices.IndexedSlicesSpec, 471 struct_pb2.TypeSpecProto.RAGGED_TENSOR_SPEC: 472 ragged_tensor.RaggedTensorSpec, 473 struct_pb2.TypeSpecProto.TENSOR_ARRAY_SPEC: 474 tensor_array_ops.TensorArraySpec, 475 struct_pb2.TypeSpecProto.DATA_DATASET_SPEC: 476 dataset_ops.DatasetSpec, 477 struct_pb2.TypeSpecProto.DATA_ITERATOR_SPEC: 478 iterator_ops.IteratorSpec, 479 struct_pb2.TypeSpecProto.OPTIONAL_SPEC: 480 optional_ops.OptionalSpec, 481 struct_pb2.TypeSpecProto.PER_REPLICA_SPEC: 482 values.PerReplicaSpec, 483 struct_pb2.TypeSpecProto.VARIABLE_SPEC: 484 resource_variable_ops.VariableSpec, 485 struct_pb2.TypeSpecProto.ROW_PARTITION_SPEC: 486 row_partition.RowPartitionSpec, 487 } 488 489 # Mapping from type (TypeSpec subclass) to enum value. 490 TYPE_SPEC_CLASS_TO_PROTO = dict( 491 (cls, enum) for (enum, cls) in TYPE_SPEC_CLASS_FROM_PROTO.items()) 492 493 def can_encode(self, pyobj): 494 """Returns true if `pyboj` can be encoded as a TypeSpec.""" 495 if type(pyobj) in self.TYPE_SPEC_CLASS_TO_PROTO: # pylint: disable=unidiomatic-typecheck 496 return True 497 498 # Check if it's a registered type. 499 if isinstance(pyobj, type_spec.TypeSpec): 500 try: 501 type_spec.get_name(type(pyobj)) 502 return True 503 except ValueError: 504 return False 505 506 return False 507 508 def do_encode(self, type_spec_value, encode_fn): 509 """Returns an encoded proto for the given `tf.TypeSpec`.""" 510 type_spec_class = self.TYPE_SPEC_CLASS_TO_PROTO.get(type(type_spec_value)) 511 type_spec_class_name = type(type_spec_value).__name__ 512 513 if type_spec_class is None: 514 type_spec_class_name = type_spec.get_name(type(type_spec_value)) 515 if isinstance(type_spec_value, extension_type.ExtensionTypeSpec): 516 type_spec_class = struct_pb2.TypeSpecProto.EXTENSION_TYPE_SPEC 517 else: 518 type_spec_class = struct_pb2.TypeSpecProto.REGISTERED_TYPE_SPEC 519 # Support for saving registered TypeSpecs is currently experimental. 520 # Issue a warning to indicate the limitations. 521 warnings.warn("Encoding a StructuredValue with type %s; loading this " 522 "StructuredValue will require that this type be " 523 "imported and registered." % type_spec_class_name) 524 525 type_state = type_spec_value._serialize() # pylint: disable=protected-access 526 num_flat_components = len( 527 nest.flatten(type_spec_value._component_specs, expand_composites=True)) # pylint: disable=protected-access 528 encoded_type_spec = struct_pb2.StructuredValue() 529 encoded_type_spec.type_spec_value.CopyFrom( 530 struct_pb2.TypeSpecProto( 531 type_spec_class=type_spec_class, 532 type_state=encode_fn(type_state), 533 type_spec_class_name=type_spec_class_name, 534 num_flat_components=num_flat_components)) 535 return encoded_type_spec 536 537 def can_decode(self, value): 538 return value.HasField("type_spec_value") 539 540 def do_decode(self, value, decode_fn): 541 """Returns the `tf.TypeSpec` encoded by the proto `value`.""" 542 type_spec_proto = value.type_spec_value 543 type_spec_class_enum = type_spec_proto.type_spec_class 544 class_name = type_spec_proto.type_spec_class_name 545 546 if type_spec_class_enum == struct_pb2.TypeSpecProto.REGISTERED_TYPE_SPEC: 547 try: 548 type_spec_class = type_spec.lookup(class_name) 549 except ValueError as e: 550 raise ValueError( 551 f"The type '{class_name}' has not been registered. It must be " 552 "registered before you load this object (typically by importing " 553 "its module).") from e 554 elif type_spec_class_enum == struct_pb2.TypeSpecProto.EXTENSION_TYPE_SPEC: 555 try: 556 type_spec_class = type_spec.lookup(class_name) 557 except ValueError: 558 type_spec_class = extension_type.AnonymousExtensionTypeSpec 559 warnings.warn("The type %r has not been registered. Falling back to " 560 "using AnonymousExtensionTypeSpec instead.") 561 else: 562 if type_spec_class_enum not in self.TYPE_SPEC_CLASS_FROM_PROTO: 563 raise ValueError( 564 f"The type '{class_name}' is not supported by this version of " 565 "TensorFlow. (The object you are loading must have been created " 566 "with a newer version of TensorFlow.)") 567 type_spec_class = self.TYPE_SPEC_CLASS_FROM_PROTO[type_spec_class_enum] 568 569 # pylint: disable=protected-access 570 return type_spec_class._deserialize(decode_fn(type_spec_proto.type_state)) 571 572 573_codecs = [ 574 _ListCodec(), 575 _TupleCodec(), 576 _NamedTupleCodec(), 577 _StringCodec(), 578 _Float64Codec(), 579 _NoneCodec(), 580 _Int64Codec(), 581 _TensorShapeCodec(), 582 _BoolCodec(), 583 _BoundedTensorSpecCodec(), 584 _TensorTypeCodec(), 585 _DictCodec(), 586 _TensorSpecCodec(), 587 _TypeSpecCodec(), 588] 589