xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/inplace_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Inplace operations.
17"""
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import ops
20from tensorflow.python.ops import array_ops
21from tensorflow.python.ops import gen_array_ops
22from tensorflow.python.ops import math_ops
23from tensorflow.python.util import deprecation
24
25
26def _inplace_helper(x, i, v, op):
27  """Applies an inplace op on (x, i, v).
28
29  op is one of gen_array_ops.alias_inplace_update,
30  gen_array_ops.alias_inplace_add, or gen_array_ops.alias_inplace_sub.
31
32  If i is None, x and v must be the same shape. Computes
33    x op v;
34  If i is a scalar, x has a rank 1 higher than v's. Computes
35    x[i, :] op v;
36  Otherwise, x and v must have the same rank. Computes
37    x[i, :] op v;
38
39  Args:
40    x: A Tensor.
41    i: None, a scalar or a vector.
42    v: A Tensor.
43    op: alias_inplace_update, alias_inplace_add, or alias_inplace_sub.
44
45  Returns:
46    Returns x.
47
48  """
49  x = ops.convert_to_tensor(x)
50  v = ops.convert_to_tensor(v, x.dtype)
51  if i is None:
52    # Full tensor.
53    return array_ops.reshape(
54        op(array_ops.reshape(x, [1, -1]), [0], array_ops.reshape(v, [1, -1])),
55        array_ops.shape(x))
56  i = math_ops.cast(i, dtypes.int32)
57  if i.get_shape().ndims == 0:
58    # Single 0-dim update.
59    return op(x, array_ops.reshape(i, [1]), array_ops.expand_dims(v, 0))
60  return op(x, i, v)
61
62
63@deprecation.deprecated(
64    None,
65    ('Prefer tf.tensor_scatter_nd_update, which offers the same functionality '
66     'with well-defined read-write semantics.'))
67def alias_inplace_update(x, i, v):
68  """Applies an inplace update on input x at index i with value v. Aliases x.
69
70  If i is None, x and v must be the same shape. Computes
71    x = v;
72  If i is a scalar, x has a rank 1 higher than v's. Computes
73    x[i, :] = v;
74  Otherwise, x and v must have the same rank. Computes
75    x[i, :] = v;
76
77  Args:
78    x: A Tensor.
79    i: None, a scalar or a vector.
80    v: A Tensor.
81
82  Returns:
83    Returns x.
84
85  """
86  return _inplace_helper(x, i, v, gen_array_ops.inplace_update)
87
88
89@deprecation.deprecated(
90    None,
91    ('Prefer tf.tensor_scatter_nd_add, which offers the same functionality '
92     'with well-defined read-write semantics.'))
93def alias_inplace_add(x, i, v):
94  """Applies an inplace add on input x at index i with value v. Aliases x.
95
96  If i is None, x and v must be the same shape. Computes
97    x += v;
98  If i is a scalar, x has a rank 1 higher than v's. Computes
99    x[i, :] += v;
100  Otherwise, x and v must have the same rank. Computes
101    x[i, :] += v;
102
103  Args:
104    x: A Tensor.
105    i: None, a scalar or a vector.
106    v: A Tensor.
107
108  Returns:
109    Returns x.
110
111  """
112  return _inplace_helper(x, i, v, gen_array_ops.inplace_add)
113
114
115@deprecation.deprecated(
116    None,
117    ('Prefer tf.tensor_scatter_nd_sub, which offers the same functionality '
118     'with well-defined read-write semantics.'))
119def alias_inplace_sub(x, i, v):
120  """Applies an inplace sub on input x at index i with value v. Aliases x.
121
122  If i is None, x and v must be the same shape. Computes
123    x -= v;
124  If i is a scalar, x has a rank 1 higher than v's. Computes
125    x[i, :] -= v;
126  Otherwise, x and v must have the same rank. Computes
127    x[i, :] -= v;
128
129  Args:
130    x: A Tensor.
131    i: None, a scalar or a vector.
132    v: A Tensor.
133
134  Returns:
135    Returns x.
136
137  """
138  return _inplace_helper(x, i, v, gen_array_ops.inplace_sub)
139
140
141def empty_like(x, init=None):
142  """Returns a non-initialized tensor with the same shape and dtype as x.
143
144  Args:
145    x: A Tensor.
146    init: Initialize the returned tensor with the default value of
147      x.dtype(), if True. Otherwise, do not initialize. Defaults to
148      None.
149
150  Returns:
151    A tensor y, whose dtype and shape are the same as those of x.
152    y is guaranteed not to be an alias of x. Upon return, y may contain
153    arbitrary data.
154
155  """
156  x = ops.convert_to_tensor(x)
157  return gen_array_ops.empty(array_ops.shape(x), x.dtype, init=init)
158
159
160@deprecation.deprecated(
161    None,
162    ('Prefer tf.tensor_scatter_nd_update, which offers the same functionality '
163     'with well-defined read-write semantics.'))
164def inplace_update(x, i, v):
165  """Applies an inplace update on input x at index i with value v.
166
167  Note that this function is not actually inplace - it allocates
168  a copy of x.  The utility is not avoiding memory copies but rather
169  specifying a sparse update.
170
171  If i is None, x and v must be the same shape. Computes
172    y = x; y = v;
173  If i is a scalar, x has a rank 1 higher than v's. Computes
174    y = x; y[i, :] = v;
175  Otherwise, x and v must have the same rank. Computes
176    y = x; y[i, :] = v;
177
178  Args:
179    x: A Tensor.
180    i: None, a scalar or a vector.
181    v: A Tensor.
182
183  Returns:
184    Returns y, which is guaranteed not to be an alias of x.
185
186  """
187  return alias_inplace_update(gen_array_ops.deep_copy(x), i, v)
188
189
190@deprecation.deprecated(
191    None,
192    ('Prefer tf.tensor_scatter_nd_add, which offers the same functionality '
193     'with well-defined read-write semantics.'))
194def inplace_add(x, i, v):
195  """Applies an inplace add on input x at index i with value v.
196
197  Note that this function is not actually inplace - it allocates
198  a copy of x.  The utility is not avoiding memory copies but rather
199  specifying a sparse update.
200
201  If i is None, x and v must be the same shape. Computes
202    y = x; y += v;
203  If i is a scalar, x has a rank 1 higher than v's. Computes
204    y = x; y[i, :] += v;
205  Otherwise, x and v must have the same rank. Computes
206    y = x; y[i, :] += v;
207
208  Args:
209    x: A Tensor.
210    i: None, a scalar or a vector.
211    v: A Tensor.
212
213  Returns:
214    Returns y, which is guaranteed not to be an alias of x.
215
216  """
217  return alias_inplace_add(gen_array_ops.deep_copy(x), i, v)
218
219
220@deprecation.deprecated(
221    None,
222    ('Prefer tf.tensor_scatter_nd_sub, which offers the same functionality '
223     'with well-defined read-write semantics.'))
224def inplace_sub(x, i, v):
225  """Applies an inplace sub on input x at index i with value v.
226
227  Note that this function is not actually inplace - it allocates
228  a copy of x.  The utility is not avoiding memory copies but rather
229  specifying a sparse update.
230
231  If i is None, x and v must be the same shape. Computes
232    y = x; y -= v;
233  If i is a scalar, x has a rank 1 higher than v's. Computes
234    y = x; y[i, :] -= v;
235  Otherwise, x and v must have the same rank. Computes
236    y = x; y[i, :] -= v;
237
238  Args:
239    x: A Tensor.
240    i: None, a scalar or a vector.
241    v: A Tensor.
242
243  Returns:
244    Returns y, which is guaranteed not to be an alias of x.
245
246  """
247  return alias_inplace_sub(gen_array_ops.deep_copy(x), i, v)
248
249empty = gen_array_ops.empty
250