1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""for_loop and pfor ops.""" 16# pylint: disable=g-direct-tensorflow-import 17 18import functools 19 20from tensorflow.python.eager import context 21from tensorflow.python.eager import def_function 22from tensorflow.python.autograph.core import ag_ctx as autograph_ctx 23from tensorflow.python.autograph.impl import api as autograph 24from tensorflow.python.framework import composite_tensor 25from tensorflow.python.framework import indexed_slices 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import sparse_tensor 28from tensorflow.python.framework import tensor_shape 29from tensorflow.python.framework import tensor_util 30from tensorflow.python.framework import type_spec 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import control_flow_ops 33from tensorflow.python.ops import math_ops 34from tensorflow.python.ops import tensor_array_ops 35from tensorflow.python.ops.parallel_for.pfor import PFor 36from tensorflow.python.ops.parallel_for.pfor import PForConfig 37from tensorflow.python.platform import tf_logging as logging 38from tensorflow.python.util import nest 39from tensorflow.python.util import tf_decorator 40from tensorflow.python.util import tf_inspect 41from tensorflow.python.util import variable_utils 42from tensorflow.python.util.tf_export import tf_export 43 44 45def for_loop(loop_fn, loop_fn_dtypes, iters, parallel_iterations=None): 46 """Runs `loop_fn` `iters` times and stacks the outputs. 47 48 49 Runs `loop_fn` `iters` times, with input values from 0 to `iters - 1`, and 50 stacks corresponding outputs of the different runs. 51 52 Args: 53 loop_fn: A function that takes an int32 scalar tf.Tensor object representing 54 the iteration number, and returns a possibly nested structure of tensor 55 objects. The shape of these outputs should not depend on the input. 56 loop_fn_dtypes: dtypes for the outputs of `loop_fn`. 57 iters: Number of iterations for which to run `loop_fn`. 58 parallel_iterations: The number of iterations that can be dispatched in 59 parallel. This knob can be used to control the total memory usage. 60 61 Returns: 62 Returns a nested structure of stacked output tensor objects with the same 63 nested structure as the output of `loop_fn`. 64 """ 65 66 flat_loop_fn_dtypes = nest.flatten(loop_fn_dtypes) 67 is_none_list = [] 68 69 def while_body(i, *ta_list): 70 """Body of while loop.""" 71 fn_conv = autograph.tf_convert(loop_fn, autograph_ctx.control_status_ctx()) 72 fn_output = nest.flatten(fn_conv(i)) 73 if len(fn_output) != len(flat_loop_fn_dtypes): 74 raise ValueError( 75 f"Number of expected outputs {len(flat_loop_fn_dtypes)}, does not " 76 f"match the number of actual outputs {len(fn_output)} from loop_fn: " 77 f"{loop_fn} with output {fn_output}.") 78 outputs = [] 79 del is_none_list[:] 80 is_none_list.extend(x is None for x in fn_output) 81 for out, ta in zip(fn_output, ta_list): 82 # TODO(agarwal): support returning Operation objects from loop_fn. 83 if out is not None: 84 # out may be a ref tensor, wrap it in identity to get a non-ref tensor. 85 ta = ta.write(i, array_ops.expand_dims(out, 0)) 86 outputs.append(ta) 87 return tuple([i + 1] + outputs) 88 89 if parallel_iterations is not None: 90 extra_args = {"parallel_iterations": parallel_iterations} 91 else: 92 extra_args = {} 93 ta_list = control_flow_ops.while_loop( 94 lambda i, *ta: i < iters, 95 while_body, 96 [0] + [tensor_array_ops.TensorArray(dtype.base_dtype, iters) 97 for dtype in flat_loop_fn_dtypes], 98 **extra_args)[1:] 99 100 # TODO(rachelim): enable this for sparse tensors 101 102 output = [None if is_none else ta.concat() 103 for ta, is_none in zip(ta_list, is_none_list)] 104 assert len(output) in (0, len(flat_loop_fn_dtypes)) 105 if not output: 106 # This may happen for the case where iters == 0. 107 return None 108 else: 109 return nest.pack_sequence_as(loop_fn_dtypes, output) 110 111 112def _flatten_first_two_dims(x): 113 """Flattens the first two dimensions of x into a single dimension.""" 114 old_shape = array_ops.shape(x) 115 new_shape = array_ops.concat([[old_shape[0] * old_shape[1]], old_shape[2:]], 116 axis=0) 117 return array_ops.reshape(x, new_shape) 118 119 120PFOR_CONFIG_ARG = "pfor_config" 121 122 123def _is_under_xla_context(): 124 """Check if we are currently inside an XLA compile context.""" 125 g = ops.get_default_graph() 126 while g is not None: 127 control_flow_context = g._get_control_flow_context() # pylint: disable=protected-access 128 while control_flow_context is not None: 129 if control_flow_context.IsXLAContext(): 130 return True 131 else: 132 control_flow_context = control_flow_context.outer_context 133 # If g is a FuncGraph, get its outer_graph. 134 g = getattr(g, "outer_graph", None) 135 return False 136 137 138def pfor(loop_fn, 139 iters, 140 fallback_to_while_loop=True, 141 parallel_iterations=None, 142 warn=False): 143 """Equivalent to running `loop_fn` `iters` times and stacking the outputs. 144 145 `pfor` has functionality similar to `for_loop`, i.e. running `loop_fn` `iters` 146 times, with input from 0 to `iters - 1`, and stacking corresponding output of 147 each iteration. However the implementation does not use a `tf.while_loop`. 148 Instead it adds new operations to the graph that collectively compute the same 149 value as what running `loop_fn` in a loop would compute. 150 151 152 This is an experimental feature and currently has a lot of limitations: 153 - There should be no data dependency between the different iterations. For 154 example, a future iteration should not depend on a value or side-effect of 155 a previous iteration. 156 - Stateful kernels may mostly not be supported since these often imply a 157 data dependency or ordering of the iterations. We do support a limited set 158 of such stateful kernels though (like RandomFoo, Variable operations like 159 reads, etc). 160 - Conversion works only on a limited set of kernels for which a converter 161 has been registered. 162 - `loop_fn` has limited support for control flow operations. `tf.cond` in 163 particular is not supported. 164 - `loop_fn` should return nested structure of Tensors or Operations. However 165 if an Operation is returned, it should have zero outputs. 166 - The shape and dtype of `loop_fn` outputs should not depend on the input 167 to loop_fn. 168 169 Args: 170 loop_fn: A function that takes an int32 scalar tf.Tensor object representing 171 the iteration number, and optionally a keyword argument `pfor_config` set 172 to a PForConfig object. It returns a possibly nested structure of Tensor 173 or Operation objects. Note that if setting `parallel_iterations` argument 174 to something other than None, `loop_fn` may be called more than once 175 during graph construction. So it may need to avoid mutating global state. 176 iters: Number of iterations for which to run `loop_fn`. 177 fallback_to_while_loop: If true, on failing to vectorize an operation, pfor 178 fallbacks to using a `tf.while_loop` to dispatch the iterations. 179 parallel_iterations: A knob to control how many iterations are vectorized 180 and dispatched in parallel. The default value of None corresponds to 181 vectorizing all the iterations. If `parallel_iterations` is smaller than 182 `iters`, then chunks of at most that many iterations are dispatched in 183 sequence. This knob can be used to control the total memory usage. 184 warn: Whether or not to warn when falling back to while loops. 185 186 Returns: 187 Returns a nested structure of stacked tensor objects with the same nested 188 structure as the output of `loop_fn`. 189 Raises: 190 ValueError: If parallel_iterations is not None and not an integer > 1. 191 """ 192 def f(): 193 return _pfor_impl( 194 loop_fn, 195 iters, 196 fallback_to_while_loop=fallback_to_while_loop, 197 parallel_iterations=parallel_iterations, 198 warn=warn) 199 # Note that we wrap into a tf.function if in eager execution mode or under 200 # XLA compilation. The latter is so that we don't compile operations like 201 # tf.placeholder that are created by the loop body. 202 functions_run_eagerly = None 203 if context.executing_eagerly() or _is_under_xla_context(): 204 functions_run_eagerly = def_function.functions_run_eagerly() 205 if functions_run_eagerly: 206 logging.warning( 207 "It looks like tf.function behavior was disabled, perhaps using " 208 "tf.config.run_functions_eagerly. Vectorization " 209 "primitives (e.g. tf.vectorized_map) require tf.function to work. " 210 "These primitives will override the disable.") 211 def_function.run_functions_eagerly(False) 212 f = def_function.function(f) 213 214 outputs = f() 215 if functions_run_eagerly is not None: 216 def_function.run_functions_eagerly(functions_run_eagerly) 217 return outputs 218 219 220def _should_expand_composite(value): 221 return (isinstance(value, composite_tensor.CompositeTensor) 222 # Leave sparse tensors to be converted by `PFor._convert_sparse`. 223 and not isinstance(value, sparse_tensor.SparseTensor) 224 and not isinstance(value, indexed_slices.IndexedSlices)) 225 226 227# pylint: disable=protected-access 228def _composite_to_tensors(value, is_batched=False): 229 """Converts a CompositeTensor into a list of stackable tensors.""" 230 if _should_expand_composite(value): 231 spec = value._type_spec 232 if not isinstance(spec, type_spec.BatchableTypeSpec): 233 raise ValueError(f"CompositeTensor instance {value} returned from " 234 "parallel_for or vectorized_map loop body must provide " 235 f"a `BatchableTypeSpec` (saw: {spec}).") 236 if is_batched: 237 return spec._to_batched_tensor_list(value) 238 return spec._to_tensor_list(value) 239 return value 240# pylint: enable=protected-access 241 242 243# pylint: disable=protected-access 244def _composite_from_tensors(stacked_tensors, 245 preconverted_value, 246 batch_size): 247 """Converts a list of stacked tensors to a batch CompositeTensor.""" 248 if _should_expand_composite(preconverted_value): 249 batch_type_spec = preconverted_value._type_spec._batch(batch_size) 250 return batch_type_spec._from_compatible_tensor_list(stacked_tensors) 251 return stacked_tensors 252# pylint: enable=protected-access 253 254 255def _loop_fn_has_config(loop_fn): 256 """Test if `loop_fn` has a `pfor_config` argument.""" 257 if tf_inspect.isfunction(loop_fn): 258 argspec = tf_inspect.getargspec(loop_fn) 259 return PFOR_CONFIG_ARG in argspec.args 260 elif isinstance(loop_fn, functools.partial): 261 fn = loop_fn.func 262 argspec = tf_inspect.getargspec(fn) 263 return (PFOR_CONFIG_ARG in argspec.args and 264 PFOR_CONFIG_ARG not in loop_fn.keywords) 265 else: 266 loop_class = tf_decorator.unwrap(loop_fn)[1] 267 if not hasattr(loop_class, "__call__"): 268 raise ValueError("`loop_fn` object did not have a __call__ method") 269 argspec = tf_inspect.getargspec(loop_class.__call__) 270 return PFOR_CONFIG_ARG in argspec.args 271 272 273def _pfor_impl(loop_fn, 274 iters, 275 fallback_to_while_loop, 276 parallel_iterations=None, 277 pfor_config=None, 278 warn=False): 279 """Implementation of pfor.""" 280 assert not context.executing_eagerly() 281 loop_fn_has_config = _loop_fn_has_config(loop_fn) 282 existing_ops = set(ops.get_default_graph().get_operations()) 283 iters_value = tensor_util.constant_value(iters) 284 # Run the loop body 285 with ops.name_scope("loop_body"): 286 loop_var = array_ops.placeholder_with_default(0, shape=[]) 287 if loop_fn_has_config: 288 if pfor_config is None: 289 pfor_config = PForConfig() 290 pfor_config._set_iters(iters) # pylint: disable=protected-access 291 loop_fn_outputs = loop_fn(loop_var, **{PFOR_CONFIG_ARG: pfor_config}) 292 else: 293 assert pfor_config is None 294 f = autograph.tf_convert(loop_fn, autograph_ctx.control_status_ctx()) 295 loop_fn_outputs = f(loop_var) 296 loop_fn_output_tensors = nest.map_structure(_composite_to_tensors, 297 loop_fn_outputs) 298 299 # Convert outputs to Tensor if needed. 300 tmp_loop_fn_outputs = [] 301 for loop_fn_output in nest.flatten(loop_fn_output_tensors): 302 if (loop_fn_output is not None and not isinstance( 303 loop_fn_output, 304 (ops.Operation, ops.Tensor, sparse_tensor.SparseTensor))): 305 if isinstance(loop_fn_output, indexed_slices.IndexedSlices): 306 logging.warn("Converting %s to a dense representation may make it slow." 307 " Alternatively, output the indices and values of the" 308 " IndexedSlices separately, and handle the vectorized" 309 " outputs directly." % loop_fn_output) 310 loop_fn_output = ops.convert_to_tensor(loop_fn_output) 311 else: 312 loop_fn_output = ops.convert_to_tensor(loop_fn_output) 313 tmp_loop_fn_outputs.append(loop_fn_output) 314 loop_fn_output_tensors = nest.pack_sequence_as(loop_fn_output_tensors, 315 tmp_loop_fn_outputs) 316 317 new_ops = set(ops.get_default_graph().get_operations()) - existing_ops 318 iters = ops.convert_to_tensor(iters) 319 if parallel_iterations is not None: 320 if parallel_iterations < 1: 321 raise ValueError( 322 "Argument `parallel_iterations` must be None or a positive integer. " 323 f"Received: {parallel_iterations}.") 324 if parallel_iterations == 1: 325 raise ValueError( 326 "Found `parallel_iterations == 1`. Use `for_loop` instead.") 327 if iters_value is not None and iters_value < parallel_iterations: 328 parallel_iterations = None 329 if parallel_iterations is None: 330 with ops.name_scope("pfor"): 331 converter = PFor( 332 loop_var, 333 iters, 334 new_ops, 335 fallback_to_while_loop=fallback_to_while_loop, 336 pfor_config=pfor_config, 337 warn=warn) 338 flattened_output_tensors = [] 339 for loop_fn_output in nest.flatten(loop_fn_output_tensors): 340 output = converter.convert(loop_fn_output) 341 flattened_output_tensors.append(output) 342 else: 343 if pfor_config is not None and pfor_config._has_reductions(): # pylint: disable=protected-access 344 raise ValueError("Setting `parallel_iterations` currently unsupported if " 345 "reductions across iterations are performed.") 346 num_tiled_iterations = iters // parallel_iterations 347 num_remaining_iterations = iters % parallel_iterations 348 # TODO(agarwal): Avoid calling loop_fn twice. Generate the loop body inside 349 # a tf.function and extract the graph from there to vectorize it. 350 with ops.name_scope("pfor_untiled"): 351 converter = PFor(loop_var, num_remaining_iterations, new_ops, 352 fallback_to_while_loop=fallback_to_while_loop, 353 pfor_config=pfor_config) 354 remaining_output_tensors = [] 355 flattened_output_tensors = nest.flatten(loop_fn_output_tensors) 356 for loop_fn_output in flattened_output_tensors: 357 output = converter.convert(loop_fn_output) 358 remaining_output_tensors.append(output) 359 360 with ops.name_scope("pfor_tiled"): 361 loop_fn_dtypes = [ops.convert_to_tensor(x).dtype 362 for x in flattened_output_tensors] 363 364 def tiled_loop_body(j): 365 offset = j * parallel_iterations + num_remaining_iterations 366 367 def tiled_loop_fn(i, pfor_config=None): 368 if loop_fn_has_config: 369 loop_fn_outputs = loop_fn(i + offset, pfor_config=pfor_config) 370 else: 371 loop_fn_outputs = loop_fn(i + offset) 372 return nest.flatten( 373 # Stacking across iterations requires explicit Tensors. 374 nest.map_structure(_composite_to_tensors, loop_fn_outputs)) 375 376 return _pfor_impl( 377 tiled_loop_fn, 378 parallel_iterations, 379 fallback_to_while_loop=fallback_to_while_loop, 380 pfor_config=pfor_config) 381 382 tiled_output_tensors = for_loop( 383 tiled_loop_body, loop_fn_dtypes, 384 num_tiled_iterations, parallel_iterations=1) 385 tiled_output_tensors = [ 386 _flatten_first_two_dims(y) for y in tiled_output_tensors] 387 388 with ops.name_scope("pfor"): 389 if iters_value is None or iters_value % parallel_iterations: 390 output_tensors = control_flow_ops.cond( 391 math_ops.equal(num_remaining_iterations, 0), 392 lambda: tiled_output_tensors, 393 lambda: [array_ops.concat([x, y], axis=0) # pylint: disable=g-long-lambda 394 for x, y in zip(remaining_output_tensors, 395 tiled_output_tensors)]) 396 else: 397 output_tensors = tiled_output_tensors 398 flattened_output_tensors = nest.flatten(output_tensors) 399 400 for output, original_output in zip(flattened_output_tensors, 401 nest.flatten(loop_fn_output_tensors)): 402 # Restore any shape information lost from tiling. 403 # TODO(b/174254748): this may not be correct for stacked `variant`s. 404 output.set_shape( 405 tensor_shape.TensorShape([iters_value]).concatenate( 406 original_output.shape)) 407 408 return nest.map_structure_up_to( 409 loop_fn_outputs, 410 functools.partial(_composite_from_tensors, batch_size=iters_value), 411 nest.pack_sequence_as(loop_fn_output_tensors, 412 flattened_output_tensors), 413 loop_fn_outputs) 414 415 416def _broadcasting_gather(x, i): 417 """Wrapper for gather that implicitly broadcasts unit dimensions.""" 418 static_first_dim = tensor_shape.dimension_value(x.shape[0]) 419 if static_first_dim == 1: 420 i = 0 421 elif static_first_dim is None: 422 i = array_ops.where_v2(array_ops.shape(x)[0] > 1, i, 0) 423 result = array_ops.gather(x, i) 424 return result 425 426 427# pylint: disable=protected-access 428def _gather_from_tensor_or_composite(x, i): 429 """Wrapper for gather that handles CompositeTensors.""" 430 if _should_expand_composite(x): 431 spec = x._type_spec 432 gathered_tensors = [_broadcasting_gather(t, i) 433 for t in spec._to_batched_tensor_list(x)] 434 return spec._unbatch()._from_compatible_tensor_list(gathered_tensors) 435 return _broadcasting_gather(x, i) 436# pylint: enable=protected-access 437 438 439@tf_export("vectorized_map") 440def vectorized_map(fn, elems, fallback_to_while_loop=True, warn=True): 441 """Parallel map on the list of tensors unpacked from `elems` on dimension 0. 442 443 This method works similar to `tf.map_fn` but is optimized to run much faster, 444 possibly with a much larger memory footprint. The speedups are obtained by 445 vectorization (see [Auto-Vectorizing TensorFlow Graphs: Jacobians, 446 Auto-Batching and Beyond](https://arxiv.org/pdf/1903.04243.pdf)). The idea 447 behind vectorization is to semantically launch all the invocations of `fn` in 448 parallel and fuse corresponding operations across all these invocations. This 449 fusion is done statically at graph generation time and the generated code is 450 often similar in performance to a manually fused version. 451 452 Because `tf.vectorized_map` fully parallelizes the batch, this method will 453 generally be significantly faster than using `tf.map_fn`, especially in eager 454 mode. However this is an experimental feature and currently has a lot of 455 limitations: 456 - There should be no data dependency between the different semantic 457 invocations of `fn`, i.e. it should be safe to map the elements of the 458 inputs in any order. 459 - Stateful kernels may mostly not be supported since these often imply a 460 data dependency. We do support a limited set of such stateful kernels 461 though (like RandomFoo, Variable operations like reads, etc). 462 - `fn` has limited support for control flow operations. 463 - `fn` should return nested structure of Tensors or Operations. However 464 if an Operation is returned, it should have zero outputs. 465 - The shape and dtype of any intermediate or output tensors in the 466 computation of `fn` should not depend on the input to `fn`. 467 468 Examples: 469 ```python 470 def outer_product(a): 471 return tf.tensordot(a, a, 0) 472 473 batch_size = 100 474 a = tf.ones((batch_size, 32, 32)) 475 c = tf.vectorized_map(outer_product, a) 476 assert c.shape == (batch_size, 32, 32, 32, 32) 477 ``` 478 479 ```python 480 # Computing per-example gradients 481 482 batch_size = 10 483 num_features = 32 484 layer = tf.keras.layers.Dense(1) 485 486 def model_fn(arg): 487 with tf.GradientTape() as g: 488 inp, label = arg 489 inp = tf.expand_dims(inp, 0) 490 label = tf.expand_dims(label, 0) 491 prediction = layer(inp) 492 loss = tf.nn.l2_loss(label - prediction) 493 return g.gradient(loss, (layer.kernel, layer.bias)) 494 495 inputs = tf.random.uniform([batch_size, num_features]) 496 labels = tf.random.uniform([batch_size, 1]) 497 per_example_gradients = tf.vectorized_map(model_fn, (inputs, labels)) 498 assert per_example_gradients[0].shape == (batch_size, num_features, 1) 499 assert per_example_gradients[1].shape == (batch_size, 1) 500 ``` 501 502 Args: 503 fn: The callable to be performed. It accepts one argument, which will have 504 the same (possibly nested) structure as `elems`, and returns a possibly 505 nested structure of Tensors and Operations, which may be different than 506 the structure of `elems`. 507 elems: A tensor or (possibly nested) sequence of tensors, each of which will 508 be unpacked along their first dimension. The nested sequence of the 509 resulting slices will be mapped over by `fn`. The first dimensions of all 510 elements must broadcast to a consistent value; equivalently, each 511 element tensor must have first dimension of either `B` or `1`, for some 512 common batch size `B >= 1`. 513 fallback_to_while_loop: If true, on failing to vectorize an operation, 514 the unsupported op is wrapped in a tf.while_loop to execute the map 515 iterations. Note that this fallback only happens for unsupported ops and 516 other parts of `fn` are still vectorized. If false, on encountering an 517 unsupported op, a ValueError is thrown. Note that the fallbacks can result 518 in slowdowns since vectorization often yields speedup of one to two orders 519 of magnitude. 520 warn: If set to `false`, this will supress any warnings due to operation 521 conversions in the provided `fn` falling back to while loops. 522 523 Returns: 524 A tensor or (possibly nested) sequence of tensors. Each tensor packs the 525 results of applying fn to tensors unpacked from elems along the first 526 dimension, from first to last. 527 528 Although they are less common as user-visible inputs and outputs, note that 529 tensors of type `tf.variant` which represent tensor lists (for example from 530 `tf.raw_ops.TensorListFromTensor`) are vectorized by stacking the list 531 contents rather than the variant itself, and so the container tensor will 532 have a scalar shape when returned rather than the usual stacked shape. This 533 improves the performance of control flow gradient vectorization. 534 535 Raises: 536 ValueError: If vectorization fails and fallback_to_while_loop is False. 537 """ 538 elems = variable_utils.convert_variables_to_tensors(elems) 539 elems = nest.map_structure(ops.convert_to_tensor, 540 elems, 541 expand_composites=True) 542 543 def loop_fn(i): 544 gathered_elems = nest.map_structure( 545 lambda x: _gather_from_tensor_or_composite(x, i), elems) 546 return fn(gathered_elems) 547 548 # Extract batch size from the maximum first dimension of any element. 549 flat_elems = nest.flatten( 550 nest.map_structure( 551 functools.partial(_composite_to_tensors, 552 is_batched=True), 553 elems)) 554 def _get_shape(x): 555 if x.shape.rank is None: 556 return None 557 return x.shape.as_list()[0] 558 static_first_dims = [_get_shape(elem) for elem in flat_elems] 559 if any(s is None for s in static_first_dims): 560 batch_size = math_ops.reduce_max( 561 [array_ops.shape(elem)[0] for elem in flat_elems]) 562 else: 563 batch_size = max(static_first_dims) 564 565 return pfor( 566 loop_fn, 567 batch_size, 568 fallback_to_while_loop=fallback_to_while_loop, 569 warn=warn) 570