1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Operators corresponding to Python builtin functions. 16 17List of built-in functions: https://docs.python.org/3/library/functions.html 18""" 19 20import functools 21import inspect 22 23import numpy as np 24 25from tensorflow.python.autograph.utils import py_func 26from tensorflow.python.autograph.utils import tensors 27from tensorflow.python.data.experimental.ops import cardinality 28from tensorflow.python.data.ops import dataset_ops 29from tensorflow.python.data.ops import iterator_ops 30from tensorflow.python.framework import constant_op 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import tensor_spec 34from tensorflow.python.framework import tensor_util 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import check_ops 37from tensorflow.python.ops import control_flow_ops 38from tensorflow.python.ops import gen_parsing_ops 39from tensorflow.python.ops import gen_string_ops 40from tensorflow.python.ops import list_ops 41from tensorflow.python.ops import math_ops 42from tensorflow.python.ops import sort_ops 43from tensorflow.python.util import lazy_loader 44from tensorflow.python.util import nest 45 46# TODO(b/145618471): Remove this dependency. 47# Lazy import to work around circular dependencies 48input_lib = lazy_loader.LazyLoader( 49 'input_lib', globals(), 50 'tensorflow.python.distribute.input_lib') 51parallel_ops = lazy_loader.LazyLoader( 52 'parallel_ops', globals(), 53 'tensorflow.python.ops.parallel_for.control_flow_ops') 54 55UNSPECIFIED = object() 56 57 58def overload_of(f): 59 if f in SUPPORTED_BUILTINS: 60 return BUILTIN_FUNCTIONS_MAP[f.__name__] 61 return f 62 63 64def _find_originating_frame(caller_fn_scope, innermost=True): 65 """Locates the frame in which `caller_fn_scope` was defined.""" 66 ctx_frame = inspect.currentframe() 67 result = None 68 while ctx_frame is not None: 69 # Note it should not be normally possible to get false positives this way 70 # because the function scope object is not accessible to user code (barring 71 # call stack introspection). 72 if ctx_frame.f_locals.get(caller_fn_scope.name, None) is caller_fn_scope: 73 result = ctx_frame 74 if innermost: 75 break 76 ctx_frame = ctx_frame.f_back 77 78 assert result is not None, ( 79 'the conversion process should ensure the caller_fn_scope is always' 80 ' found somewhere on the call stack') 81 82 return result 83 84 85def locals_in_original_context(caller_fn_scope): 86 """Executes the locals function in the context of a specified function.""" 87 return _find_originating_frame(caller_fn_scope, innermost=True).f_locals 88 89 90def globals_in_original_context(caller_fn_scope): 91 """Executes the locals function in the context of a specified function.""" 92 return _find_originating_frame(caller_fn_scope, innermost=True).f_globals 93 94 95def eval_in_original_context(f, args, caller_fn_scope): 96 """Executes the eval function in the context of a specified function.""" 97 # When control flow is rewritten using functions, eval should use the 98 # variables found in the same block where it was called. That is equivalent 99 # to the innermost function call. 100 ctx_frame = _find_originating_frame(caller_fn_scope, innermost=True) 101 102 args = ( 103 args[0], 104 ctx_frame.f_globals if len(args) < 2 else args[1], 105 ctx_frame.f_locals if len(args) < 3 else args[2], 106 ) 107 return f(*args) 108 109 110def super_in_original_context(f, args, caller_fn_scope): 111 """Executes the super function in the context of a specified function. 112 113 See https://docs.python.org/3/library/functions.html#super for the exact 114 details 115 116 Args: 117 f: Callable, typically the super builtin 118 args: List[Any], the original call arguments 119 caller_fn_scope: Optional[function_wrappers.FunctionScope], the function 120 scope of the converted function in which this call was originally made 121 122 Returns: 123 The result of calling `f` as if it was called in the frame indicated by 124 `caller_fn_scope`. 125 """ 126 127 # Only the no-arg call is desugared. 128 if args: 129 return f(*args) 130 131 # Inner functions seem to include their closure in f_locals, so we need 132 # to find the outermost frame. 133 ctx_frame = _find_originating_frame(caller_fn_scope, innermost=False) 134 135 # When super(..) is called without arguments, it looks for __class__ cell 136 # variable and the first argument passed in the enclosing function according 137 # to the spec https://www.python.org/dev/peps/pep-3135/ . 138 # 139 # We couldn't verify if `inspect.currentframe().f_code.co_varnames[0]` is 140 # guaranteed to be the first argument from an official doc or PEP, however, 141 # it's fairly stable and well established: 142 # - An unofficial community doc mentions it. 143 # https://python-reference.readthedocs.io/en/latest/docs/code/varnames.html 144 # - CPython has tests checking that order, which was merged in 2008, and 145 # unchanged since then. 146 # https://github.com/python/cpython/blame/2f224a077a83ac9de8a12bb7dcc516642b8176d8/Lib/lib2to3/tests/data/py2_test_grammar.py#L157 147 # https://github.com/python/cpython/blame/2f224a077a83ac9de8a12bb7dcc516642b8176d8/Lib/lib2to3/tests/data/py3_test_grammar.py#L192 148 # 149 # Note: the name can be more reliably obtained by inspecting the calling 150 # function's argspec. 151 # 152 # Even though methods can be declared using *args (def method(*args)), 153 # that pattern is disallowed by super() -- it raises super() no arguments. 154 # Method definitions using **kwargs are not allowed at all. 155 # In other words, we can always assume that self is on the first positional 156 # argument (for correct code). 157 # 158 # TODO(mdan): Consider additional checks in case the input code is incorrect. 159 # For example, the error might be cryptic compared to what super() regularly 160 # raises. 161 162 type_arg = ctx_frame.f_locals['__class__'] 163 self_arg_name = ctx_frame.f_code.co_varnames[0] 164 self_arg = ctx_frame.f_locals[self_arg_name] 165 return f(type_arg, self_arg) 166 167 168def abs_(x): 169 if tensor_util.is_tf_type(x): 170 return _tf_abs(x) 171 if isinstance(x, dataset_ops.DatasetV2): 172 return _tf_dataset_abs(x) 173 return _py_abs(x) 174 175 176def _tf_abs(x): 177 return math_ops.abs(x) 178 179 180def _tf_dataset_abs(x): 181 specs = nest.flatten(x.element_spec) 182 if len(specs) == 1: 183 return x.map(math_ops.abs, num_parallel_calls=dataset_ops.AUTOTUNE) 184 return x.map( 185 lambda *e: nest.map_structure(math_ops.abs, e), 186 num_parallel_calls=dataset_ops.AUTOTUNE) 187 188 189def _py_abs(x): 190 return abs(x) 191 192 193def float_(x=0): 194 if tensor_util.is_tf_type(x): 195 return _tf_float(x) 196 return _py_float(x) 197 198 199def _tf_float(x): 200 # TODO(mdan): We shouldn't assume float32. 201 if x.dtype == dtypes.string: 202 return gen_parsing_ops.string_to_number(x, out_type=dtypes.float32) 203 return math_ops.cast(x, dtype=dtypes.float32) 204 205 206def _py_float(x): 207 return float(x) 208 209 210def int_(x=0, base=UNSPECIFIED): 211 if tensor_util.is_tf_type(x): 212 return _tf_int(x, base) 213 return _py_int(x, base) 214 215 216def _tf_int(x, base): 217 if base not in (10, UNSPECIFIED): 218 raise NotImplementedError('base {} not supported for int'.format(base)) 219 220 # TODO(mdan): We shouldn't assume int32. 221 if x.dtype == dtypes.string: 222 return gen_parsing_ops.string_to_number(x, out_type=dtypes.int32) 223 return math_ops.cast(x, dtype=dtypes.int32) 224 225 226def _py_int(x, base): 227 if base is UNSPECIFIED: 228 return int(x) 229 return int(x, base) 230 231 232def len_(s): 233 if tensors.is_tensor_array(s): 234 return _tf_tensor_array_len(s) 235 elif tensors.is_tensor_list(s): 236 return _tf_tensor_list_len(s) 237 elif tensor_util.is_tf_type(s): 238 return _tf_tensor_len(s) 239 if isinstance(s, dataset_ops.DatasetV2): 240 return _tf_dataset_len(s) 241 return _py_len(s) 242 243 244def _tf_tensor_array_len(s): 245 return s.size() 246 247 248def _tf_tensor_list_len(s): 249 return list_ops.tensor_list_length(s) 250 251 252def _tf_tensor_len(s): 253 """Overload of len_ for Tensor arguments.""" 254 # Statically shaped tensors: length is known ahead of time. 255 if s.shape.ndims and s.shape.dims[0].value is not None: 256 return s.shape.dims[0].value 257 258 # Static shape of unknown dimensions: use dynamic shape but statically 259 # check that it's a scalar. 260 shape = array_ops.shape(s) 261 262 assert shape.shape, 'shape tensor of zero size? {}'.format(shape) 263 264 if shape.shape[0] == 0: 265 raise ValueError( 266 'len requires a non-scalar tensor, got one of shape {}'.format(shape)) 267 268 if shape.shape.dims[0].value is not None: 269 return array_ops.shape(s)[0] 270 271 # Fully dynamic shape: use ops. 272 rank = array_ops.rank(s) 273 274 def raise_zero_rank_error(): 275 msg = gen_string_ops.string_join( 276 ['len requires non-zero rank, got ', 277 gen_string_ops.as_string(rank)]) 278 with ops.control_dependencies([control_flow_ops.Assert(False, [msg])]): 279 return constant_op.constant(0, dtype=dtypes.int32) 280 281 return control_flow_ops.cond(rank > 0, lambda: array_ops.shape(s)[0], 282 raise_zero_rank_error) 283 284 285def _tf_dataset_len(s): 286 l = cardinality.cardinality(s) 287 msg = gen_string_ops.string_join([ 288 'len requires dataset with definitive cardinality, got ', 289 gen_string_ops.as_string(l) 290 ]) 291 # TODO (yongtang): UNKNOWN is treated as an error. 292 # In case there are more UNKNOWN cases for dataset, we could 293 # use dataset.reduce() to find out the length (in an expensive way). 294 with ops.control_dependencies([ 295 control_flow_ops.Assert( 296 math_ops.logical_and( 297 math_ops.not_equal(l, cardinality.INFINITE), 298 math_ops.not_equal(l, cardinality.UNKNOWN)), [msg]) 299 ]): 300 l = array_ops.identity(l) 301 302 return l 303 304 305def _py_len(s): 306 return len(s) 307 308 309def print_(*objects, **kwargs): 310 """Overload of the print builtin.""" 311 # Note: Python 2.6 doesn't support explicit keywords after starargs. 312 unknown_kwargs = tuple( 313 set(kwargs.keys()) - set(('sep', 'end', 'file', 'flush'))) 314 if unknown_kwargs: 315 raise ValueError('invalid keyword arguments: {}'.format(unknown_kwargs)) 316 317 # TODO(mdan): Use next.flatten(objects) instead? 318 if any(tensor_util.is_tf_type(o) for o in objects): 319 # TODO(mdan): use tf.print instead. 320 return _tf_py_func_print(objects, kwargs) 321 else: 322 _py_print(*objects, **kwargs) 323 324 325def _py_print(*objects, **kwargs): 326 print(*objects, **kwargs) 327 328 329def min_(*args, **kwargs): 330 if any(tensor_util.is_tf_type(s) for s in args): 331 return _tf_min(*args, **kwargs) 332 return _py_min(*args, **kwargs) 333 334 335def _tf_min(*args, **kwargs): 336 if len(kwargs): 337 kwargs_tuple = tuple(set(kwargs.keys())) 338 raise ValueError('These keyword arguments are ' 339 'currently not supported: {}'.format(kwargs_tuple)) 340 if len(args) == 1: 341 rank = args[0].shape.rank 342 if rank == 0: 343 return args[0] 344 if rank == 1: 345 return math_ops.reduce_min(*args, axis=0) 346 raise ValueError('min(arg) currently support only tensor with rank 1, ' 347 'but got a tensor with rank {}'.format(rank)) 348 for arg in args: 349 rank = arg.shape.rank 350 if rank != 0: 351 raise ValueError('min(arg1, arg2, *args) currently support ' 352 'only scalar tensor, but got a tensor ' 353 'with shape {}'.format(rank)) 354 return math_ops.reduce_min(args, axis=0) 355 356 357def _py_min(*args, **kwargs): 358 return min(*args, **kwargs) 359 360 361def max_(*args, **kwargs): 362 if any(tensor_util.is_tf_type(s) for s in args): 363 return _tf_max(*args, **kwargs) 364 return _py_max(*args, **kwargs) 365 366 367def _tf_max(*args, **kwargs): 368 if len(kwargs): 369 kwargs_tuple = tuple(set(kwargs.keys())) 370 raise ValueError('These keyword arguments are ' 371 'currently not supported: {}'.format(kwargs_tuple)) 372 if len(args) == 1: 373 rank = args[0].shape.rank 374 if rank == 0: 375 return args[0] 376 if rank == 1: 377 return math_ops.reduce_max(*args, axis=0) 378 raise ValueError('max(arg) currently support only tensor with rank 1, ' 379 'but got a tensor with rank {}'.format(rank)) 380 for arg in args: 381 rank = arg.shape.rank 382 if rank != 0: 383 raise ValueError('max(arg1, arg2, *args) currently support ' 384 'only scalar tensor, but got a tensor ' 385 'with shape {}'.format(rank)) 386 return math_ops.reduce_max(args, axis=0) 387 388 389def _py_max(*args, **kwargs): 390 return max(*args, **kwargs) 391 392 393def _tf_py_func_print(objects, kwargs): 394 """Overload of print_ as a py_func implementation.""" 395 override_kwargs = {k: v for k, v in kwargs.items() if v is not UNSPECIFIED} 396 if 'flush' not in override_kwargs: 397 # Defaulting to flushing the console in graph mode, which helps reduce 398 # garbled output in IPython. 399 override_kwargs['flush'] = True 400 401 def print_wrapper(*vals): 402 vals = tuple(v.numpy() if tensor_util.is_tf_type(v) else v for v in vals) 403 # TensorFlow doesn't seem to generate Unicode when passing strings to 404 # py_func. This causes the print to add a "b'" wrapper to the output, 405 # which is probably never what you want. 406 vals = tuple(v.decode('utf-8') if isinstance(v, bytes) else v for v in vals) 407 print(*vals, **override_kwargs) 408 409 return py_func.wrap_py_func( 410 print_wrapper, None, objects, use_dummy_return=True) 411 412 413def range_(start_or_stop, stop=UNSPECIFIED, step=UNSPECIFIED): 414 if any(tensor_util.is_tf_type(s) for s in (start_or_stop, stop, step)): 415 return _tf_range(start_or_stop, stop, step) 416 return _py_range(start_or_stop, stop, step) 417 418 419def _tf_range(start_or_stop, stop, step): 420 """Overload of range_ that generates a TF range tensor.""" 421 # Note: for static inputs (e.g. constants), tf.range errors out at graph 422 # construction time, instead of returning an empty tensor. Preventing the 423 # graph construction error aligns the semantics with Python. 424 425 # TODO(mdan): We should optimize this when a full tensor is not required. 426 if step is not UNSPECIFIED: 427 # TODO(mdan): Add argument coercion similar to other cases. 428 return math_ops.range(start_or_stop, stop, step) 429 if stop is not UNSPECIFIED: 430 stop = math_ops.maximum(start_or_stop, stop) 431 return math_ops.range(start_or_stop, stop) 432 start_or_stop = math_ops.maximum(start_or_stop, 0) 433 return math_ops.range(start_or_stop) 434 435 436def _py_range(start_or_stop, stop, step): 437 if step is not UNSPECIFIED: 438 return range(start_or_stop, stop, step) 439 if stop is not UNSPECIFIED: 440 return range(start_or_stop, stop) 441 return range(start_or_stop) 442 443 444def enumerate_(s, start=0): 445 if isinstance(s, dataset_ops.DatasetV2): 446 return _tf_dataset_enumerate(s, start) 447 if isinstance(s, 448 (input_lib.DistributedIterator, input_lib.DistributedDataset)): 449 raise NotImplementedError( 450 'use a for loop over the dataset and keep a separate counter') 451 return _py_enumerate(s, start) 452 453 454def _tf_dataset_enumerate(s, start=0): 455 return s.enumerate(start) 456 457 458def _py_enumerate(s, start=0): 459 return enumerate(s, start) 460 461 462def zip_(*iterables): 463 if all(isinstance(x, dataset_ops.DatasetV2) for x in iterables): 464 return _tf_dataset_zip(*iterables) 465 return _py_zip(*iterables) 466 467 468def _tf_dataset_zip(*iterables): 469 return dataset_ops.DatasetV2.zip(iterables) 470 471 472def _py_zip(*iterables): 473 return zip(*iterables) 474 475 476def map_(fn, *iterables): 477 if all(isinstance(x, dataset_ops.DatasetV2) for x in iterables): 478 return _tf_dataset_map(fn, *iterables) 479 return _py_map(fn, *iterables) 480 481 482def _tf_dataset_map(fn, *iterables): 483 return dataset_ops.DatasetV2.zip(iterables).map(fn) 484 485 486def _py_map(fn, *iterables): 487 return map(fn, *iterables) 488 489 490def next_(iterator, default=UNSPECIFIED): 491 if isinstance(iterator, iterator_ops.OwnedIterator): 492 return next_tf_iterator(iterator, default) 493 return next_py(iterator, default) 494 495 496# TODO(mdan): These checks should be easier. Fix the nest API. 497def _verify_spec_compatible(input_name, spec_name, input_, spec): 498 """Verifies that a symbol has a type compatible vith a given spec. 499 500 Here, compatibility is viewed in the general TensorFlow sense: that the dtypes 501 are the same after implicit conversion, if both are tensors. 502 503 This verifier ensures consistent treatment of types across AutoGraph. 504 505 Args: 506 input_name: A name to use for `input_` in error messages. 507 spec_name: A name to use for `spec` in error messages. 508 input_: Any, value to verify. 509 spec: TypeSpec that `input_` must be compatible with. 510 511 Raises: 512 ValueError if the two types have been determined not to be compatible. 513 """ 514 assert isinstance(spec, tensor_spec.TensorSpec) 515 if input is None: 516 # TODO(mdan): raise from None when switching to Py3. 517 raise ValueError('{} cannot be None'.format(input_name)) 518 519 # TODO(mdan): Use TensorCompatible when ready. 520 if isinstance(input_, (bool, int, float, str, np.ndarray)): 521 input_ = ops.convert_to_tensor_v2(input_) 522 523 input_dtype = getattr(input_, 'dtype', None) 524 525 if input_dtype != spec.dtype: 526 input_dtype_str = 'no dtype' if input_dtype is None else str(input_dtype) 527 528 raise TypeError( 529 '{} must have the same dtype as {}. Expected {}, got {}'.format( 530 input_name, spec_name, spec.dtype, input_dtype_str)) 531 532 533def _verify_structure_compatible(input_name, spec_name, input_, spec): 534 """Verifies that possibly-structured symbol has types compatible vith another. 535 536 See _verify_spec_compatible for a more concrete meaning of "compatible". 537 Unspec _verify_spec_compatible, which handles singular Tensor-spec objects, 538 verify_structures_compatible can process structures recognized by tf.nest. 539 540 Args: 541 input_name: A name to use for `input_` in error messages. 542 spec_name: A name to use for `spec` in error messages. 543 input_: Any, value to verify. May, but doesn't need to, be a structure. 544 spec: Any, value that `input_` must be compatible with. May, but doesn't 545 need to, be a structure. 546 547 Raises: 548 ValueError if the two types have been determined not to be compatible. 549 """ 550 try: 551 nest.assert_same_structure(input_, spec, expand_composites=True) 552 except (ValueError, TypeError) as e: 553 raise TypeError( 554 '{} must have the same element structure as {}.\n\n{}'.format( 555 input_name, spec_name, str(e))) 556 557 nest.map_structure( 558 functools.partial(_verify_spec_compatible, input_name, spec_name), input_, 559 spec) 560 561 562def next_tf_iterator(iterator, default=UNSPECIFIED): 563 if default is UNSPECIFIED: 564 # Without a default, fall back to the "normal" behavior which raises 565 # a runtime exception. 566 return next(iterator) 567 opt_iterate = iterator.get_next_as_optional() 568 _verify_structure_compatible('the default argument', 'the iterate', default, 569 iterator.element_spec) 570 return control_flow_ops.cond(opt_iterate.has_value(), opt_iterate.get_value, 571 lambda: default) 572 573 574def next_py(iterator, default=UNSPECIFIED): 575 if default is UNSPECIFIED: 576 return next(iterator) 577 return next(iterator, default) 578 579 580def filter_(function, iterable): 581 if isinstance(iterable, dataset_ops.DatasetV2): 582 return _tf_dataset_filter(function, iterable) 583 return _py_filter(function, iterable) 584 585 586def _tf_dataset_filter(function, iterable): 587 return iterable.filter(function) 588 589 590def _py_filter(function, iterable): 591 return filter(function, iterable) 592 593 594def any_(iterable): 595 if isinstance(iterable, dataset_ops.DatasetV2): 596 return _tf_dataset_any(iterable) 597 return _py_any(iterable) 598 599 600# any() operation is essentially a "if first True element exist". 601# For that it could be translated to `filter(True)` to filter out 602# only `True` element, and then `take(1)`. This works in tf.data 603# as tf.data's filter+take is done in pipeline so it will stop 604# as soon as `take(1)` returns. 605def _tf_dataset_any(iterable): 606 # check and make sure iterable.element_spec only consists of one 607 # element of tf.bool. 608 specs = nest.flatten(iterable.element_spec) 609 if len(specs) != 1 or specs[0].dtype != dtypes.bool: 610 raise ValueError('in graph mode, the "any" builtin only supports datasets ' 611 'that return bool scalars; got: {}'.format( 612 iterable.element_spec)) 613 ds = iterable.filter(lambda x: x) 614 ds = ds.take(1) 615 ds = ds.reduce(constant_op.constant(False, dtype=dtypes.bool), lambda _, y: y) 616 return ds 617 618 619def _py_any(iterable): 620 return any(iterable) 621 622 623def all_(iterable): 624 if isinstance(iterable, dataset_ops.DatasetV2): 625 return _tf_dataset_all(iterable) 626 return _py_all(iterable) 627 628 629# all() operation is similar to any() and could be translated 630# to `filter(False)` then `take(1)`, and check if `False` exists. 631def _tf_dataset_all(iterable): 632 # check and make sure iterable.element_spec only consists of one 633 # element of tf.bool. 634 specs = nest.flatten(iterable.element_spec) 635 if len(specs) != 1 or specs[0].dtype != dtypes.bool: 636 raise ValueError('in graph mode, the "all" builtin only supports datasets ' 637 'that return bool scalars; got: {}'.format( 638 iterable.element_spec)) 639 ds = iterable.filter(lambda x: math_ops.logical_not(x)) 640 ds = ds.take(1) 641 ds = ds.reduce(constant_op.constant(True, dtype=dtypes.bool), lambda _, y: y) 642 return ds 643 644 645def _py_all(iterable): 646 return all(iterable) 647 648 649def sorted_(iterable, key=UNSPECIFIED, reverse=UNSPECIFIED): 650 if tensor_util.is_tf_type(iterable): 651 return _tf_sorted(iterable, key, reverse) 652 return _py_sorted(iterable, key, reverse) 653 654 655def _tf_sorted(iterable, key, reverse): 656 """Overload of sorted_ for Tensor iterable.""" 657 if reverse is UNSPECIFIED: 658 direction = 'ASCENDING' 659 else: 660 direction = 'DESCENDING' 661 if key is not UNSPECIFIED: 662 mapped = parallel_ops.vectorized_map(key, iterable) 663 if mapped.shape.rank is not None and mapped.shape.rank != 1: 664 raise ValueError('sort only supports only 1D tensors') 665 with ops.control_dependencies([ 666 check_ops.assert_rank_v2(mapped, 1, 667 'sort only supports only 1D tensors') 668 ]): 669 order = sort_ops.argsort(mapped, direction=direction) 670 return array_ops.gather_v2(iterable, order) 671 if iterable.shape.rank is not None and iterable.shape.rank != 1: 672 raise ValueError('sort only supports only 1D tensors') 673 with ops.control_dependencies([ 674 check_ops.assert_rank_v2(iterable, 1, 675 'sort only supports only 1D tensors') 676 ]): 677 return sort_ops.sort(iterable, direction=direction) 678 679 680def _py_sorted(iterable, key, reverse): 681 if key is not UNSPECIFIED and reverse is UNSPECIFIED: 682 return sorted(iterable, key=key) 683 if key is UNSPECIFIED and reverse is not UNSPECIFIED: 684 return sorted(iterable, reverse=reverse) 685 if key is not UNSPECIFIED and reverse is not UNSPECIFIED: 686 return sorted(iterable, key=key, reverse=reverse) 687 return sorted(iterable) 688 689 690SUPPORTED_BUILTINS = (abs, float, int, len, print, range, enumerate, zip, map, 691 filter, any, all, sorted) 692 693BUILTIN_FUNCTIONS_MAP = { 694 'abs': abs_, 695 'any': any_, 696 'all': all_, 697 'enumerate': enumerate_, 698 'filter': filter_, 699 'float': float_, 700 'int': int_, 701 'len': len_, 702 'map': map_, 703 'next': next_, 704 'print': print_, 705 'range': range_, 706 'sorted': sorted_, 707 'zip': zip_, 708} 709