xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/parallel_for/control_flow_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""for_loop and pfor ops."""
16# pylint: disable=g-direct-tensorflow-import
17
18import functools
19
20from tensorflow.python.eager import context
21from tensorflow.python.eager import def_function
22from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
23from tensorflow.python.autograph.impl import api as autograph
24from tensorflow.python.framework import composite_tensor
25from tensorflow.python.framework import indexed_slices
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import sparse_tensor
28from tensorflow.python.framework import tensor_shape
29from tensorflow.python.framework import tensor_util
30from tensorflow.python.framework import type_spec
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import control_flow_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops import tensor_array_ops
35from tensorflow.python.ops.parallel_for.pfor import PFor
36from tensorflow.python.ops.parallel_for.pfor import PForConfig
37from tensorflow.python.platform import tf_logging as logging
38from tensorflow.python.util import nest
39from tensorflow.python.util import tf_decorator
40from tensorflow.python.util import tf_inspect
41from tensorflow.python.util import variable_utils
42from tensorflow.python.util.tf_export import tf_export
43
44
45def for_loop(loop_fn, loop_fn_dtypes, iters, parallel_iterations=None):
46  """Runs `loop_fn` `iters` times and stacks the outputs.
47
48
49  Runs `loop_fn` `iters` times, with input values from 0 to `iters - 1`, and
50  stacks corresponding outputs of the different runs.
51
52  Args:
53    loop_fn: A function that takes an int32 scalar tf.Tensor object representing
54      the iteration number, and returns a possibly nested structure of tensor
55      objects. The shape of these outputs should not depend on the input.
56    loop_fn_dtypes: dtypes for the outputs of `loop_fn`.
57    iters: Number of iterations for which to run `loop_fn`.
58    parallel_iterations: The number of iterations that can be dispatched in
59      parallel. This knob can be used to control the total memory usage.
60
61  Returns:
62    Returns a nested structure of stacked output tensor objects with the same
63    nested structure as the output of `loop_fn`.
64  """
65
66  flat_loop_fn_dtypes = nest.flatten(loop_fn_dtypes)
67  is_none_list = []
68
69  def while_body(i, *ta_list):
70    """Body of while loop."""
71    fn_conv = autograph.tf_convert(loop_fn, autograph_ctx.control_status_ctx())
72    fn_output = nest.flatten(fn_conv(i))
73    if len(fn_output) != len(flat_loop_fn_dtypes):
74      raise ValueError(
75          f"Number of expected outputs {len(flat_loop_fn_dtypes)}, does not "
76          f"match the number of actual outputs {len(fn_output)} from loop_fn: "
77          f"{loop_fn} with output {fn_output}.")
78    outputs = []
79    del is_none_list[:]
80    is_none_list.extend(x is None for x in fn_output)
81    for out, ta in zip(fn_output, ta_list):
82      # TODO(agarwal): support returning Operation objects from loop_fn.
83      if out is not None:
84        # out may be a ref tensor, wrap it in identity to get a non-ref tensor.
85        ta = ta.write(i, array_ops.expand_dims(out, 0))
86      outputs.append(ta)
87    return tuple([i + 1] + outputs)
88
89  if parallel_iterations is not None:
90    extra_args = {"parallel_iterations": parallel_iterations}
91  else:
92    extra_args = {}
93  ta_list = control_flow_ops.while_loop(
94      lambda i, *ta: i < iters,
95      while_body,
96      [0] + [tensor_array_ops.TensorArray(dtype.base_dtype, iters)
97             for dtype in flat_loop_fn_dtypes],
98      **extra_args)[1:]
99
100  # TODO(rachelim): enable this for sparse tensors
101
102  output = [None if is_none else ta.concat()
103            for ta, is_none in zip(ta_list, is_none_list)]
104  assert len(output) in (0, len(flat_loop_fn_dtypes))
105  if not output:
106    # This may happen for the case where iters == 0.
107    return None
108  else:
109    return nest.pack_sequence_as(loop_fn_dtypes, output)
110
111
112def _flatten_first_two_dims(x):
113  """Flattens the first two dimensions of x into a single dimension."""
114  old_shape = array_ops.shape(x)
115  new_shape = array_ops.concat([[old_shape[0] * old_shape[1]], old_shape[2:]],
116                               axis=0)
117  return array_ops.reshape(x, new_shape)
118
119
120PFOR_CONFIG_ARG = "pfor_config"
121
122
123def _is_under_xla_context():
124  """Check if we are currently inside an XLA compile context."""
125  g = ops.get_default_graph()
126  while g is not None:
127    control_flow_context = g._get_control_flow_context()  # pylint: disable=protected-access
128    while control_flow_context is not None:
129      if control_flow_context.IsXLAContext():
130        return True
131      else:
132        control_flow_context = control_flow_context.outer_context
133    # If g is a FuncGraph, get its outer_graph.
134    g = getattr(g, "outer_graph", None)
135  return False
136
137
138def pfor(loop_fn,
139         iters,
140         fallback_to_while_loop=True,
141         parallel_iterations=None,
142         warn=False):
143  """Equivalent to running `loop_fn` `iters` times and stacking the outputs.
144
145  `pfor` has functionality similar to `for_loop`, i.e. running `loop_fn` `iters`
146  times, with input from 0 to `iters - 1`, and stacking corresponding output of
147  each iteration. However the implementation does not use a `tf.while_loop`.
148  Instead it adds new operations to the graph that collectively compute the same
149  value as what running `loop_fn` in a loop would compute.
150
151
152  This is an experimental feature and currently has a lot of limitations:
153    - There should be no data dependency between the different iterations. For
154      example, a future iteration should not depend on a value or side-effect of
155      a previous iteration.
156    - Stateful kernels may mostly not be supported since these often imply a
157      data dependency or ordering of the iterations. We do support a limited set
158      of such stateful kernels though (like RandomFoo, Variable operations like
159      reads, etc).
160    - Conversion works only on a limited set of kernels for which a converter
161      has been registered.
162    - `loop_fn` has limited support for control flow operations. `tf.cond` in
163      particular is not supported.
164    - `loop_fn` should return nested structure of Tensors or Operations. However
165      if an Operation is returned, it should have zero outputs.
166    - The shape and dtype of `loop_fn` outputs should not depend on the input
167      to loop_fn.
168
169  Args:
170    loop_fn: A function that takes an int32 scalar tf.Tensor object representing
171      the iteration number, and optionally a keyword argument `pfor_config` set
172      to a PForConfig object. It returns a possibly nested structure of Tensor
173      or Operation objects. Note that if setting `parallel_iterations` argument
174      to something other than None, `loop_fn` may be called more than once
175      during graph construction. So it may need to avoid mutating global state.
176    iters: Number of iterations for which to run `loop_fn`.
177    fallback_to_while_loop: If true, on failing to vectorize an operation, pfor
178      fallbacks to using a `tf.while_loop` to dispatch the iterations.
179    parallel_iterations: A knob to control how many iterations are vectorized
180      and dispatched in parallel. The default value of None corresponds to
181      vectorizing all the iterations.  If `parallel_iterations` is smaller than
182      `iters`, then chunks of at most that many iterations are dispatched in
183      sequence. This knob can be used to control the total memory usage.
184    warn: Whether or not to warn when falling back to while loops.
185
186  Returns:
187    Returns a nested structure of stacked tensor objects with the same nested
188    structure as the output of `loop_fn`.
189  Raises:
190    ValueError: If parallel_iterations is not None and not an integer > 1.
191  """
192  def f():
193    return _pfor_impl(
194        loop_fn,
195        iters,
196        fallback_to_while_loop=fallback_to_while_loop,
197        parallel_iterations=parallel_iterations,
198        warn=warn)
199  # Note that we wrap into a tf.function if in eager execution mode or under
200  # XLA compilation. The latter is so that we don't compile operations like
201  # tf.placeholder that are created by the loop body.
202  functions_run_eagerly = None
203  if context.executing_eagerly() or _is_under_xla_context():
204    functions_run_eagerly = def_function.functions_run_eagerly()
205    if functions_run_eagerly:
206      logging.warning(
207          "It looks like tf.function behavior was disabled, perhaps using "
208          "tf.config.run_functions_eagerly. Vectorization "
209          "primitives (e.g. tf.vectorized_map) require tf.function to work. "
210          "These primitives will override the disable.")
211      def_function.run_functions_eagerly(False)
212    f = def_function.function(f)
213
214  outputs = f()
215  if functions_run_eagerly is not None:
216    def_function.run_functions_eagerly(functions_run_eagerly)
217  return outputs
218
219
220def _should_expand_composite(value):
221  return (isinstance(value, composite_tensor.CompositeTensor)
222          # Leave sparse tensors to be converted by `PFor._convert_sparse`.
223          and not isinstance(value, sparse_tensor.SparseTensor)
224          and not isinstance(value, indexed_slices.IndexedSlices))
225
226
227# pylint: disable=protected-access
228def _composite_to_tensors(value, is_batched=False):
229  """Converts a CompositeTensor into a list of stackable tensors."""
230  if _should_expand_composite(value):
231    spec = value._type_spec
232    if not isinstance(spec, type_spec.BatchableTypeSpec):
233      raise ValueError(f"CompositeTensor instance {value} returned from "
234                       "parallel_for or vectorized_map loop body must provide "
235                       f"a `BatchableTypeSpec` (saw: {spec}).")
236    if is_batched:
237      return spec._to_batched_tensor_list(value)
238    return spec._to_tensor_list(value)
239  return value
240# pylint: enable=protected-access
241
242
243# pylint: disable=protected-access
244def _composite_from_tensors(stacked_tensors,
245                            preconverted_value,
246                            batch_size):
247  """Converts a list of stacked tensors to a batch CompositeTensor."""
248  if _should_expand_composite(preconverted_value):
249    batch_type_spec = preconverted_value._type_spec._batch(batch_size)
250    return batch_type_spec._from_compatible_tensor_list(stacked_tensors)
251  return stacked_tensors
252# pylint: enable=protected-access
253
254
255def _loop_fn_has_config(loop_fn):
256  """Test if `loop_fn` has a `pfor_config` argument."""
257  if tf_inspect.isfunction(loop_fn):
258    argspec = tf_inspect.getargspec(loop_fn)
259    return PFOR_CONFIG_ARG in argspec.args
260  elif isinstance(loop_fn, functools.partial):
261    fn = loop_fn.func
262    argspec = tf_inspect.getargspec(fn)
263    return (PFOR_CONFIG_ARG in argspec.args and
264            PFOR_CONFIG_ARG not in loop_fn.keywords)
265  else:
266    loop_class = tf_decorator.unwrap(loop_fn)[1]
267    if not hasattr(loop_class, "__call__"):
268      raise ValueError("`loop_fn` object did not have a __call__ method")
269    argspec = tf_inspect.getargspec(loop_class.__call__)
270    return PFOR_CONFIG_ARG in argspec.args
271
272
273def _pfor_impl(loop_fn,
274               iters,
275               fallback_to_while_loop,
276               parallel_iterations=None,
277               pfor_config=None,
278               warn=False):
279  """Implementation of pfor."""
280  assert not context.executing_eagerly()
281  loop_fn_has_config = _loop_fn_has_config(loop_fn)
282  existing_ops = set(ops.get_default_graph().get_operations())
283  iters_value = tensor_util.constant_value(iters)
284  # Run the loop body
285  with ops.name_scope("loop_body"):
286    loop_var = array_ops.placeholder_with_default(0, shape=[])
287    if loop_fn_has_config:
288      if pfor_config is None:
289        pfor_config = PForConfig()
290        pfor_config._set_iters(iters)  # pylint: disable=protected-access
291      loop_fn_outputs = loop_fn(loop_var, **{PFOR_CONFIG_ARG: pfor_config})
292    else:
293      assert pfor_config is None
294      f = autograph.tf_convert(loop_fn, autograph_ctx.control_status_ctx())
295      loop_fn_outputs = f(loop_var)
296    loop_fn_output_tensors = nest.map_structure(_composite_to_tensors,
297                                                loop_fn_outputs)
298
299  # Convert outputs to Tensor if needed.
300  tmp_loop_fn_outputs = []
301  for loop_fn_output in nest.flatten(loop_fn_output_tensors):
302    if (loop_fn_output is not None and not isinstance(
303        loop_fn_output,
304        (ops.Operation, ops.Tensor, sparse_tensor.SparseTensor))):
305      if isinstance(loop_fn_output, indexed_slices.IndexedSlices):
306        logging.warn("Converting %s to a dense representation may make it slow."
307                     " Alternatively, output the indices and values of the"
308                     " IndexedSlices separately, and handle the vectorized"
309                     " outputs directly." % loop_fn_output)
310        loop_fn_output = ops.convert_to_tensor(loop_fn_output)
311      else:
312        loop_fn_output = ops.convert_to_tensor(loop_fn_output)
313    tmp_loop_fn_outputs.append(loop_fn_output)
314  loop_fn_output_tensors = nest.pack_sequence_as(loop_fn_output_tensors,
315                                                 tmp_loop_fn_outputs)
316
317  new_ops = set(ops.get_default_graph().get_operations()) - existing_ops
318  iters = ops.convert_to_tensor(iters)
319  if parallel_iterations is not None:
320    if parallel_iterations < 1:
321      raise ValueError(
322          "Argument `parallel_iterations` must be None or a positive integer. "
323          f"Received: {parallel_iterations}.")
324    if parallel_iterations == 1:
325      raise ValueError(
326          "Found `parallel_iterations == 1`. Use `for_loop` instead.")
327    if iters_value is not None and iters_value < parallel_iterations:
328      parallel_iterations = None
329  if parallel_iterations is None:
330    with ops.name_scope("pfor"):
331      converter = PFor(
332          loop_var,
333          iters,
334          new_ops,
335          fallback_to_while_loop=fallback_to_while_loop,
336          pfor_config=pfor_config,
337          warn=warn)
338      flattened_output_tensors = []
339      for loop_fn_output in nest.flatten(loop_fn_output_tensors):
340        output = converter.convert(loop_fn_output)
341        flattened_output_tensors.append(output)
342  else:
343    if pfor_config is not None and pfor_config._has_reductions():  # pylint: disable=protected-access
344      raise ValueError("Setting `parallel_iterations` currently unsupported if "
345                       "reductions across iterations are performed.")
346    num_tiled_iterations = iters // parallel_iterations
347    num_remaining_iterations = iters % parallel_iterations
348    # TODO(agarwal): Avoid calling loop_fn twice. Generate the loop body inside
349    # a tf.function and extract the graph from there to vectorize it.
350    with ops.name_scope("pfor_untiled"):
351      converter = PFor(loop_var, num_remaining_iterations, new_ops,
352                       fallback_to_while_loop=fallback_to_while_loop,
353                       pfor_config=pfor_config)
354      remaining_output_tensors = []
355      flattened_output_tensors = nest.flatten(loop_fn_output_tensors)
356      for loop_fn_output in flattened_output_tensors:
357        output = converter.convert(loop_fn_output)
358        remaining_output_tensors.append(output)
359
360    with ops.name_scope("pfor_tiled"):
361      loop_fn_dtypes = [ops.convert_to_tensor(x).dtype
362                        for x in flattened_output_tensors]
363
364      def tiled_loop_body(j):
365        offset = j * parallel_iterations + num_remaining_iterations
366
367        def tiled_loop_fn(i, pfor_config=None):
368          if loop_fn_has_config:
369            loop_fn_outputs = loop_fn(i + offset, pfor_config=pfor_config)
370          else:
371            loop_fn_outputs = loop_fn(i + offset)
372          return nest.flatten(
373              # Stacking across iterations requires explicit Tensors.
374              nest.map_structure(_composite_to_tensors, loop_fn_outputs))
375
376        return _pfor_impl(
377            tiled_loop_fn,
378            parallel_iterations,
379            fallback_to_while_loop=fallback_to_while_loop,
380            pfor_config=pfor_config)
381
382      tiled_output_tensors = for_loop(
383          tiled_loop_body, loop_fn_dtypes,
384          num_tiled_iterations, parallel_iterations=1)
385      tiled_output_tensors = [
386          _flatten_first_two_dims(y) for y in tiled_output_tensors]
387
388    with ops.name_scope("pfor"):
389      if iters_value is None or iters_value % parallel_iterations:
390        output_tensors = control_flow_ops.cond(
391            math_ops.equal(num_remaining_iterations, 0),
392            lambda: tiled_output_tensors,
393            lambda: [array_ops.concat([x, y], axis=0)  # pylint: disable=g-long-lambda
394                     for x, y in zip(remaining_output_tensors,
395                                     tiled_output_tensors)])
396      else:
397        output_tensors = tiled_output_tensors
398      flattened_output_tensors = nest.flatten(output_tensors)
399
400      for output, original_output in zip(flattened_output_tensors,
401                                         nest.flatten(loop_fn_output_tensors)):
402        # Restore any shape information lost from tiling.
403        # TODO(b/174254748): this may not be correct for stacked `variant`s.
404        output.set_shape(
405            tensor_shape.TensorShape([iters_value]).concatenate(
406                original_output.shape))
407
408  return nest.map_structure_up_to(
409      loop_fn_outputs,
410      functools.partial(_composite_from_tensors, batch_size=iters_value),
411      nest.pack_sequence_as(loop_fn_output_tensors,
412                            flattened_output_tensors),
413      loop_fn_outputs)
414
415
416def _broadcasting_gather(x, i):
417  """Wrapper for gather that implicitly broadcasts unit dimensions."""
418  static_first_dim = tensor_shape.dimension_value(x.shape[0])
419  if static_first_dim == 1:
420    i = 0
421  elif static_first_dim is None:
422    i = array_ops.where_v2(array_ops.shape(x)[0] > 1, i, 0)
423  result = array_ops.gather(x, i)
424  return result
425
426
427# pylint: disable=protected-access
428def _gather_from_tensor_or_composite(x, i):
429  """Wrapper for gather that handles CompositeTensors."""
430  if _should_expand_composite(x):
431    spec = x._type_spec
432    gathered_tensors = [_broadcasting_gather(t, i)
433                        for t in spec._to_batched_tensor_list(x)]
434    return spec._unbatch()._from_compatible_tensor_list(gathered_tensors)
435  return _broadcasting_gather(x, i)
436# pylint: enable=protected-access
437
438
439@tf_export("vectorized_map")
440def vectorized_map(fn, elems, fallback_to_while_loop=True, warn=True):
441  """Parallel map on the list of tensors unpacked from `elems` on dimension 0.
442
443  This method works similar to `tf.map_fn` but is optimized to run much faster,
444  possibly with a much larger memory footprint. The speedups are obtained by
445  vectorization (see [Auto-Vectorizing TensorFlow Graphs: Jacobians,
446  Auto-Batching and Beyond](https://arxiv.org/pdf/1903.04243.pdf)). The idea
447  behind vectorization is to semantically launch all the invocations of `fn` in
448  parallel and fuse corresponding operations across all these invocations. This
449  fusion is done statically at graph generation time and the generated code is
450  often similar in performance to a manually fused version.
451
452  Because `tf.vectorized_map` fully parallelizes the batch, this method will
453  generally be significantly faster than using `tf.map_fn`, especially in eager
454  mode. However this is an experimental feature and currently has a lot of
455  limitations:
456    - There should be no data dependency between the different semantic
457      invocations of `fn`, i.e. it should be safe to map the elements of the
458      inputs in any order.
459    - Stateful kernels may mostly not be supported since these often imply a
460      data dependency. We do support a limited set of such stateful kernels
461      though (like RandomFoo, Variable operations like reads, etc).
462    - `fn` has limited support for control flow operations.
463    - `fn` should return nested structure of Tensors or Operations. However
464      if an Operation is returned, it should have zero outputs.
465    - The shape and dtype of any intermediate or output tensors in the
466      computation of `fn` should not depend on the input to `fn`.
467
468  Examples:
469  ```python
470  def outer_product(a):
471    return tf.tensordot(a, a, 0)
472
473  batch_size = 100
474  a = tf.ones((batch_size, 32, 32))
475  c = tf.vectorized_map(outer_product, a)
476  assert c.shape == (batch_size, 32, 32, 32, 32)
477  ```
478
479  ```python
480  # Computing per-example gradients
481
482  batch_size = 10
483  num_features = 32
484  layer = tf.keras.layers.Dense(1)
485
486  def model_fn(arg):
487    with tf.GradientTape() as g:
488      inp, label = arg
489      inp = tf.expand_dims(inp, 0)
490      label = tf.expand_dims(label, 0)
491      prediction = layer(inp)
492      loss = tf.nn.l2_loss(label - prediction)
493    return g.gradient(loss, (layer.kernel, layer.bias))
494
495  inputs = tf.random.uniform([batch_size, num_features])
496  labels = tf.random.uniform([batch_size, 1])
497  per_example_gradients = tf.vectorized_map(model_fn, (inputs, labels))
498  assert per_example_gradients[0].shape == (batch_size, num_features, 1)
499  assert per_example_gradients[1].shape == (batch_size, 1)
500  ```
501
502  Args:
503    fn: The callable to be performed. It accepts one argument, which will have
504      the same (possibly nested) structure as `elems`, and returns a possibly
505      nested structure of Tensors and Operations, which may be different than
506      the structure of `elems`.
507    elems: A tensor or (possibly nested) sequence of tensors, each of which will
508      be unpacked along their first dimension. The nested sequence of the
509      resulting slices will be mapped over by `fn`. The first dimensions of all
510      elements must broadcast to a consistent value; equivalently, each
511      element tensor must have first dimension of either `B` or `1`, for some
512      common batch size `B >= 1`.
513    fallback_to_while_loop: If true, on failing to vectorize an operation,
514      the unsupported op is wrapped in a tf.while_loop to execute the map
515      iterations. Note that this fallback only happens for unsupported ops and
516      other parts of `fn` are still vectorized. If false, on encountering an
517      unsupported op, a ValueError is thrown. Note that the fallbacks can result
518      in slowdowns since vectorization often yields speedup of one to two orders
519      of magnitude.
520    warn: If set to `false`, this will supress any warnings due to operation
521    conversions in the provided `fn` falling back to while loops.
522
523  Returns:
524    A tensor or (possibly nested) sequence of tensors. Each tensor packs the
525    results of applying fn to tensors unpacked from elems along the first
526    dimension, from first to last.
527
528    Although they are less common as user-visible inputs and outputs, note that
529    tensors of type `tf.variant` which represent tensor lists (for example from
530    `tf.raw_ops.TensorListFromTensor`) are vectorized by stacking the list
531    contents rather than the variant itself, and so the container tensor will
532    have a scalar shape when returned rather than the usual stacked shape. This
533    improves the performance of control flow gradient vectorization.
534
535  Raises:
536    ValueError: If vectorization fails and fallback_to_while_loop is False.
537  """
538  elems = variable_utils.convert_variables_to_tensors(elems)
539  elems = nest.map_structure(ops.convert_to_tensor,
540                             elems,
541                             expand_composites=True)
542
543  def loop_fn(i):
544    gathered_elems = nest.map_structure(
545        lambda x: _gather_from_tensor_or_composite(x, i), elems)
546    return fn(gathered_elems)
547
548  # Extract batch size from the maximum first dimension of any element.
549  flat_elems = nest.flatten(
550      nest.map_structure(
551          functools.partial(_composite_to_tensors,
552                            is_batched=True),
553          elems))
554  def _get_shape(x):
555    if x.shape.rank is None:
556      return None
557    return x.shape.as_list()[0]
558  static_first_dims = [_get_shape(elem) for elem in flat_elems]
559  if any(s is None for s in static_first_dims):
560    batch_size = math_ops.reduce_max(
561        [array_ops.shape(elem)[0] for elem in flat_elems])
562  else:
563    batch_size = max(static_first_dims)
564
565  return pfor(
566      loop_fn,
567      batch_size,
568      fallback_to_while_loop=fallback_to_while_loop,
569      warn=warn)
570