xref: /aosp_15_r20/external/tensorflow/tensorflow/python/training/device_setter.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Device function for replicated training."""
16from tensorflow.core.framework import node_def_pb2
17from tensorflow.python.framework import device as pydev
18from tensorflow.python.platform import tf_logging as logging
19from tensorflow.python.training import server_lib
20from tensorflow.python.util.tf_export import tf_export
21
22# This is a tuple of PS ops used by tf.estimator.Estimator which should work in
23# almost all of cases.
24STANDARD_PS_OPS = ("Variable", "VariableV2", "AutoReloadVariable",
25                   "MutableHashTable", "MutableHashTableV2",
26                   "MutableHashTableOfTensors", "MutableHashTableOfTensorsV2",
27                   "MutableDenseHashTable", "MutableDenseHashTableV2",
28                   "VarHandleOp", "BoostedTreesEnsembleResourceHandleOp",
29                   "BoostedTreesQuantileStreamResourceHandleOp",
30                   "ResourceConditionalAccumulator",
31                   "DecisionTreeResource")
32
33
34class _RoundRobinStrategy:
35  """Returns the next ps task index for placement in round-robin order.
36
37  This class is not to be used directly by users.  See instead
38  `replica_device_setter()` below.
39  """
40
41  def __init__(self, num_tasks):
42    """Create a new `_RoundRobinStrategy`.
43
44    Args:
45      num_tasks: Number of ps tasks to cycle among.
46    """
47    self._num_tasks = num_tasks
48    self._next_task = 0
49
50  def __call__(self, unused_op):
51    """Choose a ps task index for the given `Operation`.
52
53    Args:
54      unused_op: An `Operation` to be placed on ps.
55
56    Returns:
57      The next ps task index to use for the `Operation`. Returns the next
58      index, in the range `[offset, offset + num_tasks)`.
59    """
60    task = self._next_task
61    self._next_task = (self._next_task + 1) % self._num_tasks
62    return task
63
64
65class _ReplicaDeviceChooser:
66  """Class to choose devices for Ops in a replicated training setup.
67
68  This class is not to be used directly by users.  See instead
69  `replica_device_setter()` below.
70  """
71
72  def __init__(self, ps_tasks, ps_device, worker_device, merge_devices, ps_ops,
73               ps_strategy):
74    """Create a new `_ReplicaDeviceChooser`.
75
76    Args:
77      ps_tasks: Number of tasks in the `ps` job.
78      ps_device: String.  Name of the `ps` job.
79      worker_device: String.  Name of the `worker` job.
80      merge_devices: Boolean. Set to True to allow merging of device specs.
81      ps_ops: List of strings representing `Operation` types that need to be
82        placed on `ps` devices.
83      ps_strategy: A callable invoked for every ps `Operation` (i.e. matched by
84        `ps_ops`), that takes the `Operation` and returns the ps task index to
85        use.
86    """
87    self._ps_tasks = ps_tasks
88    self._ps_device = ps_device
89    self._worker_device = worker_device
90    self._merge_devices = merge_devices
91    self._ps_ops = ps_ops
92    self._ps_strategy = ps_strategy
93
94  def device_function(self, op):
95    """Choose a device for `op`.
96
97    Args:
98      op: an `Operation`.
99
100    Returns:
101      The device to use for the `Operation`.
102    """
103    # If we don't return early here, either merge_devices is True, or op.device
104    # is empty (in which case merging is a no-op). So we can always merge below.
105    if not self._merge_devices and op.device:
106      return op.device
107
108    current_device = pydev.DeviceSpec.from_string(op.device or "")
109
110    # The ps_device will be used for specified ops (ps_ops) whenever it is
111    # present and ps_tasks is non-zero. However, its task number will only be
112    # set (using ps_strategy) if there is a job field in ps_device that won't be
113    # changed by the job field (if present) in current_device.
114    node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def
115    if self._ps_tasks and self._ps_device and node_def.op in self._ps_ops:
116      ps_device = pydev.DeviceSpec.from_string(self._ps_device)
117
118      current_job, ps_job = current_device.job, ps_device.job
119      if ps_job and (not current_job or current_job == ps_job):
120        ps_device = ps_device.replace(task=self._ps_strategy(op))
121
122      ps_device = ps_device.make_merged_spec(current_device)
123      return ps_device.to_string()
124
125    worker_device = pydev.DeviceSpec.from_string(self._worker_device or "")
126    worker_device = worker_device.make_merged_spec(current_device)
127    return worker_device.to_string()
128
129
130@tf_export(v1=["train.replica_device_setter"])
131def replica_device_setter(ps_tasks=0,
132                          ps_device="/job:ps",
133                          worker_device="/job:worker",
134                          merge_devices=True,
135                          cluster=None,
136                          ps_ops=None,
137                          ps_strategy=None):
138  """Return a `device function` to use when building a Graph for replicas.
139
140  Device Functions are used in `with tf.device(device_function):` statement to
141  automatically assign devices to `Operation` objects as they are constructed,
142  Device constraints are added from the inner-most context first, working
143  outwards. The merging behavior adds constraints to fields that are yet unset
144  by a more inner context. Currently the fields are (job, task, cpu/gpu).
145
146  If `cluster` is `None`, and `ps_tasks` is 0, the returned function is a no-op.
147  Otherwise, the value of `ps_tasks` is derived from `cluster`.
148
149  By default, only Variable ops are placed on ps tasks, and the placement
150  strategy is round-robin over all ps tasks. A custom `ps_strategy` may be used
151  to do more intelligent placement, such as
152  `tf.contrib.training.GreedyLoadBalancingStrategy`.
153
154  For example,
155
156  ```python
157  # To build a cluster with two ps jobs on hosts ps0 and ps1, and 3 worker
158  # jobs on hosts worker0, worker1 and worker2.
159  cluster_spec = {
160      "ps": ["ps0:2222", "ps1:2222"],
161      "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]}
162  with
163  tf.compat.v1.device(tf.compat.v1.train.replica_device_setter(cluster=cluster_spec)):
164    # Build your graph
165    v1 = tf.Variable(...)  # assigned to /job:ps/task:0
166    v2 = tf.Variable(...)  # assigned to /job:ps/task:1
167    v3 = tf.Variable(...)  # assigned to /job:ps/task:0
168  # Run compute
169  ```
170
171  Args:
172    ps_tasks: Number of tasks in the `ps` job.  Ignored if `cluster` is
173      provided.
174    ps_device: String.  Device of the `ps` job.  If empty no `ps` job is used.
175      Defaults to `ps`.
176    worker_device: String.  Device of the `worker` job.  If empty no `worker`
177      job is used.
178    merge_devices: `Boolean`. If `True`, merges or only sets a device if the
179      device constraint is completely unset. merges device specification rather
180      than overriding them.
181    cluster: `ClusterDef` proto or `ClusterSpec`.
182    ps_ops: List of strings representing `Operation` types that need to be
183      placed on `ps` devices.  If `None`, defaults to `STANDARD_PS_OPS`.
184    ps_strategy: A callable invoked for every ps `Operation` (i.e. matched by
185      `ps_ops`), that takes the `Operation` and returns the ps task index to
186      use.  If `None`, defaults to a round-robin strategy across all `ps`
187      devices.
188
189  Returns:
190    A function to pass to `tf.device()`.
191
192  Raises:
193    TypeError if `cluster` is not a dictionary or `ClusterDef` protocol buffer,
194    or if `ps_strategy` is provided but not a callable.
195  """
196  if cluster is not None:
197    if isinstance(cluster, server_lib.ClusterSpec):
198      cluster_spec = cluster.as_dict()
199    else:
200      cluster_spec = server_lib.ClusterSpec(cluster).as_dict()
201    # Get ps_job_name from ps_device by stripping "/job:".
202    ps_job_name = pydev.DeviceSpec.from_string(ps_device).job
203    if ps_job_name not in cluster_spec or cluster_spec[ps_job_name] is None:
204      return None
205    ps_tasks = len(cluster_spec[ps_job_name])
206
207  if ps_tasks == 0:
208    return None
209
210  if ps_ops is None:
211    # TODO(sherrym): Variables in the LOCAL_VARIABLES collection should not be
212    # placed in the parameter server.
213    ps_ops = list(STANDARD_PS_OPS)
214
215  if not merge_devices:
216    logging.warning(
217        "DEPRECATION: It is recommended to set merge_devices=true in "
218        "replica_device_setter")
219  if ps_strategy is None:
220    ps_strategy = _RoundRobinStrategy(ps_tasks)
221  if not callable(ps_strategy):
222    raise TypeError("ps_strategy must be callable")
223  chooser = _ReplicaDeviceChooser(ps_tasks, ps_device, worker_device,
224                                  merge_devices, ps_ops, ps_strategy)
225  return chooser.device_function
226