1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utility functions for TPU.""" 16 17import contextlib 18 19from tensorflow.python.distribute import packed_distributed_variable as packed 20from tensorflow.python.eager import context 21from tensorflow.python.framework import ops 22from tensorflow.python.tpu import tpu 23 24 25def enclosing_tpu_context(): 26 """Returns the TPUReplicateContext, which exists inside a tpu.rewrite().""" 27 return enclosing_tpu_context_and_graph()[0] 28 29 30def enclosing_tpu_context_and_graph(): 31 """Returns the TPUReplicateContext which exists inside a tpu.rewrite(), and its associated graph.""" 32 graph = ops.get_default_graph() 33 while graph is not None: 34 ctx = graph._get_control_flow_context() # pylint: disable=protected-access 35 while ctx is not None: 36 if isinstance(ctx, tpu.TPUReplicateContext): 37 return ctx, graph 38 ctx = ctx.outer_context 39 # This may be a FuncGraph due to defuns or v2 control flow. We need to 40 # find the original graph with the XLAControlFlowContext. 41 graph = getattr(graph, "outer_graph", None) 42 return None, None 43 44 45@contextlib.contextmanager 46def outside_or_skip_tpu_context(): 47 """Returns a context manager that skips current enclosing context if there is any.""" 48 ctx, graph = enclosing_tpu_context_and_graph() 49 if ctx is None: 50 yield 51 else: 52 saved_context = graph._get_control_flow_context() # pylint: disable=protected-access 53 graph._set_control_flow_context(ctx.outer_context) # pylint: disable=protected-access 54 yield 55 graph._set_control_flow_context(saved_context) # pylint: disable=protected-access 56 57 58@contextlib.contextmanager 59def _maybe_enter_graph(tensor): 60 # Note: might have an eager tensor but not be executing eagerly when 61 # building functions. 62 if (context.executing_eagerly() or isinstance(tensor, ops.EagerTensor) or 63 ops.has_default_graph()): 64 yield 65 else: 66 with tensor.graph.as_default(): 67 yield 68 69 70@contextlib.contextmanager 71def _maybe_on_device(var): 72 # Add a device scope for packed variables. 73 if isinstance(var, packed.PackedVarAndDevice): 74 with ops.device(var.device): 75 yield 76 else: 77 yield 78 79 80def make_raw_assign_fn(raw_assign_fn, use_handle=True): 81 """Wrap `raw_assign_fn` with the proper graph context and device scope. 82 83 Args: 84 raw_assign_fn: the function to be wrapped. 85 use_handle: if True, the `raw_assign_fn` will be applied to the handle of a 86 variable; otherwise it will be applied to the variable itself. 87 88 Returns: 89 The wrapped function. 90 """ 91 92 def assign_fn(var, value, use_locking=False, name=None, read_value=True): 93 del use_locking # Unused. 94 95 handle = var.handle if use_handle else var 96 with _maybe_enter_graph(handle), _maybe_on_device(var): 97 op = raw_assign_fn( 98 handle, ops.convert_to_tensor(value, dtype=var.dtype), name=name) 99 with ops.control_dependencies([op]): 100 if read_value: 101 return var._read_variable_op() if use_handle else var.read_value() # pylint: disable=protected-access 102 else: 103 return op 104 105 return assign_fn 106 107 108def make_raw_scatter_xxx_fn(raw_scatter_xxx_fn): 109 """Wrap `raw_scatter_xxx_fn` so that it can be called w/ and w/o packed handle.""" 110 111 def scatter_xxx_fn(var, sparse_delta, use_locking=False, name=None): # pylint: disable=missing-docstring 112 del use_locking # Unused. 113 114 handle = var.handle 115 with _maybe_enter_graph(handle), _maybe_on_device(var): 116 op = raw_scatter_xxx_fn( 117 handle, 118 sparse_delta.indices, 119 ops.convert_to_tensor(sparse_delta.values, var.dtype), 120 name=name) 121 with ops.control_dependencies([op]): 122 return var._read_variable_op() # pylint: disable=protected-access 123 124 return scatter_xxx_fn 125