xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/tpu_util.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility functions for TPU."""
16
17import contextlib
18
19from tensorflow.python.distribute import packed_distributed_variable as packed
20from tensorflow.python.eager import context
21from tensorflow.python.framework import ops
22from tensorflow.python.tpu import tpu
23
24
25def enclosing_tpu_context():
26  """Returns the TPUReplicateContext, which exists inside a tpu.rewrite()."""
27  return enclosing_tpu_context_and_graph()[0]
28
29
30def enclosing_tpu_context_and_graph():
31  """Returns the TPUReplicateContext which exists inside a tpu.rewrite(), and its associated graph."""
32  graph = ops.get_default_graph()
33  while graph is not None:
34    ctx = graph._get_control_flow_context()  # pylint: disable=protected-access
35    while ctx is not None:
36      if isinstance(ctx, tpu.TPUReplicateContext):
37        return ctx, graph
38      ctx = ctx.outer_context
39    # This may be a FuncGraph due to defuns or v2 control flow. We need to
40    # find the original graph with the XLAControlFlowContext.
41    graph = getattr(graph, "outer_graph", None)
42  return None, None
43
44
45@contextlib.contextmanager
46def outside_or_skip_tpu_context():
47  """Returns a context manager that skips current enclosing context if there is any."""
48  ctx, graph = enclosing_tpu_context_and_graph()
49  if ctx is None:
50    yield
51  else:
52    saved_context = graph._get_control_flow_context()  # pylint: disable=protected-access
53    graph._set_control_flow_context(ctx.outer_context)  # pylint: disable=protected-access
54    yield
55    graph._set_control_flow_context(saved_context)  # pylint: disable=protected-access
56
57
58@contextlib.contextmanager
59def _maybe_enter_graph(tensor):
60  # Note: might have an eager tensor but not be executing eagerly when
61  # building functions.
62  if (context.executing_eagerly() or isinstance(tensor, ops.EagerTensor) or
63      ops.has_default_graph()):
64    yield
65  else:
66    with tensor.graph.as_default():
67      yield
68
69
70@contextlib.contextmanager
71def _maybe_on_device(var):
72  # Add a device scope for packed variables.
73  if isinstance(var, packed.PackedVarAndDevice):
74    with ops.device(var.device):
75      yield
76  else:
77    yield
78
79
80def make_raw_assign_fn(raw_assign_fn, use_handle=True):
81  """Wrap `raw_assign_fn` with the proper graph context and device scope.
82
83  Args:
84    raw_assign_fn: the function to be wrapped.
85    use_handle: if True, the `raw_assign_fn` will be applied to the handle of a
86      variable; otherwise it will be applied to the variable itself.
87
88  Returns:
89    The wrapped function.
90  """
91
92  def assign_fn(var, value, use_locking=False, name=None, read_value=True):
93    del use_locking  # Unused.
94
95    handle = var.handle if use_handle else var
96    with _maybe_enter_graph(handle), _maybe_on_device(var):
97      op = raw_assign_fn(
98          handle, ops.convert_to_tensor(value, dtype=var.dtype), name=name)
99      with ops.control_dependencies([op]):
100        if read_value:
101          return var._read_variable_op() if use_handle else var.read_value()  # pylint: disable=protected-access
102        else:
103          return op
104
105  return assign_fn
106
107
108def make_raw_scatter_xxx_fn(raw_scatter_xxx_fn):
109  """Wrap `raw_scatter_xxx_fn` so that it can be called w/ and w/o packed handle."""
110
111  def scatter_xxx_fn(var, sparse_delta, use_locking=False, name=None):  # pylint: disable=missing-docstring
112    del use_locking  # Unused.
113
114    handle = var.handle
115    with _maybe_enter_graph(handle), _maybe_on_device(var):
116      op = raw_scatter_xxx_fn(
117          handle,
118          sparse_delta.indices,
119          ops.convert_to_tensor(sparse_delta.values, var.dtype),
120          name=name)
121      with ops.control_dependencies([op]):
122        return var._read_variable_op()  # pylint: disable=protected-access
123
124  return scatter_xxx_fn
125