1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Mathematical operations.""" 16# pylint: disable=g-direct-tensorflow-import 17 18import numbers 19import sys 20 21import numpy as np 22 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import errors 26from tensorflow.python.framework import ops 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import bitwise_ops 29from tensorflow.python.ops import clip_ops 30from tensorflow.python.ops import control_flow_ops 31from tensorflow.python.ops import gen_math_ops 32from tensorflow.python.ops import math_ops 33from tensorflow.python.ops import nn_ops 34from tensorflow.python.ops import sort_ops 35from tensorflow.python.ops import special_math_ops 36from tensorflow.python.ops.numpy_ops import np_array_ops 37from tensorflow.python.ops.numpy_ops import np_arrays 38from tensorflow.python.ops.numpy_ops import np_dtypes 39from tensorflow.python.ops.numpy_ops import np_export 40from tensorflow.python.ops.numpy_ops import np_utils 41 42 43pi = np_export.np_export_constant(__name__, 'pi', np.pi) 44e = np_export.np_export_constant(__name__, 'e', np.e) 45inf = np_export.np_export_constant(__name__, 'inf', np.inf) 46 47 48@np_utils.np_doc_only('dot') 49def dot(a, b): # pylint: disable=missing-docstring 50 51 def f(a, b): # pylint: disable=missing-docstring 52 return np_utils.cond( 53 np_utils.logical_or( 54 math_ops.equal(array_ops.rank(a), 0), 55 math_ops.equal(array_ops.rank(b), 0)), 56 lambda: a * b, 57 lambda: np_utils.cond( # pylint: disable=g-long-lambda 58 math_ops.equal(array_ops.rank(b), 1), 59 lambda: math_ops.tensordot(a, b, axes=[[-1], [-1]]), 60 lambda: math_ops.tensordot(a, b, axes=[[-1], [-2]]))) 61 62 return _bin_op(f, a, b) 63 64 65# TODO(wangpeng): Make element-wise ops `ufunc`s 66def _bin_op(tf_fun, a, b, promote=True): 67 if promote: 68 a, b = np_array_ops._promote_dtype_binary(a, b) # pylint: disable=protected-access 69 else: 70 a = np_array_ops.array(a) 71 b = np_array_ops.array(b) 72 return tf_fun(a, b) 73 74 75@np_utils.np_doc('add') 76def add(x1, x2): 77 78 def add_or_or(x1, x2): 79 if x1.dtype == dtypes.bool: 80 assert x2.dtype == dtypes.bool 81 return math_ops.logical_or(x1, x2) 82 return math_ops.add(x1, x2) 83 84 return _bin_op(add_or_or, x1, x2) 85 86 87@np_utils.np_doc('subtract') 88def subtract(x1, x2): 89 return _bin_op(math_ops.subtract, x1, x2) 90 91 92@np_utils.np_doc('multiply') 93def multiply(x1, x2): 94 95 def mul_or_and(x1, x2): 96 if x1.dtype == dtypes.bool: 97 assert x2.dtype == dtypes.bool 98 return math_ops.logical_and(x1, x2) 99 return math_ops.multiply(x1, x2) 100 101 return _bin_op(mul_or_and, x1, x2) 102 103 104@np_utils.np_doc('true_divide') 105def true_divide(x1, x2): # pylint: disable=missing-function-docstring 106 107 def _avoid_float64(x1, x2): 108 if x1.dtype == x2.dtype and x1.dtype in (dtypes.int32, dtypes.int64): 109 x1 = math_ops.cast(x1, dtype=dtypes.float32) 110 x2 = math_ops.cast(x2, dtype=dtypes.float32) 111 return x1, x2 112 113 def f(x1, x2): 114 if x1.dtype == dtypes.bool: 115 assert x2.dtype == dtypes.bool 116 float_ = np_dtypes.default_float_type() 117 x1 = math_ops.cast(x1, float_) 118 x2 = math_ops.cast(x2, float_) 119 if not np_dtypes.is_allow_float64(): 120 # math_ops.truediv in Python3 produces float64 when both inputs are int32 121 # or int64. We want to avoid that when is_allow_float64() is False. 122 x1, x2 = _avoid_float64(x1, x2) 123 return math_ops.truediv(x1, x2) 124 125 return _bin_op(f, x1, x2) 126 127 128@np_utils.np_doc('divide') 129def divide(x1, x2): # pylint: disable=missing-function-docstring 130 return true_divide(x1, x2) 131 132 133@np_utils.np_doc('floor_divide') 134def floor_divide(x1, x2): # pylint: disable=missing-function-docstring 135 136 def f(x1, x2): 137 if x1.dtype == dtypes.bool: 138 assert x2.dtype == dtypes.bool 139 x1 = math_ops.cast(x1, dtypes.int8) 140 x2 = math_ops.cast(x2, dtypes.int8) 141 return math_ops.floordiv(x1, x2) 142 143 return _bin_op(f, x1, x2) 144 145 146@np_utils.np_doc('mod') 147def mod(x1, x2): # pylint: disable=missing-function-docstring 148 149 def f(x1, x2): 150 if x1.dtype == dtypes.bool: 151 assert x2.dtype == dtypes.bool 152 x1 = math_ops.cast(x1, dtypes.int8) 153 x2 = math_ops.cast(x2, dtypes.int8) 154 return math_ops.mod(x1, x2) 155 156 return _bin_op(f, x1, x2) 157 158 159@np_utils.np_doc('remainder') 160def remainder(x1, x2): # pylint: disable=missing-function-docstring 161 return mod(x1, x2) 162 163 164@np_utils.np_doc('divmod') 165def divmod(x1, x2): # pylint: disable=redefined-builtin 166 return floor_divide(x1, x2), mod(x1, x2) 167 168 169@np_utils.np_doc('maximum') 170def maximum(x1, x2): # pylint: disable=missing-function-docstring 171 172 # Fast path for when maximum is used as relu. 173 if isinstance( 174 x2, numbers.Real) and not isinstance(x2, bool) and x2 == 0 and isinstance( 175 x1, np_arrays.ndarray) and x1.dtype != dtypes.bool: 176 return nn_ops.relu(np_array_ops.asarray(x1)) 177 178 def max_or_or(x1, x2): 179 if x1.dtype == dtypes.bool: 180 assert x2.dtype == dtypes.bool 181 return math_ops.logical_or(x1, x2) 182 return math_ops.maximum(x1, x2) 183 184 return _bin_op(max_or_or, x1, x2) 185 186 187@np_utils.np_doc('minimum') 188def minimum(x1, x2): 189 190 def min_or_and(x1, x2): 191 if x1.dtype == dtypes.bool: 192 assert x2.dtype == dtypes.bool 193 return math_ops.logical_and(x1, x2) 194 return math_ops.minimum(x1, x2) 195 196 return _bin_op(min_or_and, x1, x2) 197 198 199@np_utils.np_doc('clip') 200def clip(a, a_min, a_max): # pylint: disable=missing-docstring 201 if a_min is None and a_max is None: 202 raise ValueError('Not more than one of `a_min` and `a_max` may be `None`.') 203 if a_min is None: 204 return minimum(a, a_max) 205 elif a_max is None: 206 return maximum(a, a_min) 207 else: 208 a, a_min, a_max = np_array_ops._promote_dtype(a, a_min, a_max) # pylint: disable=protected-access 209 return clip_ops.clip_by_value(*np_utils.tf_broadcast(a, a_min, a_max)) 210 211 212@np_utils.np_doc('matmul') 213def matmul(x1, x2): # pylint: disable=missing-docstring 214 def f(x1, x2): 215 try: 216 if x1._rank() == 2 and x2._rank() == 2: # pylint: disable=protected-access 217 # Fast path for known ranks. 218 return gen_math_ops.mat_mul(x1, x2) 219 return np_utils.cond( 220 math_ops.equal(np_utils.tf_rank(x2), 1), 221 lambda: math_ops.tensordot(x1, x2, axes=1), 222 lambda: np_utils.cond( # pylint: disable=g-long-lambda 223 math_ops.equal(np_utils.tf_rank(x1), 1), 224 lambda: math_ops.tensordot( # pylint: disable=g-long-lambda 225 x1, x2, axes=[[0], [-2]]), 226 lambda: math_ops.matmul(x1, x2))) 227 except errors.InvalidArgumentError as err: 228 raise ValueError(str(err)).with_traceback(sys.exc_info()[2]) 229 230 return _bin_op(f, x1, x2) 231 232 233# Exported so it can be called from Tensor.__matmul__. NumPy's matmul handles 234# batched matmul as well, so simply including promotion in TF's current 235# __matmul__ implementation was not sufficient. 236setattr(np_arrays.ndarray, '_matmul', matmul) 237 238 239@np_utils.np_doc('tensordot') 240def tensordot(a, b, axes=2): 241 return _bin_op(lambda a, b: math_ops.tensordot(a, b, axes=axes), a, b) 242 243 244@np_utils.np_doc_only('inner') 245def inner(a, b): # pylint: disable=missing-function-docstring 246 247 def f(a, b): 248 return np_utils.cond( 249 np_utils.logical_or( 250 math_ops.equal(array_ops.rank(a), 0), 251 math_ops.equal(array_ops.rank(b), 0)), lambda: a * b, 252 lambda: math_ops.tensordot(a, b, axes=[[-1], [-1]])) 253 254 return _bin_op(f, a, b) 255 256 257@np_utils.np_doc('cross') 258def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): # pylint: disable=missing-docstring 259 260 def f(a, b): # pylint: disable=missing-docstring 261 # We can't assign to captured variable `axisa`, so make a new variable 262 if axis is None: 263 axis_a = axisa 264 axis_b = axisb 265 axis_c = axisc 266 else: 267 axis_a = axis 268 axis_b = axis 269 axis_c = axis 270 if axis_a < 0: 271 axis_a = np_utils.add(axis_a, array_ops.rank(a)) 272 if axis_b < 0: 273 axis_b = np_utils.add(axis_b, array_ops.rank(b)) 274 275 def maybe_move_axis_to_last(a, axis): 276 277 def move_axis_to_last(a, axis): 278 return array_ops.transpose( 279 a, 280 array_ops.concat([ 281 math_ops.range(axis), 282 math_ops.range(axis + 1, array_ops.rank(a)), [axis] 283 ], 284 axis=0)) 285 286 return np_utils.cond(axis == np_utils.subtract(array_ops.rank(a), 1), 287 lambda: a, lambda: move_axis_to_last(a, axis)) 288 289 a = maybe_move_axis_to_last(a, axis_a) 290 b = maybe_move_axis_to_last(b, axis_b) 291 a_dim = np_utils.getitem(array_ops.shape(a), -1) 292 b_dim = np_utils.getitem(array_ops.shape(b), -1) 293 294 def maybe_pad_0(a, size_of_last_dim): 295 296 def pad_0(a): 297 return array_ops.pad( 298 a, 299 array_ops.concat([ 300 array_ops.zeros([array_ops.rank(a) - 1, 2], dtypes.int32), 301 constant_op.constant([[0, 1]], dtypes.int32) 302 ], 303 axis=0)) 304 305 return np_utils.cond( 306 math_ops.equal(size_of_last_dim, 2), lambda: pad_0(a), lambda: a) 307 308 a = maybe_pad_0(a, a_dim) 309 b = maybe_pad_0(b, b_dim) 310 c = math_ops.cross(*np_utils.tf_broadcast(a, b)) 311 if axis_c < 0: 312 axis_c = np_utils.add(axis_c, array_ops.rank(c)) 313 314 def move_last_to_axis(a, axis): 315 r = array_ops.rank(a) 316 return array_ops.transpose( 317 a, 318 array_ops.concat( 319 [math_ops.range(axis), [r - 1], 320 math_ops.range(axis, r - 1)], 321 axis=0)) 322 323 c = np_utils.cond( 324 (a_dim == 2) & (b_dim == 2), 325 lambda: c[..., 2], 326 lambda: np_utils.cond( # pylint: disable=g-long-lambda 327 axis_c == np_utils.subtract(array_ops.rank(c), 1), lambda: c, 328 lambda: move_last_to_axis(c, axis_c))) 329 return c 330 331 return _bin_op(f, a, b) 332 333 334@np_utils.np_doc_only('vdot') 335def vdot(a, b): # pylint: disable=missing-docstring 336 a, b = np_array_ops._promote_dtype(a, b) # pylint: disable=protected-access 337 a = np_array_ops.reshape(a, [-1]) 338 b = np_array_ops.reshape(b, [-1]) 339 if a.dtype == np_dtypes.complex128 or a.dtype == np_dtypes.complex64: 340 a = conj(a) 341 return dot(a, b) 342 343 344@np_utils.np_doc('power') 345def power(x1, x2): 346 return _bin_op(math_ops.pow, x1, x2) 347 348 349@np_utils.np_doc('float_power') 350def float_power(x1, x2): 351 return power(x1, x2) 352 353 354@np_utils.np_doc('arctan2') 355def arctan2(x1, x2): 356 return _bin_op(math_ops.atan2, x1, x2) 357 358 359@np_utils.np_doc('nextafter') 360def nextafter(x1, x2): 361 return _bin_op(math_ops.nextafter, x1, x2) 362 363 364@np_utils.np_doc('heaviside') 365def heaviside(x1, x2): # pylint: disable=missing-function-docstring 366 367 def f(x1, x2): 368 return array_ops.where_v2( 369 x1 < 0, constant_op.constant(0, dtype=x2.dtype), 370 array_ops.where_v2(x1 > 0, constant_op.constant(1, dtype=x2.dtype), x2)) 371 372 y = _bin_op(f, x1, x2) 373 if not np.issubdtype(y.dtype.as_numpy_dtype, np.inexact): 374 y = y.astype(np_dtypes.default_float_type()) 375 return y 376 377 378@np_utils.np_doc('hypot') 379def hypot(x1, x2): 380 return sqrt(square(x1) + square(x2)) 381 382 383@np_utils.np_doc('kron') 384def kron(a, b): # pylint: disable=missing-function-docstring 385 # pylint: disable=protected-access,g-complex-comprehension 386 a, b = np_array_ops._promote_dtype(a, b) 387 t_a = np_utils.cond( 388 a.ndim < b.ndim, 389 lambda: np_array_ops.reshape( # pylint: disable=g-long-lambda 390 a, np_array_ops._pad_left_to(b.ndim, a.shape)), 391 lambda: a) 392 t_b = np_utils.cond( 393 b.ndim < a.ndim, 394 lambda: np_array_ops.reshape( # pylint: disable=g-long-lambda 395 b, np_array_ops._pad_left_to(a.ndim, b.shape)), 396 lambda: b) 397 398 def _make_shape(shape, prepend): 399 ones = array_ops.ones_like(shape) 400 if prepend: 401 shapes = [ones, shape] 402 else: 403 shapes = [shape, ones] 404 return array_ops.reshape(array_ops.stack(shapes, axis=1), [-1]) 405 406 a_shape = array_ops.shape(t_a) 407 b_shape = array_ops.shape(t_b) 408 a_reshaped = np_array_ops.reshape(t_a, _make_shape(a_shape, False)) 409 b_reshaped = np_array_ops.reshape(t_b, _make_shape(b_shape, True)) 410 out_shape = a_shape * b_shape 411 return np_array_ops.reshape(a_reshaped * b_reshaped, out_shape) 412 413 414@np_utils.np_doc('outer') 415def outer(a, b): 416 417 def f(a, b): 418 return array_ops.reshape(a, [-1, 1]) * array_ops.reshape(b, [-1]) 419 420 return _bin_op(f, a, b) 421 422 423# This can also be implemented via tf.reduce_logsumexp 424@np_utils.np_doc('logaddexp') 425def logaddexp(x1, x2): 426 amax = maximum(x1, x2) 427 delta = x1 - x2 428 return np_array_ops.where( 429 isnan(delta), 430 x1 + x2, # NaNs or infinities of the same sign. 431 amax + log1p(exp(-abs(delta)))) 432 433 434@np_utils.np_doc('logaddexp2') 435def logaddexp2(x1, x2): 436 amax = maximum(x1, x2) 437 delta = x1 - x2 438 return np_array_ops.where( 439 isnan(delta), 440 x1 + x2, # NaNs or infinities of the same sign. 441 amax + log1p(exp2(-abs(delta))) / np.log(2)) 442 443 444@np_utils.np_doc('polyval') 445def polyval(p, x): # pylint: disable=missing-function-docstring 446 447 def f(p, x): 448 if p.shape.rank == 0: 449 p = array_ops.reshape(p, [1]) 450 p = array_ops.unstack(p) 451 # TODO(wangpeng): Make tf version take a tensor for p instead of a list. 452 y = math_ops.polyval(p, x) 453 # If the polynomial is 0-order, numpy requires the result to be broadcast to 454 # `x`'s shape. 455 if len(p) == 1: 456 y = array_ops.broadcast_to(y, x.shape) 457 return y 458 459 return _bin_op(f, p, x) 460 461 462@np_utils.np_doc('isclose') 463def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): # pylint: disable=missing-docstring 464 465 def f(a, b): # pylint: disable=missing-docstring 466 dtype = a.dtype 467 if np.issubdtype(dtype.as_numpy_dtype, np.inexact): 468 rtol_ = ops.convert_to_tensor(rtol, dtype.real_dtype) 469 atol_ = ops.convert_to_tensor(atol, dtype.real_dtype) 470 result = (math_ops.abs(a - b) <= atol_ + rtol_ * math_ops.abs(b)) 471 if equal_nan: 472 result = result | (math_ops.is_nan(a) & math_ops.is_nan(b)) 473 return result 474 else: 475 return a == b 476 477 return _bin_op(f, a, b) 478 479 480@np_utils.np_doc('allclose') 481def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): 482 return np_array_ops.all( 483 isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)) 484 485 486def _tf_gcd(x1, x2): # pylint: disable=missing-function-docstring 487 488 def _gcd_cond_fn(_, x2): 489 return math_ops.reduce_any(x2 != 0) 490 491 def _gcd_body_fn(x1, x2): 492 # math_ops.mod will raise an error when any element of x2 is 0. To avoid 493 # that, we change those zeros to ones. Their values don't matter because 494 # they won't be used. 495 x2_safe = array_ops.where_v2(x2 != 0, x2, constant_op.constant(1, x2.dtype)) 496 x1, x2 = (array_ops.where_v2(x2 != 0, x2, x1), 497 array_ops.where_v2(x2 != 0, math_ops.mod(x1, x2_safe), 498 constant_op.constant(0, x2.dtype))) 499 return (array_ops.where_v2(x1 < x2, x2, 500 x1), array_ops.where_v2(x1 < x2, x1, x2)) 501 502 if (not np.issubdtype(x1.dtype.as_numpy_dtype, np.integer) or 503 not np.issubdtype(x2.dtype.as_numpy_dtype, np.integer)): 504 raise ValueError('Arguments to gcd must be integers.') 505 shape = array_ops.broadcast_dynamic_shape( 506 array_ops.shape(x1), array_ops.shape(x2)) 507 x1 = array_ops.broadcast_to(x1, shape) 508 x2 = array_ops.broadcast_to(x2, shape) 509 value, _ = control_flow_ops.while_loop(_gcd_cond_fn, _gcd_body_fn, 510 (math_ops.abs(x1), math_ops.abs(x2))) 511 return value 512 513 514# Note that np.gcd may not be present in some supported versions of numpy. 515@np_utils.np_doc('gcd') 516def gcd(x1, x2): 517 return _bin_op(_tf_gcd, x1, x2) 518 519 520# Note that np.lcm may not be present in some supported versions of numpy. 521@np_utils.np_doc('lcm') 522def lcm(x1, x2): # pylint: disable=missing-function-docstring 523 524 def f(x1, x2): 525 d = _tf_gcd(x1, x2) 526 # Same as the `x2_safe` trick above 527 d_safe = array_ops.where_v2( 528 math_ops.equal(d, 0), constant_op.constant(1, d.dtype), d) 529 return array_ops.where_v2( 530 math_ops.equal(d, 0), constant_op.constant(0, d.dtype), 531 math_ops.abs(x1 * x2) // d_safe) 532 533 return _bin_op(f, x1, x2) 534 535 536def _bitwise_binary_op(tf_fn, x1, x2): # pylint: disable=missing-function-docstring 537 538 def f(x1, x2): 539 is_bool = (x1.dtype == dtypes.bool) 540 if is_bool: 541 assert x2.dtype == dtypes.bool 542 x1 = math_ops.cast(x1, dtypes.int8) 543 x2 = math_ops.cast(x2, dtypes.int8) 544 r = tf_fn(x1, x2) 545 if is_bool: 546 r = math_ops.cast(r, dtypes.bool) 547 return r 548 549 return _bin_op(f, x1, x2) 550 551 552@np_utils.np_doc('bitwise_and') 553def bitwise_and(x1, x2): 554 return _bitwise_binary_op(bitwise_ops.bitwise_and, x1, x2) 555 556 557@np_utils.np_doc('bitwise_or') 558def bitwise_or(x1, x2): 559 return _bitwise_binary_op(bitwise_ops.bitwise_or, x1, x2) 560 561 562@np_utils.np_doc('bitwise_xor') 563def bitwise_xor(x1, x2): 564 return _bitwise_binary_op(bitwise_ops.bitwise_xor, x1, x2) 565 566 567@np_utils.np_doc('bitwise_not', link=np_utils.AliasOf('invert')) 568def bitwise_not(x): 569 570 def f(x): 571 if x.dtype == dtypes.bool: 572 return math_ops.logical_not(x) 573 return bitwise_ops.invert(x) 574 575 return _scalar(f, x) 576 577 578def _scalar(tf_fn, x, promote_to_float=False): 579 """Computes the tf_fn(x) for each element in `x`. 580 581 Args: 582 tf_fn: function that takes a single Tensor argument. 583 x: array_like. Could be an ndarray, a Tensor or any object that can be 584 converted to a Tensor using `ops.convert_to_tensor`. 585 promote_to_float: whether to cast the argument to a float dtype 586 (`np_dtypes.default_float_type`) if it is not already. 587 588 Returns: 589 An ndarray with the same shape as `x`. The default output dtype is 590 determined by `np_dtypes.default_float_type`, unless x is an ndarray with a 591 floating point type, in which case the output type is same as x.dtype. 592 """ 593 x = np_array_ops.asarray(x) 594 if promote_to_float and not np.issubdtype(x.dtype.as_numpy_dtype, np.inexact): 595 x = x.astype(np_dtypes.default_float_type()) 596 return tf_fn(x) 597 598 599@np_utils.np_doc('log') 600def log(x): 601 return _scalar(math_ops.log, x, True) 602 603 604@np_utils.np_doc('exp') 605def exp(x): 606 return _scalar(math_ops.exp, x, True) 607 608 609@np_utils.np_doc('sqrt') 610def sqrt(x): 611 return _scalar(math_ops.sqrt, x, True) 612 613 614@np_utils.np_doc('abs', link=np_utils.AliasOf('absolute')) 615def abs(x): # pylint: disable=redefined-builtin 616 return _scalar(math_ops.abs, x) 617 618 619@np_utils.np_doc('absolute') 620def absolute(x): 621 return abs(x) 622 623 624@np_utils.np_doc('fabs') 625def fabs(x): 626 return abs(x) 627 628 629@np_utils.np_doc('ceil') 630def ceil(x): 631 return _scalar(math_ops.ceil, x, True) 632 633 634@np_utils.np_doc('floor') 635def floor(x): 636 return _scalar(math_ops.floor, x, True) 637 638 639@np_utils.np_doc('conj') 640def conj(x): 641 return _scalar(math_ops.conj, x) 642 643 644@np_utils.np_doc('negative') 645def negative(x): 646 return _scalar(math_ops.negative, x) 647 648 649@np_utils.np_doc('reciprocal') 650def reciprocal(x): 651 return _scalar(math_ops.reciprocal, x) 652 653 654@np_utils.np_doc('signbit') 655def signbit(x): 656 657 def f(x): 658 if x.dtype == dtypes.bool: 659 return array_ops.fill(array_ops.shape(x), False) 660 return x < 0 661 662 return _scalar(f, x) 663 664 665@np_utils.np_doc('sin') 666def sin(x): 667 return _scalar(math_ops.sin, x, True) 668 669 670@np_utils.np_doc('cos') 671def cos(x): 672 return _scalar(math_ops.cos, x, True) 673 674 675@np_utils.np_doc('tan') 676def tan(x): 677 return _scalar(math_ops.tan, x, True) 678 679 680@np_utils.np_doc('sinh') 681def sinh(x): 682 return _scalar(math_ops.sinh, x, True) 683 684 685@np_utils.np_doc('cosh') 686def cosh(x): 687 return _scalar(math_ops.cosh, x, True) 688 689 690@np_utils.np_doc('tanh') 691def tanh(x): 692 return _scalar(math_ops.tanh, x, True) 693 694 695@np_utils.np_doc('arcsin') 696def arcsin(x): 697 return _scalar(math_ops.asin, x, True) 698 699 700@np_utils.np_doc('arccos') 701def arccos(x): 702 return _scalar(math_ops.acos, x, True) 703 704 705@np_utils.np_doc('arctan') 706def arctan(x): 707 return _scalar(math_ops.atan, x, True) 708 709 710@np_utils.np_doc('arcsinh') 711def arcsinh(x): 712 return _scalar(math_ops.asinh, x, True) 713 714 715@np_utils.np_doc('arccosh') 716def arccosh(x): 717 return _scalar(math_ops.acosh, x, True) 718 719 720@np_utils.np_doc('arctanh') 721def arctanh(x): 722 return _scalar(math_ops.atanh, x, True) 723 724 725@np_utils.np_doc('deg2rad') 726def deg2rad(x): 727 728 def f(x): 729 return x * (np.pi / 180.0) 730 731 return _scalar(f, x, True) 732 733 734@np_utils.np_doc('rad2deg') 735def rad2deg(x): 736 return x * (180.0 / np.pi) 737 738 739_tf_float_types = [ 740 dtypes.bfloat16, dtypes.float16, dtypes.float32, dtypes.float64 741] 742 743 744@np_utils.np_doc('angle') 745def angle(z, deg=False): # pylint: disable=missing-function-docstring 746 747 def f(x): 748 if x.dtype in _tf_float_types: 749 # Workaround for b/147515503 750 return array_ops.where_v2(x < 0, np.pi, 0) 751 else: 752 return math_ops.angle(x) 753 754 y = _scalar(f, z, True) 755 if deg: 756 y = rad2deg(y) 757 return y 758 759 760@np_utils.np_doc('cbrt') 761def cbrt(x): 762 763 def f(x): 764 # __pow__ can't handle negative base, so we use `abs` here. 765 rt = math_ops.abs(x)**(1.0 / 3) 766 return array_ops.where_v2(x < 0, -rt, rt) 767 768 return _scalar(f, x, True) 769 770 771@np_utils.np_doc('conjugate', link=np_utils.AliasOf('conj')) 772def conjugate(x): 773 return _scalar(math_ops.conj, x) 774 775 776@np_utils.np_doc('exp2') 777def exp2(x): 778 779 def f(x): 780 return 2**x 781 782 return _scalar(f, x, True) 783 784 785@np_utils.np_doc('expm1') 786def expm1(x): 787 return _scalar(math_ops.expm1, x, True) 788 789 790@np_utils.np_doc('fix') 791def fix(x): 792 793 def f(x): 794 return array_ops.where_v2(x < 0, math_ops.ceil(x), math_ops.floor(x)) 795 796 return _scalar(f, x, True) 797 798 799@np_utils.np_doc('iscomplex') 800def iscomplex(x): 801 return np_array_ops.imag(x) != 0 802 803 804@np_utils.np_doc('isreal') 805def isreal(x): 806 return np_array_ops.imag(x) == 0 807 808 809@np_utils.np_doc('iscomplexobj') 810def iscomplexobj(x): 811 x = np_array_ops.array(x) 812 return np.issubdtype(x.dtype.as_numpy_dtype, np.complexfloating) 813 814 815@np_utils.np_doc('isrealobj') 816def isrealobj(x): 817 return not iscomplexobj(x) 818 819 820@np_utils.np_doc('isnan') 821def isnan(x): 822 return _scalar(math_ops.is_nan, x, True) 823 824 825def _make_nan_reduction(np_fun_name, reduction, init_val): 826 """Helper to generate nan* functions.""" 827 828 @np_utils.np_doc(np_fun_name) 829 def nan_reduction(a, axis=None, dtype=None, keepdims=False): 830 a = np_array_ops.array(a) 831 v = np_array_ops.array(init_val, dtype=a.dtype) 832 return reduction( 833 np_array_ops.where(isnan(a), v, a), 834 axis=axis, 835 dtype=dtype, 836 keepdims=keepdims) 837 838 return nan_reduction 839 840 841nansum = _make_nan_reduction('nansum', np_array_ops.sum, 0) 842nanprod = _make_nan_reduction('nanprod', np_array_ops.prod, 1) 843 844 845@np_utils.np_doc('nanmean') 846def nanmean(a, axis=None, dtype=None, keepdims=None): # pylint: disable=missing-docstring 847 a = np_array_ops.array(a) 848 if np.issubdtype(a.dtype.as_numpy_dtype, np.bool_) or np.issubdtype( 849 a.dtype.as_numpy_dtype, np.integer): 850 return np_array_ops.mean(a, axis=axis, dtype=dtype, keepdims=keepdims) 851 nan_mask = logical_not(isnan(a)) 852 if dtype is None: 853 dtype = a.dtype.as_numpy_dtype 854 normalizer = np_array_ops.sum( 855 nan_mask, axis=axis, dtype=dtype, keepdims=keepdims) 856 return nansum(a, axis=axis, dtype=dtype, keepdims=keepdims) / normalizer 857 858 859@np_utils.np_doc('isfinite') 860def isfinite(x): 861 return _scalar(math_ops.is_finite, x, True) 862 863 864@np_utils.np_doc('isinf') 865def isinf(x): 866 return _scalar(math_ops.is_inf, x, True) 867 868 869@np_utils.np_doc('isneginf') 870def isneginf(x): 871 return x == np_array_ops.full_like(x, -np.inf) 872 873 874@np_utils.np_doc('isposinf') 875def isposinf(x): 876 return x == np_array_ops.full_like(x, np.inf) 877 878 879@np_utils.np_doc('log2') 880def log2(x): 881 return log(x) / np.log(2) 882 883 884@np_utils.np_doc('log10') 885def log10(x): 886 return log(x) / np.log(10) 887 888 889@np_utils.np_doc('log1p') 890def log1p(x): 891 return _scalar(math_ops.log1p, x, True) 892 893 894@np_utils.np_doc('positive') 895def positive(x): 896 return _scalar(lambda x: x, x) 897 898 899@np_utils.np_doc('sinc') 900def sinc(x): 901 902 def f(x): 903 pi_x = x * np.pi 904 return array_ops.where_v2(x == 0, array_ops.ones_like(x), 905 math_ops.sin(pi_x) / pi_x) 906 907 return _scalar(f, x, True) 908 909 910@np_utils.np_doc('square') 911def square(x): 912 return _scalar(math_ops.square, x) 913 914 915@np_utils.np_doc('diff') 916def diff(a, n=1, axis=-1): # pylint: disable=missing-function-docstring 917 918 def f(a): 919 # TODO(agarwal): transpose and reshape to N, H, 1 and do a 1D convolution 920 # TODO(agarwal): avoid depending on static rank. 921 nd = a.shape.rank 922 if nd is None: 923 raise ValueError( 924 'Function `diff` currently requires a known rank for input `a`. ' 925 f'Received: a={a} (unknown rank)') 926 if (axis + nd if axis < 0 else axis) >= nd: 927 raise ValueError( 928 f'Argument `axis` (received axis={axis}) is out of bounds ' 929 f'for input {a} of rank {nd}.') 930 if n < 0: 931 raise ValueError('Argument `order` must be a non-negative integer. ' 932 f'Received: axis={n}') 933 slice1 = [slice(None)] * nd 934 slice2 = [slice(None)] * nd 935 slice1[axis] = slice(1, None) 936 slice2[axis] = slice(None, -1) 937 slice1 = tuple(slice1) 938 slice2 = tuple(slice2) 939 op = math_ops.not_equal if a.dtype == dtypes.bool else math_ops.subtract 940 for _ in range(n): 941 a = op(a[slice1], a[slice2]) 942 return a 943 944 return _scalar(f, a) 945 946 947def _wrap(f, reverse=False): 948 """Wraps binary ops so they can be added as operator overloads on ndarray.""" 949 950 def _f(a, b): 951 if reverse: 952 a, b = b, a 953 954 if getattr(b, '__array_priority__', 955 0) > np_arrays.ndarray.__array_priority__: 956 return NotImplemented 957 958 return f(a, b) 959 960 return _f 961 962 963def _comparison(tf_fun, x1, x2, cast_bool_to_int=False): 964 """Helper function for comparision.""" 965 dtype = np_utils.result_type(x1, x2) 966 # Cast x1 and x2 to the result_type if needed. 967 x1 = np_array_ops.array(x1, dtype=dtype) 968 x2 = np_array_ops.array(x2, dtype=dtype) 969 if cast_bool_to_int and x1.dtype == dtypes.bool: 970 x1 = math_ops.cast(x1, dtypes.int32) 971 x2 = math_ops.cast(x2, dtypes.int32) 972 return tf_fun(x1, x2) 973 974 975@np_utils.np_doc('equal') 976def equal(x1, x2): 977 return _comparison(math_ops.equal, x1, x2) 978 979 980@np_utils.np_doc('not_equal') 981def not_equal(x1, x2): 982 return _comparison(math_ops.not_equal, x1, x2) 983 984 985@np_utils.np_doc('greater') 986def greater(x1, x2): 987 return _comparison(math_ops.greater, x1, x2, True) 988 989 990@np_utils.np_doc('greater_equal') 991def greater_equal(x1, x2): 992 return _comparison(math_ops.greater_equal, x1, x2, True) 993 994 995@np_utils.np_doc('less') 996def less(x1, x2): 997 return _comparison(math_ops.less, x1, x2, True) 998 999 1000@np_utils.np_doc('less_equal') 1001def less_equal(x1, x2): 1002 return _comparison(math_ops.less_equal, x1, x2, True) 1003 1004 1005@np_utils.np_doc('array_equal') 1006def array_equal(a1, a2): # pylint: disable=missing-function-docstring 1007 1008 def f(x1, x2): 1009 return np_utils.cond( 1010 math_ops.equal(array_ops.rank(x1), array_ops.rank(x2)), 1011 lambda: np_utils.cond( # pylint: disable=g-long-lambda 1012 np_utils.reduce_all( 1013 math_ops.equal(array_ops.shape(x1), array_ops.shape(x2)) 1014 ), 1015 lambda: math_ops.reduce_all(math_ops.equal(x1, x2)), 1016 lambda: constant_op.constant(False)), 1017 lambda: constant_op.constant(False)) 1018 1019 return _comparison(f, a1, a2) 1020 1021 1022def _logical_binary_op(tf_fun, x1, x2): 1023 x1 = np_array_ops.array(x1, dtype=np.bool_) 1024 x2 = np_array_ops.array(x2, dtype=np.bool_) 1025 return tf_fun(x1, x2) 1026 1027 1028@np_utils.np_doc('logical_and') 1029def logical_and(x1, x2): 1030 return _logical_binary_op(math_ops.logical_and, x1, x2) 1031 1032 1033@np_utils.np_doc('logical_or') 1034def logical_or(x1, x2): 1035 return _logical_binary_op(math_ops.logical_or, x1, x2) 1036 1037 1038@np_utils.np_doc('logical_xor') 1039def logical_xor(x1, x2): 1040 return _logical_binary_op(math_ops.logical_xor, x1, x2) 1041 1042 1043@np_utils.np_doc('logical_not') 1044def logical_not(x): 1045 x = np_array_ops.array(x, dtype=np.bool_) 1046 return math_ops.logical_not(x) 1047 1048 1049@np_utils.np_doc('linspace') 1050def linspace( # pylint: disable=missing-docstring 1051 start, 1052 stop, 1053 num=50, 1054 endpoint=True, 1055 retstep=False, 1056 dtype=float, 1057 axis=0): 1058 if dtype: 1059 dtype = np_utils.result_type(dtype) 1060 start = np_array_ops.array(start, dtype=dtype) 1061 stop = np_array_ops.array(stop, dtype=dtype) 1062 if num < 0: 1063 raise ValueError( 1064 'Argument `num` (number of samples) must be a non-negative integer. ' 1065 f'Received: num={num}') 1066 step = ops.convert_to_tensor(np.nan) 1067 if endpoint: 1068 result = math_ops.linspace(start, stop, num, axis=axis) 1069 if num > 1: 1070 step = (stop - start) / (num - 1) 1071 else: 1072 # math_ops.linspace does not support endpoint=False so we manually handle it 1073 # here. 1074 if num > 0: 1075 step = ((stop - start) / num) 1076 if num > 1: 1077 new_stop = math_ops.cast(stop, step.dtype) - step 1078 start = math_ops.cast(start, new_stop.dtype) 1079 result = math_ops.linspace(start, new_stop, num, axis=axis) 1080 else: 1081 result = math_ops.linspace(start, stop, num, axis=axis) 1082 if dtype: 1083 if dtype.is_integer: 1084 # Since numpy 1.20, linspace's rounding is towards -inf instead of 0 1085 result = math_ops.floor(result) 1086 result = math_ops.cast(result, dtype) 1087 if retstep: 1088 return (result, step) 1089 else: 1090 return result 1091 1092 1093@np_utils.np_doc('logspace') 1094def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0): 1095 dtype = np_utils.result_type(start, stop, dtype) 1096 result = linspace( 1097 start, stop, num=num, endpoint=endpoint, dtype=dtype, axis=axis) 1098 result = math_ops.pow(math_ops.cast(base, result.dtype), result) 1099 if dtype: 1100 result = math_ops.cast(result, dtype) 1101 return result 1102 1103 1104@np_utils.np_doc('geomspace') 1105def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0): # pylint: disable=missing-docstring 1106 dtype = dtypes.as_dtype(dtype) if dtype else np_utils.result_type( 1107 start, stop, float(num), np_array_ops.zeros((), dtype)) 1108 computation_dtype = np.promote_types(dtype.as_numpy_dtype, np.float32) 1109 start = np_array_ops.asarray(start, dtype=computation_dtype) 1110 stop = np_array_ops.asarray(stop, dtype=computation_dtype) 1111 # follow the numpy geomspace convention for negative and complex endpoints 1112 start_sign = 1 - np_array_ops.sign(np_array_ops.real(start)) 1113 stop_sign = 1 - np_array_ops.sign(np_array_ops.real(stop)) 1114 signflip = 1 - start_sign * stop_sign // 2 1115 res = signflip * logspace( 1116 log10(signflip * start), 1117 log10(signflip * stop), 1118 num, 1119 endpoint=endpoint, 1120 base=10.0, 1121 dtype=computation_dtype, 1122 axis=0) 1123 if axis != 0: 1124 res = np_array_ops.moveaxis(res, 0, axis) 1125 return math_ops.cast(res, dtype) 1126 1127 1128@np_utils.np_doc('ptp') 1129def ptp(a, axis=None, keepdims=None): 1130 return (np_array_ops.amax(a, axis=axis, keepdims=keepdims) - 1131 np_array_ops.amin(a, axis=axis, keepdims=keepdims)) 1132 1133 1134@np_utils.np_doc_only('concatenate') 1135def concatenate(arys, axis=0): 1136 if not isinstance(arys, (list, tuple)): 1137 arys = [arys] 1138 if not arys: 1139 raise ValueError('Need at least one array to concatenate. Received empty ' 1140 f'input: arys={arys}') 1141 dtype = np_utils.result_type(*arys) 1142 arys = [np_array_ops.array(array, dtype=dtype) for array in arys] 1143 return array_ops.concat(arys, axis) 1144 1145 1146@np_utils.np_doc_only('tile') 1147def tile(a, reps): # pylint: disable=missing-function-docstring 1148 a = np_array_ops.array(a) 1149 reps = np_array_ops.array(reps, dtype=dtypes.int32).reshape([-1]) 1150 1151 a_rank = array_ops.rank(a) 1152 reps_size = array_ops.size(reps) 1153 reps = array_ops.pad( 1154 reps, [[math_ops.maximum(a_rank - reps_size, 0), 0]], constant_values=1) 1155 a_shape = array_ops.pad( 1156 array_ops.shape(a), [[math_ops.maximum(reps_size - a_rank, 0), 0]], 1157 constant_values=1) 1158 a = array_ops.reshape(a, a_shape) 1159 1160 return array_ops.tile(a, reps) 1161 1162 1163@np_utils.np_doc('count_nonzero') 1164def count_nonzero(a, axis=None): 1165 return math_ops.count_nonzero(np_array_ops.array(a), axis) 1166 1167 1168@np_utils.np_doc('argsort') 1169def argsort(a, axis=-1, kind='quicksort', order=None): # pylint: disable=missing-docstring 1170 # TODO(nareshmodi): make string tensors also work. 1171 if kind not in ('quicksort', 'stable'): 1172 raise ValueError( 1173 'Invalid value for argument `kind`. ' 1174 'Only kind="quicksort" and kind="stable" are supported. ' 1175 f'Received: kind={kind}') 1176 if order is not None: 1177 raise ValueError('The `order` argument is not supported. Pass order=None') 1178 stable = (kind == 'stable') 1179 1180 a = np_array_ops.array(a) 1181 1182 def _argsort(a, axis, stable): 1183 if axis is None: 1184 a = array_ops.reshape(a, [-1]) 1185 axis = 0 1186 1187 return sort_ops.argsort(a, axis, stable=stable) 1188 1189 tf_ans = np_utils.cond( 1190 math_ops.equal(array_ops.rank(a), 0), lambda: constant_op.constant([0]), 1191 lambda: _argsort(a, axis, stable)) 1192 1193 return np_array_ops.array(tf_ans, dtype=np.intp) 1194 1195 1196@np_utils.np_doc('sort') 1197def sort(a, axis=-1, kind='quicksort', order=None): # pylint: disable=missing-docstring 1198 if kind != 'quicksort': 1199 raise ValueError( 1200 'Invalid value for argument `kind`. ' 1201 'Only kind="quicksort" is supported. ' 1202 f'Received: kind={kind}') 1203 if order is not None: 1204 raise ValueError('The `order` argument is not supported. Pass order=None') 1205 1206 a = np_array_ops.array(a) 1207 1208 if axis is None: 1209 return sort_ops.sort(array_ops.reshape(a, [-1]), 0) 1210 else: 1211 return sort_ops.sort(a, axis) 1212 1213 1214def _argminmax(fn, a, axis=None): 1215 a = np_array_ops.array(a) 1216 if axis is None: 1217 # When axis is None numpy flattens the array. 1218 a_t = array_ops.reshape(a, [-1]) 1219 else: 1220 a_t = np_array_ops.atleast_1d(a) 1221 return fn(input=a_t, axis=axis) 1222 1223 1224@np_utils.np_doc('argmax') 1225def argmax(a, axis=None): 1226 return _argminmax(math_ops.argmax, a, axis) 1227 1228 1229@np_utils.np_doc('argmin') 1230def argmin(a, axis=None): 1231 return _argminmax(math_ops.argmin, a, axis) 1232 1233 1234@np_utils.np_doc('append') 1235def append(arr, values, axis=None): 1236 if axis is None: 1237 return concatenate([np_array_ops.ravel(arr), np_array_ops.ravel(values)], 0) 1238 else: 1239 return concatenate([arr, values], axis=axis) 1240 1241 1242@np_utils.np_doc('average') 1243def average(a, axis=None, weights=None, returned=False): # pylint: disable=missing-docstring 1244 if axis is not None and not isinstance(axis, int): 1245 # TODO(wangpeng): Support tuple of ints as `axis` 1246 raise ValueError('Argument `axis` must be an integer. ' 1247 f'Received axis={axis} (of type {type(axis)})') 1248 a = np_array_ops.array(a) 1249 if weights is None: # Treat all weights as 1 1250 if not np.issubdtype(a.dtype.as_numpy_dtype, np.inexact): 1251 a = a.astype( 1252 np_utils.result_type(a.dtype, np_dtypes.default_float_type())) 1253 avg = math_ops.reduce_mean(a, axis=axis) 1254 if returned: 1255 if axis is None: 1256 weights_sum = array_ops.size(a) 1257 else: 1258 weights_sum = array_ops.shape(a)[axis] 1259 weights_sum = math_ops.cast(weights_sum, a.dtype) 1260 else: 1261 if np.issubdtype(a.dtype.as_numpy_dtype, np.inexact): 1262 out_dtype = np_utils.result_type(a.dtype, weights) 1263 else: 1264 out_dtype = np_utils.result_type(a.dtype, weights, 1265 np_dtypes.default_float_type()) 1266 a = np_array_ops.array(a, out_dtype) 1267 weights = np_array_ops.array(weights, out_dtype) 1268 1269 def rank_equal_case(): 1270 control_flow_ops.Assert( 1271 math_ops.reduce_all(array_ops.shape(a) == array_ops.shape(weights)), 1272 [array_ops.shape(a), array_ops.shape(weights)]) 1273 weights_sum = math_ops.reduce_sum(weights, axis=axis) 1274 avg = math_ops.reduce_sum(a * weights, axis=axis) / weights_sum 1275 return avg, weights_sum 1276 1277 if axis is None: 1278 avg, weights_sum = rank_equal_case() 1279 else: 1280 1281 def rank_not_equal_case(): 1282 control_flow_ops.Assert( 1283 array_ops.rank(weights) == 1, [array_ops.rank(weights)]) 1284 weights_sum = math_ops.reduce_sum(weights) 1285 axes = ops.convert_to_tensor([[axis], [0]]) 1286 avg = math_ops.tensordot(a, weights, axes) / weights_sum 1287 return avg, weights_sum 1288 1289 # We condition on rank rather than shape equality, because if we do the 1290 # latter, when the shapes are partially unknown but the ranks are known 1291 # and different, np_utils.cond will run shape checking on the true branch, 1292 # which will raise a shape-checking error. 1293 avg, weights_sum = np_utils.cond( 1294 math_ops.equal(array_ops.rank(a), array_ops.rank(weights)), 1295 rank_equal_case, rank_not_equal_case) 1296 1297 avg = np_array_ops.array(avg) 1298 if returned: 1299 weights_sum = np_array_ops.broadcast_to(weights_sum, array_ops.shape(avg)) 1300 return avg, weights_sum 1301 return avg 1302 1303 1304@np_utils.np_doc('trace') 1305def trace(a, offset=0, axis1=0, axis2=1, dtype=None): # pylint: disable=missing-docstring 1306 if dtype: 1307 dtype = np_utils.result_type(dtype) 1308 a = np_array_ops.asarray(a, dtype) 1309 1310 if offset == 0: 1311 a_shape = a.shape 1312 if a_shape.rank is not None: 1313 rank = len(a_shape) 1314 if (axis1 == -2 or axis1 == rank - 2) and (axis2 == -1 or 1315 axis2 == rank - 1): 1316 return math_ops.trace(a) 1317 1318 a = np_array_ops.diagonal(a, offset, axis1, axis2) 1319 return np_array_ops.sum(a, -1, dtype) 1320 1321 1322@np_utils.np_doc('meshgrid') 1323def meshgrid(*xi, **kwargs): 1324 """This currently requires copy=True and sparse=False.""" 1325 sparse = kwargs.get('sparse', False) 1326 if sparse: 1327 raise ValueError( 1328 'Function `meshgrid` does not support returning sparse arrays yet. ' 1329 f'Received: sparse={sparse}') 1330 1331 copy = kwargs.get('copy', True) 1332 if not copy: 1333 raise ValueError('Function `meshgrid` only supports copy=True. ' 1334 f'Received: copy={copy}') 1335 1336 indexing = kwargs.get('indexing', 'xy') 1337 1338 xi = [np_array_ops.asarray(arg) for arg in xi] 1339 kwargs = {'indexing': indexing} 1340 1341 outputs = array_ops.meshgrid(*xi, **kwargs) 1342 1343 return outputs 1344 1345 1346# Uses np_doc_only here because np.einsum (in 1.16) doesn't have argument 1347# `subscripts`, even though the doc says it has. 1348@np_utils.np_doc_only('einsum') 1349def einsum(subscripts, *operands, **kwargs): # pylint: disable=missing-docstring 1350 casting = kwargs.get('casting', 'safe') 1351 optimize = kwargs.get('optimize', False) 1352 if casting == 'safe': 1353 operands = np_array_ops._promote_dtype(*operands) # pylint: disable=protected-access 1354 elif casting == 'no': 1355 operands = [np_array_ops.asarray(x) for x in operands] 1356 else: 1357 raise ValueError( 1358 'Invalid value for argument `casting`. ' 1359 f'Expected casting="safe" or casting="no". Received: casting={casting}') 1360 if not optimize: 1361 # TF doesn't have a "no optimization" option. 1362 # TODO(wangpeng): Print a warning that np and tf use different 1363 # optimizations. 1364 tf_optimize = 'greedy' 1365 elif optimize == True: # pylint: disable=singleton-comparison,g-explicit-bool-comparison 1366 tf_optimize = 'greedy' 1367 elif optimize == 'greedy': 1368 tf_optimize = 'greedy' 1369 elif optimize == 'optimal': 1370 tf_optimize = 'optimal' 1371 else: 1372 raise ValueError( 1373 'Invalid value for argument `optimize`. ' 1374 'Expected one of {True, "greedy", "optimal"}. ' 1375 f'Received: optimize={optimize}') 1376 1377 res = special_math_ops.einsum(subscripts, *operands, optimize=tf_optimize) 1378 return res 1379 1380 1381def _tensor_t(self): 1382 """Returns a Tensor which is the transpose of this Tensor.""" 1383 return self.transpose() 1384 1385 1386def _tensor_ndim(self): 1387 """Returns the rank of the Tensor.""" 1388 return self.shape.ndims 1389 1390 1391def _tensor_pos(self): 1392 """Returns self, for unary operator `+`.""" 1393 return self 1394 1395 1396def _tensor_size(self): 1397 """Returns the number of elements in this Tensor, if fully known.""" 1398 if not self.shape.is_fully_defined(): 1399 return None 1400 return np.prod(self.shape.as_list()) 1401 1402 1403def _tensor_tolist(self): 1404 if isinstance(self, ops.EagerTensor): 1405 return self._numpy().tolist() # pylint: disable=protected-access 1406 1407 raise ValueError('Symbolic Tensors do not support the tolist API.') 1408 1409 1410def enable_numpy_methods_on_tensor(): 1411 """Adds additional NumPy methods on tf.Tensor class.""" 1412 t = property(_tensor_t) 1413 setattr(ops.Tensor, 'T', t) 1414 1415 ndim = property(_tensor_ndim) 1416 setattr(ops.Tensor, 'ndim', ndim) 1417 1418 size = property(_tensor_size) 1419 setattr(ops.Tensor, 'size', size) 1420 1421 setattr(ops.Tensor, '__pos__', _tensor_pos) 1422 setattr(ops.Tensor, 'tolist', _tensor_tolist) 1423 1424 # TODO(b/178540516): Make a custom `setattr` that changes the method's 1425 # docstring to the TF one. 1426 setattr(ops.Tensor, 'transpose', np_array_ops.transpose) 1427 setattr(ops.Tensor, 'reshape', np_array_ops._reshape_method_wrapper) # pylint: disable=protected-access 1428 setattr(ops.Tensor, 'ravel', np_array_ops.ravel) 1429 setattr(ops.Tensor, 'clip', clip) 1430 setattr(ops.Tensor, 'astype', math_ops.cast) 1431 setattr(ops.Tensor, '__round__', np_array_ops.around) 1432 setattr(ops.Tensor, 'max', np_array_ops.amax) 1433 setattr(ops.Tensor, 'mean', np_array_ops.mean) 1434 setattr(ops.Tensor, 'min', np_array_ops.amin) 1435 1436 # TODO(wangpeng): Remove `data` when all uses of it are removed 1437 data = property(lambda self: self) 1438 setattr(ops.Tensor, 'data', data) 1439