xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/device.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Class to represent a device."""
17
18from tensorflow.python import tf2
19from tensorflow.python.framework import device_spec
20
21if tf2.enabled():
22  DeviceSpec = device_spec.DeviceSpecV2
23else:
24  DeviceSpec = device_spec.DeviceSpecV1
25
26
27def check_valid(spec):
28  """Check that a device spec is valid.
29
30  Args:
31    spec: a string.
32
33  Raises:
34    An exception if the spec is invalid.
35  """
36  # Construct a DeviceSpec.  It will assert a failure if spec is invalid.
37  DeviceSpec.from_string(spec)
38
39
40def is_device_spec(obj):
41  """Abstract away the fact that DeviceSpecV2 is the base class."""
42  return isinstance(obj, device_spec.DeviceSpecV2)
43
44
45def canonical_name(device):
46  """Returns a canonical name for the given `DeviceSpec` or device name."""
47  if device is None:
48    return ""
49  if is_device_spec(device):
50    return device.to_string()
51  else:
52    device = DeviceSpec.from_string(device)
53    return device.to_string()
54
55
56# Performance caches
57_cached_mergers = {}
58_string_merge_cache = {}
59
60
61def merge_device(spec):
62  """Returns a device function that merges devices specifications.
63
64  This can be used to merge partial specifications of devices. The
65  innermost setting for a device field takes precedence. For example:
66
67    with tf.device(merge_device("/device:GPU:0"))
68      # Nodes created here have device "/device:GPU:0"
69      with tf.device(merge_device("/job:worker")):
70        # Nodes created here have device "/job:worker/device:GPU:0"
71        with tf.device(merge_device("/device:CPU:0")):
72          # Nodes created here have device "/job:worker/device:CPU:0"
73          with tf.device(merge_device("/job:ps")):
74            # Nodes created here have device "/job:ps/device:CPU:0"
75
76  Args:
77    spec: A `DeviceSpec` or a device spec string (partially) describing the
78      device that should be used for all nodes created in the scope of
79      the returned device function's with block.
80
81  Returns:
82    A MergeDevice object with the above-described behavior.
83
84  Raises:
85    ValueError: if the spec was not valid.
86  """
87
88  if isinstance(spec, MergeDevice):
89    return spec
90
91  merger = _cached_mergers.get(spec)
92  if merger:
93    return merger
94  merger = MergeDevice(spec)
95  # No locking needed, since updates are stateless.
96  _cached_mergers[spec] = merger
97  return merger
98
99
100class MergeDevice(object):
101  """Wraps a device specification (DeviceSpec or str) with merge functionality.
102
103  When called, this class will merge a node_def with its own spec. It also
104  exposes a `shortcut_string_merge` method which can significantly improve
105  performance of device placement.
106  """
107
108  __slots__ = ["_spec"]
109
110  def __init__(self, spec):
111    if isinstance(spec, device_spec.DeviceSpecV2):
112      self._spec = spec
113    elif isinstance(spec, device_spec.DeviceSpecV1):
114      # Capture a snapshot of spec.
115      self._spec = spec.__class__.from_string(spec.to_string())
116    else:
117      self._spec = DeviceSpec.from_string(spec)
118
119  def __call__(self, node_def):
120    # In general a user may create a device function which takes into account
121    # arbitrary properties of an op. (For instance dynamically placing ops based
122    # on type.) So even though the standard DeviceSpec route only uses the
123    # device attribute, we take an entire node_def to maintain a consistent
124    # signature with general device functions.
125    current_device = DeviceSpec.from_string(node_def.device or "")
126    return self._spec.make_merged_spec(current_device)
127
128  def shortcut_string_merge(self, node_def):
129    """Merge a node def without materializing a full DeviceSpec object.
130
131    Often a device merge is invoked in order to generate a string which can be
132    passed into the c api. In such a case, we can cache the
133      node_def.device  ->  merge_result_string
134
135    map, and in most cases avoid:
136      - Materializing a copy of self._spec (In the case of DeviceSpecV1)
137      - Materializing a DeviceSpec for node_def.device
138      - A DeviceSpec.merge_from invocation
139
140    In practice the cache hit rate for this function is very high, because the
141    number of invocations when iterating through the device stack is much
142    larger than the number of devices.
143
144    Args:
145      node_def: An Operation (or Operation-like) to merge device constraints
146        with self._spec
147
148    Returns:
149      A string containing the merged device specification.
150    """
151    device = node_def.device or ""
152
153    merge_key = (self._spec, device)
154    result = _string_merge_cache.get(merge_key)
155    if result is None:
156      # This update is not atomic, however because the merge is stateless
157      # we don't need to lock when updating the cache.
158      result = self.__call__(node_def).to_string()
159      _string_merge_cache[merge_key] = result
160
161    return result
162
163  def __repr__(self):
164    return "{} (spec: {})".format(
165        super(MergeDevice, self).__repr__(), self._spec.to_string())
166
167  @property
168  def is_null_merge(self):
169    """Indicate whether the wrapped spec is empty.
170
171    In the degenerate case where self._spec is an empty specification, a caller
172    may wish to skip a merge step entirely. (However this class does not have
173    enough information to make that determination.)
174
175    Returns:
176      A boolean indicating whether a device merge will be trivial.
177    """
178    return not bool(self._spec.to_string())
179