xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/multi_worker_util.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for multi-worker distribution strategies."""
16
17from tensorflow.core.protobuf import cluster_pb2
18from tensorflow.python.distribute import distribute_coordinator_context as dc_context
19from tensorflow.python.training import server_lib
20
21
22def normalize_cluster_spec(cluster_spec):
23  """Makes `cluster_spec` into a `ClusterSpec` object.
24
25  Args:
26    cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
27      cluster configurations.
28
29  Returns:
30    a `ClusterSpec` object.
31
32  Raises:
33    ValueError: if `cluster_spec` is not a dict or a `ClusterSpec` or a
34      `ClusterDef`.
35  """
36  if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)):
37    return server_lib.ClusterSpec(cluster_spec)
38  elif not isinstance(cluster_spec, server_lib.ClusterSpec):
39    raise ValueError(
40        "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
41        "`tf.train.ClusterDef` object")
42  return cluster_spec
43
44
45def task_count(cluster_spec, task_type):
46  try:
47    return cluster_spec.num_tasks(task_type)
48  except ValueError:
49    return 0
50
51
52def _validate_cluster_spec(cluster_spec,
53                           task_type,
54                           task_id):
55  """Validates `cluster_spec`.
56
57  It checks:
58  1) task type is one of "chief", "worker", "ps", "evaluator", or not provided
59     (None).
60  2) whether there is such a task type as `task_type` in the `cluster_spec`. The
61     only exception is `evaluator`. In other words, it is still a valid
62     configuration when `task_type` is `evaluator` but it doesn't appear in
63     `cluster_spec`. This is to be compatible with `TF_CONFIG` in Estimator.
64  3) whether there is at most one "chief" job.
65  4) whether there is at most one "evaluator" job.
66  5) whether the `task_id` is smaller than the number of tasks for that
67     particular `task_type`.
68
69  Args:
70    cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated.
71    task_type: string indicating the type of the task.
72    task_id: the id of the `task_type` in this cluster.
73
74  Raises:
75    ValueError: if `cluster_spec` fails any check.
76  """
77  allowed_task_types = ("chief", "worker", "evaluator", "ps", None)
78
79  cluster_spec = normalize_cluster_spec(cluster_spec)
80
81  if any(job not in allowed_task_types for job in cluster_spec.jobs):
82    raise ValueError("Disallowed task type found in cluster spec. Allowed "
83                     "types are {} and the cluster spec is {}.".format(
84                         allowed_task_types, cluster_spec))
85
86  if task_type not in allowed_task_types:
87    raise ValueError(
88        "Unrecognized task_type: {}, valid task types are: {}".format(
89            task_type, allowed_task_types))
90
91  if (task_type and task_type not in cluster_spec.jobs and
92      task_type != "evaluator"):
93    raise ValueError("`task_type` %r not found in cluster_spec." % task_type)
94
95  if task_count(cluster_spec, "chief") > 1:
96    raise ValueError("There must be at most one 'chief' job.")
97
98  if task_count(cluster_spec, "evaluator") > 1:
99    raise ValueError("There must be at most one 'evaluator' job.")
100
101  # The `evaluator` job is allowed to be missing in `cluster_spec`.
102  if task_type in cluster_spec.jobs and task_id >= task_count(
103      cluster_spec, task_type):
104    raise ValueError(
105        "The `task_id` %d exceeds the maximum id of %s." % (task_id, task_type))
106
107
108def is_chief(cluster_spec=None, task_type=None, task_id=None):
109  """Returns whether the given task is chief in the cluster.
110
111  Since there is at most one evaluator and the evaluator itself should be
112  independent of the training cluster, the evaluator job is also a chief job on
113  its own.
114
115  If this is currently running under a `_WorkerContext` of distribute
116  coordinator, the arguments can be omitted as the result is already available.
117
118  Args:
119    cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the
120      cluster configurations.
121    task_type: the task type in the cluster.
122    task_id: the task id in the cluster.
123
124  Returns:
125    a boolean indicating whether the given task is chief.
126
127  Raises:
128    ValueError: if `task_type` is not in the `cluster_spec` or `task_id` exceeds
129      the maximum id of the `task_type`.
130  """
131  if has_worker_context():
132    # If a worker context exists, use the value provided by it.
133    return dc_context.get_current_worker_context().is_chief
134
135  _validate_cluster_spec(cluster_spec, task_type, task_id)
136  cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
137
138  if task_type == "chief" or task_type == "evaluator":
139    return True
140
141  # If chief not in the cluster_spec, use the first worker as chief. This is
142  # common in CollectiveAllReduceStrategy.
143  if ("chief" not in cluster_spec and task_type == "worker" and task_id == 0):
144    return True
145  return False
146
147
148def collective_leader(cluster_spec, task_type, task_id):
149  """Return the job name for the leader of for collective ops.
150
151  Args:
152    cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the
153      cluster configurations.
154    task_type: the task type in the cluster.
155    task_id: the task id in the cluster.
156
157  Returns:
158    a string indicating the leader job name or empty string if no need to set
159    leader job.
160  """
161  cluster_spec = normalize_cluster_spec(cluster_spec)
162
163  # No need to set collective leader for local.
164  if not cluster_spec.as_dict():
165    return ""
166
167  _validate_cluster_spec(cluster_spec, task_type, task_id)
168
169  # Only one evaluator, so no need to set collective leader.
170  if task_type == "evaluator":
171    return ""
172
173  # Use chief if chief is in the cluster.
174  if "chief" in cluster_spec.jobs:
175    return "/job:chief/replica:0/task:0"
176
177  # Use worker 0 if no chief job.
178  assert "worker" in cluster_spec.jobs
179  return "/job:worker/replica:0/task:0"
180
181
182def coordination_leader(cluster_spec):
183  """Return the task name of the coordination service leader.
184
185  Args:
186    cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object sxpecifying the
187      cluster configurations.
188
189  Returns:
190    a string indicating the task name of the coordination service leader.
191  """
192  cluster_spec = normalize_cluster_spec(cluster_spec)
193
194  # No need to set coordination service leader for local.
195  if not cluster_spec.as_dict():
196    return ""
197
198  # Use chief if chief is in the cluster.
199  if "chief" in cluster_spec.jobs:
200    return "/job:chief/replica:0/task:0"
201
202  # Use worker 0 if no chief job.
203  assert "worker" in cluster_spec.jobs
204  return "/job:worker/replica:0/task:0"
205
206
207def worker_count(cluster_spec, task_type):
208  """Returns the number of workers in the cluster."""
209  _validate_cluster_spec(cluster_spec, task_type, task_id=0)
210  cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
211
212  # Other jobs such as "ps" shouldn't call this function.
213  if task_type not in ["chief", "worker", "evaluator"]:
214    raise ValueError("Unexpected `task_type` %r" % task_type)
215
216  if task_type == "evaluator":
217    # The "evaluator" is in its own cluster or its own partition of a cluster.
218    # So we don't have to count "chief" or "worker" if the current task is an
219    # "evaluator".
220    return len(cluster_spec["evaluator"])
221  else:
222    # In the non-evaluator case, we return the total number of "chief" and
223    # "worker" tasks as the "chief" is also a worker.
224    return (len(cluster_spec.get("chief", [])) + len(
225        cluster_spec.get("worker", [])))
226
227
228def id_in_cluster(cluster_spec, task_type, task_id):
229  """Returns a unique id for the task in the `task_type`'s cluster.
230
231  It returns an id ranging from [0, `worker_count(task_type, task_id)`).
232
233  Note: this function assumes that "evaluate" job is in its own cluster or its
234  own partition of a cluster.
235
236  Args:
237    cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated.
238    task_type: string indicating the type of the task.
239    task_id: the id of the `task_type` in this cluster.
240
241  Returns:
242    an int indicating the unique id.
243
244  Throws:
245    ValueError: if `task_type` is not "chief", "worker" or "evaluator".
246  """
247  _validate_cluster_spec(cluster_spec, task_type, task_id)
248  cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
249
250  # The "chief" job has always id 0 and there is at most one and "worker" jobs
251  # come after it.
252  if task_type == "chief":
253    return 0
254
255  if task_type == "worker":
256    return task_id + len(cluster_spec.get("chief", []))
257
258  # The "evaluator" is in its own cluster or its own partition of a cluster.
259  if task_type == "evaluator":
260    return task_id
261
262  # We currently don't assign ids to other tasks.
263  raise ValueError("There is no id for task_type %r" % task_type)
264
265
266def should_save_checkpoint():
267  """Returns whether the current worker should save checkpoints.
268
269  In multi-worker training, if saving checkpoint is requested by user, or needed
270  for fault-tolerance, the cluster should save checkpoint but not necessarily
271  every worker in the cluster should.
272
273  TODO(rchao): Consider generalizing this util to be `should_save_file` as there
274  can be other files to save such as summary.
275
276  Returns:
277      Whether this particular worker in the cluster should save checkpoints.
278  """
279  return dc_context.get_current_worker_context().should_checkpoint
280
281
282def should_load_checkpoint():
283  """Returns whether the current worker should load checkpoints.
284
285  In multi-worker training, if loading checkpoint is requested by user, or
286  needed for fault-tolerance, the cluster should load checkpoint but not
287  necessarily every worker in the cluster should.
288
289  Returns:
290      Whether this particular worker in the cluster should load checkpoints.
291  """
292  return dc_context.get_current_worker_context().experimental_should_init
293
294
295def wait_for_other_workers():
296  """Waits for other workers to reach the same call to this method."""
297  return dc_context.get_current_worker_context().wait_for_other_workers()
298
299
300def has_worker_context():
301  """Returns whether a worker context has been entered."""
302  return dc_context.get_current_worker_context() is not None
303