1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Inplace operations. 17""" 18from tensorflow.python.framework import dtypes 19from tensorflow.python.framework import ops 20from tensorflow.python.ops import array_ops 21from tensorflow.python.ops import gen_array_ops 22from tensorflow.python.ops import math_ops 23from tensorflow.python.util import deprecation 24 25 26def _inplace_helper(x, i, v, op): 27 """Applies an inplace op on (x, i, v). 28 29 op is one of gen_array_ops.alias_inplace_update, 30 gen_array_ops.alias_inplace_add, or gen_array_ops.alias_inplace_sub. 31 32 If i is None, x and v must be the same shape. Computes 33 x op v; 34 If i is a scalar, x has a rank 1 higher than v's. Computes 35 x[i, :] op v; 36 Otherwise, x and v must have the same rank. Computes 37 x[i, :] op v; 38 39 Args: 40 x: A Tensor. 41 i: None, a scalar or a vector. 42 v: A Tensor. 43 op: alias_inplace_update, alias_inplace_add, or alias_inplace_sub. 44 45 Returns: 46 Returns x. 47 48 """ 49 x = ops.convert_to_tensor(x) 50 v = ops.convert_to_tensor(v, x.dtype) 51 if i is None: 52 # Full tensor. 53 return array_ops.reshape( 54 op(array_ops.reshape(x, [1, -1]), [0], array_ops.reshape(v, [1, -1])), 55 array_ops.shape(x)) 56 i = math_ops.cast(i, dtypes.int32) 57 if i.get_shape().ndims == 0: 58 # Single 0-dim update. 59 return op(x, array_ops.reshape(i, [1]), array_ops.expand_dims(v, 0)) 60 return op(x, i, v) 61 62 63@deprecation.deprecated( 64 None, 65 ('Prefer tf.tensor_scatter_nd_update, which offers the same functionality ' 66 'with well-defined read-write semantics.')) 67def alias_inplace_update(x, i, v): 68 """Applies an inplace update on input x at index i with value v. Aliases x. 69 70 If i is None, x and v must be the same shape. Computes 71 x = v; 72 If i is a scalar, x has a rank 1 higher than v's. Computes 73 x[i, :] = v; 74 Otherwise, x and v must have the same rank. Computes 75 x[i, :] = v; 76 77 Args: 78 x: A Tensor. 79 i: None, a scalar or a vector. 80 v: A Tensor. 81 82 Returns: 83 Returns x. 84 85 """ 86 return _inplace_helper(x, i, v, gen_array_ops.inplace_update) 87 88 89@deprecation.deprecated( 90 None, 91 ('Prefer tf.tensor_scatter_nd_add, which offers the same functionality ' 92 'with well-defined read-write semantics.')) 93def alias_inplace_add(x, i, v): 94 """Applies an inplace add on input x at index i with value v. Aliases x. 95 96 If i is None, x and v must be the same shape. Computes 97 x += v; 98 If i is a scalar, x has a rank 1 higher than v's. Computes 99 x[i, :] += v; 100 Otherwise, x and v must have the same rank. Computes 101 x[i, :] += v; 102 103 Args: 104 x: A Tensor. 105 i: None, a scalar or a vector. 106 v: A Tensor. 107 108 Returns: 109 Returns x. 110 111 """ 112 return _inplace_helper(x, i, v, gen_array_ops.inplace_add) 113 114 115@deprecation.deprecated( 116 None, 117 ('Prefer tf.tensor_scatter_nd_sub, which offers the same functionality ' 118 'with well-defined read-write semantics.')) 119def alias_inplace_sub(x, i, v): 120 """Applies an inplace sub on input x at index i with value v. Aliases x. 121 122 If i is None, x and v must be the same shape. Computes 123 x -= v; 124 If i is a scalar, x has a rank 1 higher than v's. Computes 125 x[i, :] -= v; 126 Otherwise, x and v must have the same rank. Computes 127 x[i, :] -= v; 128 129 Args: 130 x: A Tensor. 131 i: None, a scalar or a vector. 132 v: A Tensor. 133 134 Returns: 135 Returns x. 136 137 """ 138 return _inplace_helper(x, i, v, gen_array_ops.inplace_sub) 139 140 141def empty_like(x, init=None): 142 """Returns a non-initialized tensor with the same shape and dtype as x. 143 144 Args: 145 x: A Tensor. 146 init: Initialize the returned tensor with the default value of 147 x.dtype(), if True. Otherwise, do not initialize. Defaults to 148 None. 149 150 Returns: 151 A tensor y, whose dtype and shape are the same as those of x. 152 y is guaranteed not to be an alias of x. Upon return, y may contain 153 arbitrary data. 154 155 """ 156 x = ops.convert_to_tensor(x) 157 return gen_array_ops.empty(array_ops.shape(x), x.dtype, init=init) 158 159 160@deprecation.deprecated( 161 None, 162 ('Prefer tf.tensor_scatter_nd_update, which offers the same functionality ' 163 'with well-defined read-write semantics.')) 164def inplace_update(x, i, v): 165 """Applies an inplace update on input x at index i with value v. 166 167 Note that this function is not actually inplace - it allocates 168 a copy of x. The utility is not avoiding memory copies but rather 169 specifying a sparse update. 170 171 If i is None, x and v must be the same shape. Computes 172 y = x; y = v; 173 If i is a scalar, x has a rank 1 higher than v's. Computes 174 y = x; y[i, :] = v; 175 Otherwise, x and v must have the same rank. Computes 176 y = x; y[i, :] = v; 177 178 Args: 179 x: A Tensor. 180 i: None, a scalar or a vector. 181 v: A Tensor. 182 183 Returns: 184 Returns y, which is guaranteed not to be an alias of x. 185 186 """ 187 return alias_inplace_update(gen_array_ops.deep_copy(x), i, v) 188 189 190@deprecation.deprecated( 191 None, 192 ('Prefer tf.tensor_scatter_nd_add, which offers the same functionality ' 193 'with well-defined read-write semantics.')) 194def inplace_add(x, i, v): 195 """Applies an inplace add on input x at index i with value v. 196 197 Note that this function is not actually inplace - it allocates 198 a copy of x. The utility is not avoiding memory copies but rather 199 specifying a sparse update. 200 201 If i is None, x and v must be the same shape. Computes 202 y = x; y += v; 203 If i is a scalar, x has a rank 1 higher than v's. Computes 204 y = x; y[i, :] += v; 205 Otherwise, x and v must have the same rank. Computes 206 y = x; y[i, :] += v; 207 208 Args: 209 x: A Tensor. 210 i: None, a scalar or a vector. 211 v: A Tensor. 212 213 Returns: 214 Returns y, which is guaranteed not to be an alias of x. 215 216 """ 217 return alias_inplace_add(gen_array_ops.deep_copy(x), i, v) 218 219 220@deprecation.deprecated( 221 None, 222 ('Prefer tf.tensor_scatter_nd_sub, which offers the same functionality ' 223 'with well-defined read-write semantics.')) 224def inplace_sub(x, i, v): 225 """Applies an inplace sub on input x at index i with value v. 226 227 Note that this function is not actually inplace - it allocates 228 a copy of x. The utility is not avoiding memory copies but rather 229 specifying a sparse update. 230 231 If i is None, x and v must be the same shape. Computes 232 y = x; y -= v; 233 If i is a scalar, x has a rank 1 higher than v's. Computes 234 y = x; y[i, :] -= v; 235 Otherwise, x and v must have the same rank. Computes 236 y = x; y[i, :] -= v; 237 238 Args: 239 x: A Tensor. 240 i: None, a scalar or a vector. 241 v: A Tensor. 242 243 Returns: 244 Returns y, which is guaranteed not to be an alias of x. 245 246 """ 247 return alias_inplace_sub(gen_array_ops.deep_copy(x), i, v) 248 249empty = gen_array_ops.empty 250