xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/numpy_ops/np_math_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Mathematical operations."""
16# pylint: disable=g-direct-tensorflow-import
17
18import numbers
19import sys
20
21import numpy as np
22
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import errors
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import bitwise_ops
29from tensorflow.python.ops import clip_ops
30from tensorflow.python.ops import control_flow_ops
31from tensorflow.python.ops import gen_math_ops
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops import nn_ops
34from tensorflow.python.ops import sort_ops
35from tensorflow.python.ops import special_math_ops
36from tensorflow.python.ops.numpy_ops import np_array_ops
37from tensorflow.python.ops.numpy_ops import np_arrays
38from tensorflow.python.ops.numpy_ops import np_dtypes
39from tensorflow.python.ops.numpy_ops import np_export
40from tensorflow.python.ops.numpy_ops import np_utils
41
42
43pi = np_export.np_export_constant(__name__, 'pi', np.pi)
44e = np_export.np_export_constant(__name__, 'e', np.e)
45inf = np_export.np_export_constant(__name__, 'inf', np.inf)
46
47
48@np_utils.np_doc_only('dot')
49def dot(a, b):  # pylint: disable=missing-docstring
50
51  def f(a, b):  # pylint: disable=missing-docstring
52    return np_utils.cond(
53        np_utils.logical_or(
54            math_ops.equal(array_ops.rank(a), 0),
55            math_ops.equal(array_ops.rank(b), 0)),
56        lambda: a * b,
57        lambda: np_utils.cond(  # pylint: disable=g-long-lambda
58            math_ops.equal(array_ops.rank(b), 1),
59            lambda: math_ops.tensordot(a, b, axes=[[-1], [-1]]),
60            lambda: math_ops.tensordot(a, b, axes=[[-1], [-2]])))
61
62  return _bin_op(f, a, b)
63
64
65# TODO(wangpeng): Make element-wise ops `ufunc`s
66def _bin_op(tf_fun, a, b, promote=True):
67  if promote:
68    a, b = np_array_ops._promote_dtype_binary(a, b)  # pylint: disable=protected-access
69  else:
70    a = np_array_ops.array(a)
71    b = np_array_ops.array(b)
72  return tf_fun(a, b)
73
74
75@np_utils.np_doc('add')
76def add(x1, x2):
77
78  def add_or_or(x1, x2):
79    if x1.dtype == dtypes.bool:
80      assert x2.dtype == dtypes.bool
81      return math_ops.logical_or(x1, x2)
82    return math_ops.add(x1, x2)
83
84  return _bin_op(add_or_or, x1, x2)
85
86
87@np_utils.np_doc('subtract')
88def subtract(x1, x2):
89  return _bin_op(math_ops.subtract, x1, x2)
90
91
92@np_utils.np_doc('multiply')
93def multiply(x1, x2):
94
95  def mul_or_and(x1, x2):
96    if x1.dtype == dtypes.bool:
97      assert x2.dtype == dtypes.bool
98      return math_ops.logical_and(x1, x2)
99    return math_ops.multiply(x1, x2)
100
101  return _bin_op(mul_or_and, x1, x2)
102
103
104@np_utils.np_doc('true_divide')
105def true_divide(x1, x2):  # pylint: disable=missing-function-docstring
106
107  def _avoid_float64(x1, x2):
108    if x1.dtype == x2.dtype and x1.dtype in (dtypes.int32, dtypes.int64):
109      x1 = math_ops.cast(x1, dtype=dtypes.float32)
110      x2 = math_ops.cast(x2, dtype=dtypes.float32)
111    return x1, x2
112
113  def f(x1, x2):
114    if x1.dtype == dtypes.bool:
115      assert x2.dtype == dtypes.bool
116      float_ = np_dtypes.default_float_type()
117      x1 = math_ops.cast(x1, float_)
118      x2 = math_ops.cast(x2, float_)
119    if not np_dtypes.is_allow_float64():
120      # math_ops.truediv in Python3 produces float64 when both inputs are int32
121      # or int64. We want to avoid that when is_allow_float64() is False.
122      x1, x2 = _avoid_float64(x1, x2)
123    return math_ops.truediv(x1, x2)
124
125  return _bin_op(f, x1, x2)
126
127
128@np_utils.np_doc('divide')
129def divide(x1, x2):  # pylint: disable=missing-function-docstring
130  return true_divide(x1, x2)
131
132
133@np_utils.np_doc('floor_divide')
134def floor_divide(x1, x2):  # pylint: disable=missing-function-docstring
135
136  def f(x1, x2):
137    if x1.dtype == dtypes.bool:
138      assert x2.dtype == dtypes.bool
139      x1 = math_ops.cast(x1, dtypes.int8)
140      x2 = math_ops.cast(x2, dtypes.int8)
141    return math_ops.floordiv(x1, x2)
142
143  return _bin_op(f, x1, x2)
144
145
146@np_utils.np_doc('mod')
147def mod(x1, x2):  # pylint: disable=missing-function-docstring
148
149  def f(x1, x2):
150    if x1.dtype == dtypes.bool:
151      assert x2.dtype == dtypes.bool
152      x1 = math_ops.cast(x1, dtypes.int8)
153      x2 = math_ops.cast(x2, dtypes.int8)
154    return math_ops.mod(x1, x2)
155
156  return _bin_op(f, x1, x2)
157
158
159@np_utils.np_doc('remainder')
160def remainder(x1, x2):  # pylint: disable=missing-function-docstring
161  return mod(x1, x2)
162
163
164@np_utils.np_doc('divmod')
165def divmod(x1, x2):  # pylint: disable=redefined-builtin
166  return floor_divide(x1, x2), mod(x1, x2)
167
168
169@np_utils.np_doc('maximum')
170def maximum(x1, x2):  # pylint: disable=missing-function-docstring
171
172  # Fast path for when maximum is used as relu.
173  if isinstance(
174      x2, numbers.Real) and not isinstance(x2, bool) and x2 == 0 and isinstance(
175          x1, np_arrays.ndarray) and x1.dtype != dtypes.bool:
176    return nn_ops.relu(np_array_ops.asarray(x1))
177
178  def max_or_or(x1, x2):
179    if x1.dtype == dtypes.bool:
180      assert x2.dtype == dtypes.bool
181      return math_ops.logical_or(x1, x2)
182    return math_ops.maximum(x1, x2)
183
184  return _bin_op(max_or_or, x1, x2)
185
186
187@np_utils.np_doc('minimum')
188def minimum(x1, x2):
189
190  def min_or_and(x1, x2):
191    if x1.dtype == dtypes.bool:
192      assert x2.dtype == dtypes.bool
193      return math_ops.logical_and(x1, x2)
194    return math_ops.minimum(x1, x2)
195
196  return _bin_op(min_or_and, x1, x2)
197
198
199@np_utils.np_doc('clip')
200def clip(a, a_min, a_max):  # pylint: disable=missing-docstring
201  if a_min is None and a_max is None:
202    raise ValueError('Not more than one of `a_min` and `a_max` may be `None`.')
203  if a_min is None:
204    return minimum(a, a_max)
205  elif a_max is None:
206    return maximum(a, a_min)
207  else:
208    a, a_min, a_max = np_array_ops._promote_dtype(a, a_min, a_max)  # pylint: disable=protected-access
209    return clip_ops.clip_by_value(*np_utils.tf_broadcast(a, a_min, a_max))
210
211
212@np_utils.np_doc('matmul')
213def matmul(x1, x2):  # pylint: disable=missing-docstring
214  def f(x1, x2):
215    try:
216      if x1._rank() == 2 and x2._rank() == 2:  # pylint: disable=protected-access
217        # Fast path for known ranks.
218        return gen_math_ops.mat_mul(x1, x2)
219      return np_utils.cond(
220          math_ops.equal(np_utils.tf_rank(x2), 1),
221          lambda: math_ops.tensordot(x1, x2, axes=1),
222          lambda: np_utils.cond(  # pylint: disable=g-long-lambda
223              math_ops.equal(np_utils.tf_rank(x1), 1),
224              lambda: math_ops.tensordot(  # pylint: disable=g-long-lambda
225                  x1, x2, axes=[[0], [-2]]),
226              lambda: math_ops.matmul(x1, x2)))
227    except errors.InvalidArgumentError as err:
228      raise ValueError(str(err)).with_traceback(sys.exc_info()[2])
229
230  return _bin_op(f, x1, x2)
231
232
233# Exported so it can be called from Tensor.__matmul__. NumPy's matmul handles
234# batched matmul as well, so simply including promotion in TF's current
235# __matmul__ implementation was not sufficient.
236setattr(np_arrays.ndarray, '_matmul', matmul)
237
238
239@np_utils.np_doc('tensordot')
240def tensordot(a, b, axes=2):
241  return _bin_op(lambda a, b: math_ops.tensordot(a, b, axes=axes), a, b)
242
243
244@np_utils.np_doc_only('inner')
245def inner(a, b):  # pylint: disable=missing-function-docstring
246
247  def f(a, b):
248    return np_utils.cond(
249        np_utils.logical_or(
250            math_ops.equal(array_ops.rank(a), 0),
251            math_ops.equal(array_ops.rank(b), 0)), lambda: a * b,
252        lambda: math_ops.tensordot(a, b, axes=[[-1], [-1]]))
253
254  return _bin_op(f, a, b)
255
256
257@np_utils.np_doc('cross')
258def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):  # pylint: disable=missing-docstring
259
260  def f(a, b):  # pylint: disable=missing-docstring
261    # We can't assign to captured variable `axisa`, so make a new variable
262    if axis is None:
263      axis_a = axisa
264      axis_b = axisb
265      axis_c = axisc
266    else:
267      axis_a = axis
268      axis_b = axis
269      axis_c = axis
270    if axis_a < 0:
271      axis_a = np_utils.add(axis_a, array_ops.rank(a))
272    if axis_b < 0:
273      axis_b = np_utils.add(axis_b, array_ops.rank(b))
274
275    def maybe_move_axis_to_last(a, axis):
276
277      def move_axis_to_last(a, axis):
278        return array_ops.transpose(
279            a,
280            array_ops.concat([
281                math_ops.range(axis),
282                math_ops.range(axis + 1, array_ops.rank(a)), [axis]
283            ],
284                             axis=0))
285
286      return np_utils.cond(axis == np_utils.subtract(array_ops.rank(a), 1),
287                           lambda: a, lambda: move_axis_to_last(a, axis))
288
289    a = maybe_move_axis_to_last(a, axis_a)
290    b = maybe_move_axis_to_last(b, axis_b)
291    a_dim = np_utils.getitem(array_ops.shape(a), -1)
292    b_dim = np_utils.getitem(array_ops.shape(b), -1)
293
294    def maybe_pad_0(a, size_of_last_dim):
295
296      def pad_0(a):
297        return array_ops.pad(
298            a,
299            array_ops.concat([
300                array_ops.zeros([array_ops.rank(a) - 1, 2], dtypes.int32),
301                constant_op.constant([[0, 1]], dtypes.int32)
302            ],
303                             axis=0))
304
305      return np_utils.cond(
306          math_ops.equal(size_of_last_dim, 2), lambda: pad_0(a), lambda: a)
307
308    a = maybe_pad_0(a, a_dim)
309    b = maybe_pad_0(b, b_dim)
310    c = math_ops.cross(*np_utils.tf_broadcast(a, b))
311    if axis_c < 0:
312      axis_c = np_utils.add(axis_c, array_ops.rank(c))
313
314    def move_last_to_axis(a, axis):
315      r = array_ops.rank(a)
316      return array_ops.transpose(
317          a,
318          array_ops.concat(
319              [math_ops.range(axis), [r - 1],
320               math_ops.range(axis, r - 1)],
321              axis=0))
322
323    c = np_utils.cond(
324        (a_dim == 2) & (b_dim == 2),
325        lambda: c[..., 2],
326        lambda: np_utils.cond(  # pylint: disable=g-long-lambda
327            axis_c == np_utils.subtract(array_ops.rank(c), 1), lambda: c,
328            lambda: move_last_to_axis(c, axis_c)))
329    return c
330
331  return _bin_op(f, a, b)
332
333
334@np_utils.np_doc_only('vdot')
335def vdot(a, b):  # pylint: disable=missing-docstring
336  a, b = np_array_ops._promote_dtype(a, b)  # pylint: disable=protected-access
337  a = np_array_ops.reshape(a, [-1])
338  b = np_array_ops.reshape(b, [-1])
339  if a.dtype == np_dtypes.complex128 or a.dtype == np_dtypes.complex64:
340    a = conj(a)
341  return dot(a, b)
342
343
344@np_utils.np_doc('power')
345def power(x1, x2):
346  return _bin_op(math_ops.pow, x1, x2)
347
348
349@np_utils.np_doc('float_power')
350def float_power(x1, x2):
351  return power(x1, x2)
352
353
354@np_utils.np_doc('arctan2')
355def arctan2(x1, x2):
356  return _bin_op(math_ops.atan2, x1, x2)
357
358
359@np_utils.np_doc('nextafter')
360def nextafter(x1, x2):
361  return _bin_op(math_ops.nextafter, x1, x2)
362
363
364@np_utils.np_doc('heaviside')
365def heaviside(x1, x2):  # pylint: disable=missing-function-docstring
366
367  def f(x1, x2):
368    return array_ops.where_v2(
369        x1 < 0, constant_op.constant(0, dtype=x2.dtype),
370        array_ops.where_v2(x1 > 0, constant_op.constant(1, dtype=x2.dtype), x2))
371
372  y = _bin_op(f, x1, x2)
373  if not np.issubdtype(y.dtype.as_numpy_dtype, np.inexact):
374    y = y.astype(np_dtypes.default_float_type())
375  return y
376
377
378@np_utils.np_doc('hypot')
379def hypot(x1, x2):
380  return sqrt(square(x1) + square(x2))
381
382
383@np_utils.np_doc('kron')
384def kron(a, b):  # pylint: disable=missing-function-docstring
385  # pylint: disable=protected-access,g-complex-comprehension
386  a, b = np_array_ops._promote_dtype(a, b)
387  t_a = np_utils.cond(
388      a.ndim < b.ndim,
389      lambda: np_array_ops.reshape(  # pylint: disable=g-long-lambda
390          a, np_array_ops._pad_left_to(b.ndim, a.shape)),
391      lambda: a)
392  t_b = np_utils.cond(
393      b.ndim < a.ndim,
394      lambda: np_array_ops.reshape(  # pylint: disable=g-long-lambda
395          b, np_array_ops._pad_left_to(a.ndim, b.shape)),
396      lambda: b)
397
398  def _make_shape(shape, prepend):
399    ones = array_ops.ones_like(shape)
400    if prepend:
401      shapes = [ones, shape]
402    else:
403      shapes = [shape, ones]
404    return array_ops.reshape(array_ops.stack(shapes, axis=1), [-1])
405
406  a_shape = array_ops.shape(t_a)
407  b_shape = array_ops.shape(t_b)
408  a_reshaped = np_array_ops.reshape(t_a, _make_shape(a_shape, False))
409  b_reshaped = np_array_ops.reshape(t_b, _make_shape(b_shape, True))
410  out_shape = a_shape * b_shape
411  return np_array_ops.reshape(a_reshaped * b_reshaped, out_shape)
412
413
414@np_utils.np_doc('outer')
415def outer(a, b):
416
417  def f(a, b):
418    return array_ops.reshape(a, [-1, 1]) * array_ops.reshape(b, [-1])
419
420  return _bin_op(f, a, b)
421
422
423# This can also be implemented via tf.reduce_logsumexp
424@np_utils.np_doc('logaddexp')
425def logaddexp(x1, x2):
426  amax = maximum(x1, x2)
427  delta = x1 - x2
428  return np_array_ops.where(
429      isnan(delta),
430      x1 + x2,  # NaNs or infinities of the same sign.
431      amax + log1p(exp(-abs(delta))))
432
433
434@np_utils.np_doc('logaddexp2')
435def logaddexp2(x1, x2):
436  amax = maximum(x1, x2)
437  delta = x1 - x2
438  return np_array_ops.where(
439      isnan(delta),
440      x1 + x2,  # NaNs or infinities of the same sign.
441      amax + log1p(exp2(-abs(delta))) / np.log(2))
442
443
444@np_utils.np_doc('polyval')
445def polyval(p, x):  # pylint: disable=missing-function-docstring
446
447  def f(p, x):
448    if p.shape.rank == 0:
449      p = array_ops.reshape(p, [1])
450    p = array_ops.unstack(p)
451    # TODO(wangpeng): Make tf version take a tensor for p instead of a list.
452    y = math_ops.polyval(p, x)
453    # If the polynomial is 0-order, numpy requires the result to be broadcast to
454    # `x`'s shape.
455    if len(p) == 1:
456      y = array_ops.broadcast_to(y, x.shape)
457    return y
458
459  return _bin_op(f, p, x)
460
461
462@np_utils.np_doc('isclose')
463def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):  # pylint: disable=missing-docstring
464
465  def f(a, b):  # pylint: disable=missing-docstring
466    dtype = a.dtype
467    if np.issubdtype(dtype.as_numpy_dtype, np.inexact):
468      rtol_ = ops.convert_to_tensor(rtol, dtype.real_dtype)
469      atol_ = ops.convert_to_tensor(atol, dtype.real_dtype)
470      result = (math_ops.abs(a - b) <= atol_ + rtol_ * math_ops.abs(b))
471      if equal_nan:
472        result = result | (math_ops.is_nan(a) & math_ops.is_nan(b))
473      return result
474    else:
475      return a == b
476
477  return _bin_op(f, a, b)
478
479
480@np_utils.np_doc('allclose')
481def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
482  return np_array_ops.all(
483      isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan))
484
485
486def _tf_gcd(x1, x2):  # pylint: disable=missing-function-docstring
487
488  def _gcd_cond_fn(_, x2):
489    return math_ops.reduce_any(x2 != 0)
490
491  def _gcd_body_fn(x1, x2):
492    # math_ops.mod will raise an error when any element of x2 is 0. To avoid
493    # that, we change those zeros to ones. Their values don't matter because
494    # they won't be used.
495    x2_safe = array_ops.where_v2(x2 != 0, x2, constant_op.constant(1, x2.dtype))
496    x1, x2 = (array_ops.where_v2(x2 != 0, x2, x1),
497              array_ops.where_v2(x2 != 0, math_ops.mod(x1, x2_safe),
498                                 constant_op.constant(0, x2.dtype)))
499    return (array_ops.where_v2(x1 < x2, x2,
500                               x1), array_ops.where_v2(x1 < x2, x1, x2))
501
502  if (not np.issubdtype(x1.dtype.as_numpy_dtype, np.integer) or
503      not np.issubdtype(x2.dtype.as_numpy_dtype, np.integer)):
504    raise ValueError('Arguments to gcd must be integers.')
505  shape = array_ops.broadcast_dynamic_shape(
506      array_ops.shape(x1), array_ops.shape(x2))
507  x1 = array_ops.broadcast_to(x1, shape)
508  x2 = array_ops.broadcast_to(x2, shape)
509  value, _ = control_flow_ops.while_loop(_gcd_cond_fn, _gcd_body_fn,
510                                         (math_ops.abs(x1), math_ops.abs(x2)))
511  return value
512
513
514# Note that np.gcd may not be present in some supported versions of numpy.
515@np_utils.np_doc('gcd')
516def gcd(x1, x2):
517  return _bin_op(_tf_gcd, x1, x2)
518
519
520# Note that np.lcm may not be present in some supported versions of numpy.
521@np_utils.np_doc('lcm')
522def lcm(x1, x2):  # pylint: disable=missing-function-docstring
523
524  def f(x1, x2):
525    d = _tf_gcd(x1, x2)
526    # Same as the `x2_safe` trick above
527    d_safe = array_ops.where_v2(
528        math_ops.equal(d, 0), constant_op.constant(1, d.dtype), d)
529    return array_ops.where_v2(
530        math_ops.equal(d, 0), constant_op.constant(0, d.dtype),
531        math_ops.abs(x1 * x2) // d_safe)
532
533  return _bin_op(f, x1, x2)
534
535
536def _bitwise_binary_op(tf_fn, x1, x2):  # pylint: disable=missing-function-docstring
537
538  def f(x1, x2):
539    is_bool = (x1.dtype == dtypes.bool)
540    if is_bool:
541      assert x2.dtype == dtypes.bool
542      x1 = math_ops.cast(x1, dtypes.int8)
543      x2 = math_ops.cast(x2, dtypes.int8)
544    r = tf_fn(x1, x2)
545    if is_bool:
546      r = math_ops.cast(r, dtypes.bool)
547    return r
548
549  return _bin_op(f, x1, x2)
550
551
552@np_utils.np_doc('bitwise_and')
553def bitwise_and(x1, x2):
554  return _bitwise_binary_op(bitwise_ops.bitwise_and, x1, x2)
555
556
557@np_utils.np_doc('bitwise_or')
558def bitwise_or(x1, x2):
559  return _bitwise_binary_op(bitwise_ops.bitwise_or, x1, x2)
560
561
562@np_utils.np_doc('bitwise_xor')
563def bitwise_xor(x1, x2):
564  return _bitwise_binary_op(bitwise_ops.bitwise_xor, x1, x2)
565
566
567@np_utils.np_doc('bitwise_not', link=np_utils.AliasOf('invert'))
568def bitwise_not(x):
569
570  def f(x):
571    if x.dtype == dtypes.bool:
572      return math_ops.logical_not(x)
573    return bitwise_ops.invert(x)
574
575  return _scalar(f, x)
576
577
578def _scalar(tf_fn, x, promote_to_float=False):
579  """Computes the tf_fn(x) for each element in `x`.
580
581  Args:
582    tf_fn: function that takes a single Tensor argument.
583    x: array_like. Could be an ndarray, a Tensor or any object that can be
584      converted to a Tensor using `ops.convert_to_tensor`.
585    promote_to_float: whether to cast the argument to a float dtype
586      (`np_dtypes.default_float_type`) if it is not already.
587
588  Returns:
589    An ndarray with the same shape as `x`. The default output dtype is
590    determined by `np_dtypes.default_float_type`, unless x is an ndarray with a
591    floating point type, in which case the output type is same as x.dtype.
592  """
593  x = np_array_ops.asarray(x)
594  if promote_to_float and not np.issubdtype(x.dtype.as_numpy_dtype, np.inexact):
595    x = x.astype(np_dtypes.default_float_type())
596  return tf_fn(x)
597
598
599@np_utils.np_doc('log')
600def log(x):
601  return _scalar(math_ops.log, x, True)
602
603
604@np_utils.np_doc('exp')
605def exp(x):
606  return _scalar(math_ops.exp, x, True)
607
608
609@np_utils.np_doc('sqrt')
610def sqrt(x):
611  return _scalar(math_ops.sqrt, x, True)
612
613
614@np_utils.np_doc('abs', link=np_utils.AliasOf('absolute'))
615def abs(x):  # pylint: disable=redefined-builtin
616  return _scalar(math_ops.abs, x)
617
618
619@np_utils.np_doc('absolute')
620def absolute(x):
621  return abs(x)
622
623
624@np_utils.np_doc('fabs')
625def fabs(x):
626  return abs(x)
627
628
629@np_utils.np_doc('ceil')
630def ceil(x):
631  return _scalar(math_ops.ceil, x, True)
632
633
634@np_utils.np_doc('floor')
635def floor(x):
636  return _scalar(math_ops.floor, x, True)
637
638
639@np_utils.np_doc('conj')
640def conj(x):
641  return _scalar(math_ops.conj, x)
642
643
644@np_utils.np_doc('negative')
645def negative(x):
646  return _scalar(math_ops.negative, x)
647
648
649@np_utils.np_doc('reciprocal')
650def reciprocal(x):
651  return _scalar(math_ops.reciprocal, x)
652
653
654@np_utils.np_doc('signbit')
655def signbit(x):
656
657  def f(x):
658    if x.dtype == dtypes.bool:
659      return array_ops.fill(array_ops.shape(x), False)
660    return x < 0
661
662  return _scalar(f, x)
663
664
665@np_utils.np_doc('sin')
666def sin(x):
667  return _scalar(math_ops.sin, x, True)
668
669
670@np_utils.np_doc('cos')
671def cos(x):
672  return _scalar(math_ops.cos, x, True)
673
674
675@np_utils.np_doc('tan')
676def tan(x):
677  return _scalar(math_ops.tan, x, True)
678
679
680@np_utils.np_doc('sinh')
681def sinh(x):
682  return _scalar(math_ops.sinh, x, True)
683
684
685@np_utils.np_doc('cosh')
686def cosh(x):
687  return _scalar(math_ops.cosh, x, True)
688
689
690@np_utils.np_doc('tanh')
691def tanh(x):
692  return _scalar(math_ops.tanh, x, True)
693
694
695@np_utils.np_doc('arcsin')
696def arcsin(x):
697  return _scalar(math_ops.asin, x, True)
698
699
700@np_utils.np_doc('arccos')
701def arccos(x):
702  return _scalar(math_ops.acos, x, True)
703
704
705@np_utils.np_doc('arctan')
706def arctan(x):
707  return _scalar(math_ops.atan, x, True)
708
709
710@np_utils.np_doc('arcsinh')
711def arcsinh(x):
712  return _scalar(math_ops.asinh, x, True)
713
714
715@np_utils.np_doc('arccosh')
716def arccosh(x):
717  return _scalar(math_ops.acosh, x, True)
718
719
720@np_utils.np_doc('arctanh')
721def arctanh(x):
722  return _scalar(math_ops.atanh, x, True)
723
724
725@np_utils.np_doc('deg2rad')
726def deg2rad(x):
727
728  def f(x):
729    return x * (np.pi / 180.0)
730
731  return _scalar(f, x, True)
732
733
734@np_utils.np_doc('rad2deg')
735def rad2deg(x):
736  return x * (180.0 / np.pi)
737
738
739_tf_float_types = [
740    dtypes.bfloat16, dtypes.float16, dtypes.float32, dtypes.float64
741]
742
743
744@np_utils.np_doc('angle')
745def angle(z, deg=False):  # pylint: disable=missing-function-docstring
746
747  def f(x):
748    if x.dtype in _tf_float_types:
749      # Workaround for b/147515503
750      return array_ops.where_v2(x < 0, np.pi, 0)
751    else:
752      return math_ops.angle(x)
753
754  y = _scalar(f, z, True)
755  if deg:
756    y = rad2deg(y)
757  return y
758
759
760@np_utils.np_doc('cbrt')
761def cbrt(x):
762
763  def f(x):
764    # __pow__ can't handle negative base, so we use `abs` here.
765    rt = math_ops.abs(x)**(1.0 / 3)
766    return array_ops.where_v2(x < 0, -rt, rt)
767
768  return _scalar(f, x, True)
769
770
771@np_utils.np_doc('conjugate', link=np_utils.AliasOf('conj'))
772def conjugate(x):
773  return _scalar(math_ops.conj, x)
774
775
776@np_utils.np_doc('exp2')
777def exp2(x):
778
779  def f(x):
780    return 2**x
781
782  return _scalar(f, x, True)
783
784
785@np_utils.np_doc('expm1')
786def expm1(x):
787  return _scalar(math_ops.expm1, x, True)
788
789
790@np_utils.np_doc('fix')
791def fix(x):
792
793  def f(x):
794    return array_ops.where_v2(x < 0, math_ops.ceil(x), math_ops.floor(x))
795
796  return _scalar(f, x, True)
797
798
799@np_utils.np_doc('iscomplex')
800def iscomplex(x):
801  return np_array_ops.imag(x) != 0
802
803
804@np_utils.np_doc('isreal')
805def isreal(x):
806  return np_array_ops.imag(x) == 0
807
808
809@np_utils.np_doc('iscomplexobj')
810def iscomplexobj(x):
811  x = np_array_ops.array(x)
812  return np.issubdtype(x.dtype.as_numpy_dtype, np.complexfloating)
813
814
815@np_utils.np_doc('isrealobj')
816def isrealobj(x):
817  return not iscomplexobj(x)
818
819
820@np_utils.np_doc('isnan')
821def isnan(x):
822  return _scalar(math_ops.is_nan, x, True)
823
824
825def _make_nan_reduction(np_fun_name, reduction, init_val):
826  """Helper to generate nan* functions."""
827
828  @np_utils.np_doc(np_fun_name)
829  def nan_reduction(a, axis=None, dtype=None, keepdims=False):
830    a = np_array_ops.array(a)
831    v = np_array_ops.array(init_val, dtype=a.dtype)
832    return reduction(
833        np_array_ops.where(isnan(a), v, a),
834        axis=axis,
835        dtype=dtype,
836        keepdims=keepdims)
837
838  return nan_reduction
839
840
841nansum = _make_nan_reduction('nansum', np_array_ops.sum, 0)
842nanprod = _make_nan_reduction('nanprod', np_array_ops.prod, 1)
843
844
845@np_utils.np_doc('nanmean')
846def nanmean(a, axis=None, dtype=None, keepdims=None):  # pylint: disable=missing-docstring
847  a = np_array_ops.array(a)
848  if np.issubdtype(a.dtype.as_numpy_dtype, np.bool_) or np.issubdtype(
849      a.dtype.as_numpy_dtype, np.integer):
850    return np_array_ops.mean(a, axis=axis, dtype=dtype, keepdims=keepdims)
851  nan_mask = logical_not(isnan(a))
852  if dtype is None:
853    dtype = a.dtype.as_numpy_dtype
854  normalizer = np_array_ops.sum(
855      nan_mask, axis=axis, dtype=dtype, keepdims=keepdims)
856  return nansum(a, axis=axis, dtype=dtype, keepdims=keepdims) / normalizer
857
858
859@np_utils.np_doc('isfinite')
860def isfinite(x):
861  return _scalar(math_ops.is_finite, x, True)
862
863
864@np_utils.np_doc('isinf')
865def isinf(x):
866  return _scalar(math_ops.is_inf, x, True)
867
868
869@np_utils.np_doc('isneginf')
870def isneginf(x):
871  return x == np_array_ops.full_like(x, -np.inf)
872
873
874@np_utils.np_doc('isposinf')
875def isposinf(x):
876  return x == np_array_ops.full_like(x, np.inf)
877
878
879@np_utils.np_doc('log2')
880def log2(x):
881  return log(x) / np.log(2)
882
883
884@np_utils.np_doc('log10')
885def log10(x):
886  return log(x) / np.log(10)
887
888
889@np_utils.np_doc('log1p')
890def log1p(x):
891  return _scalar(math_ops.log1p, x, True)
892
893
894@np_utils.np_doc('positive')
895def positive(x):
896  return _scalar(lambda x: x, x)
897
898
899@np_utils.np_doc('sinc')
900def sinc(x):
901
902  def f(x):
903    pi_x = x * np.pi
904    return array_ops.where_v2(x == 0, array_ops.ones_like(x),
905                              math_ops.sin(pi_x) / pi_x)
906
907  return _scalar(f, x, True)
908
909
910@np_utils.np_doc('square')
911def square(x):
912  return _scalar(math_ops.square, x)
913
914
915@np_utils.np_doc('diff')
916def diff(a, n=1, axis=-1):  # pylint: disable=missing-function-docstring
917
918  def f(a):
919    # TODO(agarwal): transpose and reshape to N, H, 1 and do a 1D convolution
920    # TODO(agarwal): avoid depending on static rank.
921    nd = a.shape.rank
922    if nd is None:
923      raise ValueError(
924          'Function `diff` currently requires a known rank for input `a`. '
925          f'Received: a={a} (unknown rank)')
926    if (axis + nd if axis < 0 else axis) >= nd:
927      raise ValueError(
928          f'Argument `axis` (received axis={axis}) is out of bounds '
929          f'for input {a} of rank {nd}.')
930    if n < 0:
931      raise ValueError('Argument `order` must be a non-negative integer. '
932                       f'Received: axis={n}')
933    slice1 = [slice(None)] * nd
934    slice2 = [slice(None)] * nd
935    slice1[axis] = slice(1, None)
936    slice2[axis] = slice(None, -1)
937    slice1 = tuple(slice1)
938    slice2 = tuple(slice2)
939    op = math_ops.not_equal if a.dtype == dtypes.bool else math_ops.subtract
940    for _ in range(n):
941      a = op(a[slice1], a[slice2])
942    return a
943
944  return _scalar(f, a)
945
946
947def _wrap(f, reverse=False):
948  """Wraps binary ops so they can be added as operator overloads on ndarray."""
949
950  def _f(a, b):
951    if reverse:
952      a, b = b, a
953
954    if getattr(b, '__array_priority__',
955               0) > np_arrays.ndarray.__array_priority__:
956      return NotImplemented
957
958    return f(a, b)
959
960  return _f
961
962
963def _comparison(tf_fun, x1, x2, cast_bool_to_int=False):
964  """Helper function for comparision."""
965  dtype = np_utils.result_type(x1, x2)
966  # Cast x1 and x2 to the result_type if needed.
967  x1 = np_array_ops.array(x1, dtype=dtype)
968  x2 = np_array_ops.array(x2, dtype=dtype)
969  if cast_bool_to_int and x1.dtype == dtypes.bool:
970    x1 = math_ops.cast(x1, dtypes.int32)
971    x2 = math_ops.cast(x2, dtypes.int32)
972  return tf_fun(x1, x2)
973
974
975@np_utils.np_doc('equal')
976def equal(x1, x2):
977  return _comparison(math_ops.equal, x1, x2)
978
979
980@np_utils.np_doc('not_equal')
981def not_equal(x1, x2):
982  return _comparison(math_ops.not_equal, x1, x2)
983
984
985@np_utils.np_doc('greater')
986def greater(x1, x2):
987  return _comparison(math_ops.greater, x1, x2, True)
988
989
990@np_utils.np_doc('greater_equal')
991def greater_equal(x1, x2):
992  return _comparison(math_ops.greater_equal, x1, x2, True)
993
994
995@np_utils.np_doc('less')
996def less(x1, x2):
997  return _comparison(math_ops.less, x1, x2, True)
998
999
1000@np_utils.np_doc('less_equal')
1001def less_equal(x1, x2):
1002  return _comparison(math_ops.less_equal, x1, x2, True)
1003
1004
1005@np_utils.np_doc('array_equal')
1006def array_equal(a1, a2):  # pylint: disable=missing-function-docstring
1007
1008  def f(x1, x2):
1009    return np_utils.cond(
1010        math_ops.equal(array_ops.rank(x1), array_ops.rank(x2)),
1011        lambda: np_utils.cond(  # pylint: disable=g-long-lambda
1012            np_utils.reduce_all(
1013                math_ops.equal(array_ops.shape(x1), array_ops.shape(x2))
1014            ),
1015            lambda: math_ops.reduce_all(math_ops.equal(x1, x2)),
1016            lambda: constant_op.constant(False)),
1017        lambda: constant_op.constant(False))
1018
1019  return _comparison(f, a1, a2)
1020
1021
1022def _logical_binary_op(tf_fun, x1, x2):
1023  x1 = np_array_ops.array(x1, dtype=np.bool_)
1024  x2 = np_array_ops.array(x2, dtype=np.bool_)
1025  return tf_fun(x1, x2)
1026
1027
1028@np_utils.np_doc('logical_and')
1029def logical_and(x1, x2):
1030  return _logical_binary_op(math_ops.logical_and, x1, x2)
1031
1032
1033@np_utils.np_doc('logical_or')
1034def logical_or(x1, x2):
1035  return _logical_binary_op(math_ops.logical_or, x1, x2)
1036
1037
1038@np_utils.np_doc('logical_xor')
1039def logical_xor(x1, x2):
1040  return _logical_binary_op(math_ops.logical_xor, x1, x2)
1041
1042
1043@np_utils.np_doc('logical_not')
1044def logical_not(x):
1045  x = np_array_ops.array(x, dtype=np.bool_)
1046  return math_ops.logical_not(x)
1047
1048
1049@np_utils.np_doc('linspace')
1050def linspace(  # pylint: disable=missing-docstring
1051    start,
1052    stop,
1053    num=50,
1054    endpoint=True,
1055    retstep=False,
1056    dtype=float,
1057    axis=0):
1058  if dtype:
1059    dtype = np_utils.result_type(dtype)
1060  start = np_array_ops.array(start, dtype=dtype)
1061  stop = np_array_ops.array(stop, dtype=dtype)
1062  if num < 0:
1063    raise ValueError(
1064        'Argument `num` (number of samples) must be a non-negative integer. '
1065        f'Received: num={num}')
1066  step = ops.convert_to_tensor(np.nan)
1067  if endpoint:
1068    result = math_ops.linspace(start, stop, num, axis=axis)
1069    if num > 1:
1070      step = (stop - start) / (num - 1)
1071  else:
1072    # math_ops.linspace does not support endpoint=False so we manually handle it
1073    # here.
1074    if num > 0:
1075      step = ((stop - start) / num)
1076    if num > 1:
1077      new_stop = math_ops.cast(stop, step.dtype) - step
1078      start = math_ops.cast(start, new_stop.dtype)
1079      result = math_ops.linspace(start, new_stop, num, axis=axis)
1080    else:
1081      result = math_ops.linspace(start, stop, num, axis=axis)
1082  if dtype:
1083    if dtype.is_integer:
1084      # Since numpy 1.20, linspace's rounding is towards -inf instead of 0
1085      result = math_ops.floor(result)
1086    result = math_ops.cast(result, dtype)
1087  if retstep:
1088    return (result, step)
1089  else:
1090    return result
1091
1092
1093@np_utils.np_doc('logspace')
1094def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0):
1095  dtype = np_utils.result_type(start, stop, dtype)
1096  result = linspace(
1097      start, stop, num=num, endpoint=endpoint, dtype=dtype, axis=axis)
1098  result = math_ops.pow(math_ops.cast(base, result.dtype), result)
1099  if dtype:
1100    result = math_ops.cast(result, dtype)
1101  return result
1102
1103
1104@np_utils.np_doc('geomspace')
1105def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):  # pylint: disable=missing-docstring
1106  dtype = dtypes.as_dtype(dtype) if dtype else np_utils.result_type(
1107      start, stop, float(num), np_array_ops.zeros((), dtype))
1108  computation_dtype = np.promote_types(dtype.as_numpy_dtype, np.float32)
1109  start = np_array_ops.asarray(start, dtype=computation_dtype)
1110  stop = np_array_ops.asarray(stop, dtype=computation_dtype)
1111  # follow the numpy geomspace convention for negative and complex endpoints
1112  start_sign = 1 - np_array_ops.sign(np_array_ops.real(start))
1113  stop_sign = 1 - np_array_ops.sign(np_array_ops.real(stop))
1114  signflip = 1 - start_sign * stop_sign // 2
1115  res = signflip * logspace(
1116      log10(signflip * start),
1117      log10(signflip * stop),
1118      num,
1119      endpoint=endpoint,
1120      base=10.0,
1121      dtype=computation_dtype,
1122      axis=0)
1123  if axis != 0:
1124    res = np_array_ops.moveaxis(res, 0, axis)
1125  return math_ops.cast(res, dtype)
1126
1127
1128@np_utils.np_doc('ptp')
1129def ptp(a, axis=None, keepdims=None):
1130  return (np_array_ops.amax(a, axis=axis, keepdims=keepdims) -
1131          np_array_ops.amin(a, axis=axis, keepdims=keepdims))
1132
1133
1134@np_utils.np_doc_only('concatenate')
1135def concatenate(arys, axis=0):
1136  if not isinstance(arys, (list, tuple)):
1137    arys = [arys]
1138  if not arys:
1139    raise ValueError('Need at least one array to concatenate. Received empty '
1140                     f'input: arys={arys}')
1141  dtype = np_utils.result_type(*arys)
1142  arys = [np_array_ops.array(array, dtype=dtype) for array in arys]
1143  return array_ops.concat(arys, axis)
1144
1145
1146@np_utils.np_doc_only('tile')
1147def tile(a, reps):  # pylint: disable=missing-function-docstring
1148  a = np_array_ops.array(a)
1149  reps = np_array_ops.array(reps, dtype=dtypes.int32).reshape([-1])
1150
1151  a_rank = array_ops.rank(a)
1152  reps_size = array_ops.size(reps)
1153  reps = array_ops.pad(
1154      reps, [[math_ops.maximum(a_rank - reps_size, 0), 0]], constant_values=1)
1155  a_shape = array_ops.pad(
1156      array_ops.shape(a), [[math_ops.maximum(reps_size - a_rank, 0), 0]],
1157      constant_values=1)
1158  a = array_ops.reshape(a, a_shape)
1159
1160  return array_ops.tile(a, reps)
1161
1162
1163@np_utils.np_doc('count_nonzero')
1164def count_nonzero(a, axis=None):
1165  return math_ops.count_nonzero(np_array_ops.array(a), axis)
1166
1167
1168@np_utils.np_doc('argsort')
1169def argsort(a, axis=-1, kind='quicksort', order=None):  # pylint: disable=missing-docstring
1170  # TODO(nareshmodi): make string tensors also work.
1171  if kind not in ('quicksort', 'stable'):
1172    raise ValueError(
1173        'Invalid value for argument `kind`. '
1174        'Only kind="quicksort" and kind="stable" are supported. '
1175        f'Received: kind={kind}')
1176  if order is not None:
1177    raise ValueError('The `order` argument is not supported. Pass order=None')
1178  stable = (kind == 'stable')
1179
1180  a = np_array_ops.array(a)
1181
1182  def _argsort(a, axis, stable):
1183    if axis is None:
1184      a = array_ops.reshape(a, [-1])
1185      axis = 0
1186
1187    return sort_ops.argsort(a, axis, stable=stable)
1188
1189  tf_ans = np_utils.cond(
1190      math_ops.equal(array_ops.rank(a), 0), lambda: constant_op.constant([0]),
1191      lambda: _argsort(a, axis, stable))
1192
1193  return np_array_ops.array(tf_ans, dtype=np.intp)
1194
1195
1196@np_utils.np_doc('sort')
1197def sort(a, axis=-1, kind='quicksort', order=None):  # pylint: disable=missing-docstring
1198  if kind != 'quicksort':
1199    raise ValueError(
1200        'Invalid value for argument `kind`. '
1201        'Only kind="quicksort" is supported. '
1202        f'Received: kind={kind}')
1203  if order is not None:
1204    raise ValueError('The `order` argument is not supported. Pass order=None')
1205
1206  a = np_array_ops.array(a)
1207
1208  if axis is None:
1209    return sort_ops.sort(array_ops.reshape(a, [-1]), 0)
1210  else:
1211    return sort_ops.sort(a, axis)
1212
1213
1214def _argminmax(fn, a, axis=None):
1215  a = np_array_ops.array(a)
1216  if axis is None:
1217    # When axis is None numpy flattens the array.
1218    a_t = array_ops.reshape(a, [-1])
1219  else:
1220    a_t = np_array_ops.atleast_1d(a)
1221  return fn(input=a_t, axis=axis)
1222
1223
1224@np_utils.np_doc('argmax')
1225def argmax(a, axis=None):
1226  return _argminmax(math_ops.argmax, a, axis)
1227
1228
1229@np_utils.np_doc('argmin')
1230def argmin(a, axis=None):
1231  return _argminmax(math_ops.argmin, a, axis)
1232
1233
1234@np_utils.np_doc('append')
1235def append(arr, values, axis=None):
1236  if axis is None:
1237    return concatenate([np_array_ops.ravel(arr), np_array_ops.ravel(values)], 0)
1238  else:
1239    return concatenate([arr, values], axis=axis)
1240
1241
1242@np_utils.np_doc('average')
1243def average(a, axis=None, weights=None, returned=False):  # pylint: disable=missing-docstring
1244  if axis is not None and not isinstance(axis, int):
1245    # TODO(wangpeng): Support tuple of ints as `axis`
1246    raise ValueError('Argument `axis` must be an integer. '
1247                     f'Received axis={axis} (of type {type(axis)})')
1248  a = np_array_ops.array(a)
1249  if weights is None:  # Treat all weights as 1
1250    if not np.issubdtype(a.dtype.as_numpy_dtype, np.inexact):
1251      a = a.astype(
1252          np_utils.result_type(a.dtype, np_dtypes.default_float_type()))
1253    avg = math_ops.reduce_mean(a, axis=axis)
1254    if returned:
1255      if axis is None:
1256        weights_sum = array_ops.size(a)
1257      else:
1258        weights_sum = array_ops.shape(a)[axis]
1259      weights_sum = math_ops.cast(weights_sum, a.dtype)
1260  else:
1261    if np.issubdtype(a.dtype.as_numpy_dtype, np.inexact):
1262      out_dtype = np_utils.result_type(a.dtype, weights)
1263    else:
1264      out_dtype = np_utils.result_type(a.dtype, weights,
1265                                       np_dtypes.default_float_type())
1266    a = np_array_ops.array(a, out_dtype)
1267    weights = np_array_ops.array(weights, out_dtype)
1268
1269    def rank_equal_case():
1270      control_flow_ops.Assert(
1271          math_ops.reduce_all(array_ops.shape(a) == array_ops.shape(weights)),
1272          [array_ops.shape(a), array_ops.shape(weights)])
1273      weights_sum = math_ops.reduce_sum(weights, axis=axis)
1274      avg = math_ops.reduce_sum(a * weights, axis=axis) / weights_sum
1275      return avg, weights_sum
1276
1277    if axis is None:
1278      avg, weights_sum = rank_equal_case()
1279    else:
1280
1281      def rank_not_equal_case():
1282        control_flow_ops.Assert(
1283            array_ops.rank(weights) == 1, [array_ops.rank(weights)])
1284        weights_sum = math_ops.reduce_sum(weights)
1285        axes = ops.convert_to_tensor([[axis], [0]])
1286        avg = math_ops.tensordot(a, weights, axes) / weights_sum
1287        return avg, weights_sum
1288
1289      # We condition on rank rather than shape equality, because if we do the
1290      # latter, when the shapes are partially unknown but the ranks are known
1291      # and different, np_utils.cond will run shape checking on the true branch,
1292      # which will raise a shape-checking error.
1293      avg, weights_sum = np_utils.cond(
1294          math_ops.equal(array_ops.rank(a), array_ops.rank(weights)),
1295          rank_equal_case, rank_not_equal_case)
1296
1297  avg = np_array_ops.array(avg)
1298  if returned:
1299    weights_sum = np_array_ops.broadcast_to(weights_sum, array_ops.shape(avg))
1300    return avg, weights_sum
1301  return avg
1302
1303
1304@np_utils.np_doc('trace')
1305def trace(a, offset=0, axis1=0, axis2=1, dtype=None):  # pylint: disable=missing-docstring
1306  if dtype:
1307    dtype = np_utils.result_type(dtype)
1308  a = np_array_ops.asarray(a, dtype)
1309
1310  if offset == 0:
1311    a_shape = a.shape
1312    if a_shape.rank is not None:
1313      rank = len(a_shape)
1314      if (axis1 == -2 or axis1 == rank - 2) and (axis2 == -1 or
1315                                                 axis2 == rank - 1):
1316        return math_ops.trace(a)
1317
1318  a = np_array_ops.diagonal(a, offset, axis1, axis2)
1319  return np_array_ops.sum(a, -1, dtype)
1320
1321
1322@np_utils.np_doc('meshgrid')
1323def meshgrid(*xi, **kwargs):
1324  """This currently requires copy=True and sparse=False."""
1325  sparse = kwargs.get('sparse', False)
1326  if sparse:
1327    raise ValueError(
1328        'Function `meshgrid` does not support returning sparse arrays yet. '
1329        f'Received: sparse={sparse}')
1330
1331  copy = kwargs.get('copy', True)
1332  if not copy:
1333    raise ValueError('Function `meshgrid` only supports copy=True. '
1334                     f'Received: copy={copy}')
1335
1336  indexing = kwargs.get('indexing', 'xy')
1337
1338  xi = [np_array_ops.asarray(arg) for arg in xi]
1339  kwargs = {'indexing': indexing}
1340
1341  outputs = array_ops.meshgrid(*xi, **kwargs)
1342
1343  return outputs
1344
1345
1346# Uses np_doc_only here because np.einsum (in 1.16) doesn't have argument
1347# `subscripts`, even though the doc says it has.
1348@np_utils.np_doc_only('einsum')
1349def einsum(subscripts, *operands, **kwargs):  # pylint: disable=missing-docstring
1350  casting = kwargs.get('casting', 'safe')
1351  optimize = kwargs.get('optimize', False)
1352  if casting == 'safe':
1353    operands = np_array_ops._promote_dtype(*operands)  # pylint: disable=protected-access
1354  elif casting == 'no':
1355    operands = [np_array_ops.asarray(x) for x in operands]
1356  else:
1357    raise ValueError(
1358        'Invalid value for argument `casting`. '
1359        f'Expected casting="safe" or casting="no". Received: casting={casting}')
1360  if not optimize:
1361    # TF doesn't have a "no optimization" option.
1362    # TODO(wangpeng): Print a warning that np and tf use different
1363    #   optimizations.
1364    tf_optimize = 'greedy'
1365  elif optimize == True:  # pylint: disable=singleton-comparison,g-explicit-bool-comparison
1366    tf_optimize = 'greedy'
1367  elif optimize == 'greedy':
1368    tf_optimize = 'greedy'
1369  elif optimize == 'optimal':
1370    tf_optimize = 'optimal'
1371  else:
1372    raise ValueError(
1373        'Invalid value for argument `optimize`. '
1374        'Expected one of {True, "greedy", "optimal"}. '
1375        f'Received: optimize={optimize}')
1376
1377  res = special_math_ops.einsum(subscripts, *operands, optimize=tf_optimize)
1378  return res
1379
1380
1381def _tensor_t(self):
1382  """Returns a Tensor which is the transpose of this Tensor."""
1383  return self.transpose()
1384
1385
1386def _tensor_ndim(self):
1387  """Returns the rank of the Tensor."""
1388  return self.shape.ndims
1389
1390
1391def _tensor_pos(self):
1392  """Returns self, for unary operator `+`."""
1393  return self
1394
1395
1396def _tensor_size(self):
1397  """Returns the number of elements in this Tensor, if fully known."""
1398  if not self.shape.is_fully_defined():
1399    return None
1400  return np.prod(self.shape.as_list())
1401
1402
1403def _tensor_tolist(self):
1404  if isinstance(self, ops.EagerTensor):
1405    return self._numpy().tolist()  # pylint: disable=protected-access
1406
1407  raise ValueError('Symbolic Tensors do not support the tolist API.')
1408
1409
1410def enable_numpy_methods_on_tensor():
1411  """Adds additional NumPy methods on tf.Tensor class."""
1412  t = property(_tensor_t)
1413  setattr(ops.Tensor, 'T', t)
1414
1415  ndim = property(_tensor_ndim)
1416  setattr(ops.Tensor, 'ndim', ndim)
1417
1418  size = property(_tensor_size)
1419  setattr(ops.Tensor, 'size', size)
1420
1421  setattr(ops.Tensor, '__pos__', _tensor_pos)
1422  setattr(ops.Tensor, 'tolist', _tensor_tolist)
1423
1424  # TODO(b/178540516): Make a custom `setattr` that changes the method's
1425  #   docstring to the TF one.
1426  setattr(ops.Tensor, 'transpose', np_array_ops.transpose)
1427  setattr(ops.Tensor, 'reshape', np_array_ops._reshape_method_wrapper)  # pylint: disable=protected-access
1428  setattr(ops.Tensor, 'ravel', np_array_ops.ravel)
1429  setattr(ops.Tensor, 'clip', clip)
1430  setattr(ops.Tensor, 'astype', math_ops.cast)
1431  setattr(ops.Tensor, '__round__', np_array_ops.around)
1432  setattr(ops.Tensor, 'max', np_array_ops.amax)
1433  setattr(ops.Tensor, 'mean', np_array_ops.mean)
1434  setattr(ops.Tensor, 'min', np_array_ops.amin)
1435
1436  # TODO(wangpeng): Remove `data` when all uses of it are removed
1437  data = property(lambda self: self)
1438  setattr(ops.Tensor, 'data', data)
1439