1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for multi-worker distribution strategies.""" 16 17from tensorflow.core.protobuf import cluster_pb2 18from tensorflow.python.distribute import distribute_coordinator_context as dc_context 19from tensorflow.python.training import server_lib 20 21 22def normalize_cluster_spec(cluster_spec): 23 """Makes `cluster_spec` into a `ClusterSpec` object. 24 25 Args: 26 cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the 27 cluster configurations. 28 29 Returns: 30 a `ClusterSpec` object. 31 32 Raises: 33 ValueError: if `cluster_spec` is not a dict or a `ClusterSpec` or a 34 `ClusterDef`. 35 """ 36 if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)): 37 return server_lib.ClusterSpec(cluster_spec) 38 elif not isinstance(cluster_spec, server_lib.ClusterSpec): 39 raise ValueError( 40 "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a " 41 "`tf.train.ClusterDef` object") 42 return cluster_spec 43 44 45def task_count(cluster_spec, task_type): 46 try: 47 return cluster_spec.num_tasks(task_type) 48 except ValueError: 49 return 0 50 51 52def _validate_cluster_spec(cluster_spec, 53 task_type, 54 task_id): 55 """Validates `cluster_spec`. 56 57 It checks: 58 1) task type is one of "chief", "worker", "ps", "evaluator", or not provided 59 (None). 60 2) whether there is such a task type as `task_type` in the `cluster_spec`. The 61 only exception is `evaluator`. In other words, it is still a valid 62 configuration when `task_type` is `evaluator` but it doesn't appear in 63 `cluster_spec`. This is to be compatible with `TF_CONFIG` in Estimator. 64 3) whether there is at most one "chief" job. 65 4) whether there is at most one "evaluator" job. 66 5) whether the `task_id` is smaller than the number of tasks for that 67 particular `task_type`. 68 69 Args: 70 cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated. 71 task_type: string indicating the type of the task. 72 task_id: the id of the `task_type` in this cluster. 73 74 Raises: 75 ValueError: if `cluster_spec` fails any check. 76 """ 77 allowed_task_types = ("chief", "worker", "evaluator", "ps", None) 78 79 cluster_spec = normalize_cluster_spec(cluster_spec) 80 81 if any(job not in allowed_task_types for job in cluster_spec.jobs): 82 raise ValueError("Disallowed task type found in cluster spec. Allowed " 83 "types are {} and the cluster spec is {}.".format( 84 allowed_task_types, cluster_spec)) 85 86 if task_type not in allowed_task_types: 87 raise ValueError( 88 "Unrecognized task_type: {}, valid task types are: {}".format( 89 task_type, allowed_task_types)) 90 91 if (task_type and task_type not in cluster_spec.jobs and 92 task_type != "evaluator"): 93 raise ValueError("`task_type` %r not found in cluster_spec." % task_type) 94 95 if task_count(cluster_spec, "chief") > 1: 96 raise ValueError("There must be at most one 'chief' job.") 97 98 if task_count(cluster_spec, "evaluator") > 1: 99 raise ValueError("There must be at most one 'evaluator' job.") 100 101 # The `evaluator` job is allowed to be missing in `cluster_spec`. 102 if task_type in cluster_spec.jobs and task_id >= task_count( 103 cluster_spec, task_type): 104 raise ValueError( 105 "The `task_id` %d exceeds the maximum id of %s." % (task_id, task_type)) 106 107 108def is_chief(cluster_spec=None, task_type=None, task_id=None): 109 """Returns whether the given task is chief in the cluster. 110 111 Since there is at most one evaluator and the evaluator itself should be 112 independent of the training cluster, the evaluator job is also a chief job on 113 its own. 114 115 If this is currently running under a `_WorkerContext` of distribute 116 coordinator, the arguments can be omitted as the result is already available. 117 118 Args: 119 cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the 120 cluster configurations. 121 task_type: the task type in the cluster. 122 task_id: the task id in the cluster. 123 124 Returns: 125 a boolean indicating whether the given task is chief. 126 127 Raises: 128 ValueError: if `task_type` is not in the `cluster_spec` or `task_id` exceeds 129 the maximum id of the `task_type`. 130 """ 131 if has_worker_context(): 132 # If a worker context exists, use the value provided by it. 133 return dc_context.get_current_worker_context().is_chief 134 135 _validate_cluster_spec(cluster_spec, task_type, task_id) 136 cluster_spec = normalize_cluster_spec(cluster_spec).as_dict() 137 138 if task_type == "chief" or task_type == "evaluator": 139 return True 140 141 # If chief not in the cluster_spec, use the first worker as chief. This is 142 # common in CollectiveAllReduceStrategy. 143 if ("chief" not in cluster_spec and task_type == "worker" and task_id == 0): 144 return True 145 return False 146 147 148def collective_leader(cluster_spec, task_type, task_id): 149 """Return the job name for the leader of for collective ops. 150 151 Args: 152 cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the 153 cluster configurations. 154 task_type: the task type in the cluster. 155 task_id: the task id in the cluster. 156 157 Returns: 158 a string indicating the leader job name or empty string if no need to set 159 leader job. 160 """ 161 cluster_spec = normalize_cluster_spec(cluster_spec) 162 163 # No need to set collective leader for local. 164 if not cluster_spec.as_dict(): 165 return "" 166 167 _validate_cluster_spec(cluster_spec, task_type, task_id) 168 169 # Only one evaluator, so no need to set collective leader. 170 if task_type == "evaluator": 171 return "" 172 173 # Use chief if chief is in the cluster. 174 if "chief" in cluster_spec.jobs: 175 return "/job:chief/replica:0/task:0" 176 177 # Use worker 0 if no chief job. 178 assert "worker" in cluster_spec.jobs 179 return "/job:worker/replica:0/task:0" 180 181 182def coordination_leader(cluster_spec): 183 """Return the task name of the coordination service leader. 184 185 Args: 186 cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object sxpecifying the 187 cluster configurations. 188 189 Returns: 190 a string indicating the task name of the coordination service leader. 191 """ 192 cluster_spec = normalize_cluster_spec(cluster_spec) 193 194 # No need to set coordination service leader for local. 195 if not cluster_spec.as_dict(): 196 return "" 197 198 # Use chief if chief is in the cluster. 199 if "chief" in cluster_spec.jobs: 200 return "/job:chief/replica:0/task:0" 201 202 # Use worker 0 if no chief job. 203 assert "worker" in cluster_spec.jobs 204 return "/job:worker/replica:0/task:0" 205 206 207def worker_count(cluster_spec, task_type): 208 """Returns the number of workers in the cluster.""" 209 _validate_cluster_spec(cluster_spec, task_type, task_id=0) 210 cluster_spec = normalize_cluster_spec(cluster_spec).as_dict() 211 212 # Other jobs such as "ps" shouldn't call this function. 213 if task_type not in ["chief", "worker", "evaluator"]: 214 raise ValueError("Unexpected `task_type` %r" % task_type) 215 216 if task_type == "evaluator": 217 # The "evaluator" is in its own cluster or its own partition of a cluster. 218 # So we don't have to count "chief" or "worker" if the current task is an 219 # "evaluator". 220 return len(cluster_spec["evaluator"]) 221 else: 222 # In the non-evaluator case, we return the total number of "chief" and 223 # "worker" tasks as the "chief" is also a worker. 224 return (len(cluster_spec.get("chief", [])) + len( 225 cluster_spec.get("worker", []))) 226 227 228def id_in_cluster(cluster_spec, task_type, task_id): 229 """Returns a unique id for the task in the `task_type`'s cluster. 230 231 It returns an id ranging from [0, `worker_count(task_type, task_id)`). 232 233 Note: this function assumes that "evaluate" job is in its own cluster or its 234 own partition of a cluster. 235 236 Args: 237 cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated. 238 task_type: string indicating the type of the task. 239 task_id: the id of the `task_type` in this cluster. 240 241 Returns: 242 an int indicating the unique id. 243 244 Throws: 245 ValueError: if `task_type` is not "chief", "worker" or "evaluator". 246 """ 247 _validate_cluster_spec(cluster_spec, task_type, task_id) 248 cluster_spec = normalize_cluster_spec(cluster_spec).as_dict() 249 250 # The "chief" job has always id 0 and there is at most one and "worker" jobs 251 # come after it. 252 if task_type == "chief": 253 return 0 254 255 if task_type == "worker": 256 return task_id + len(cluster_spec.get("chief", [])) 257 258 # The "evaluator" is in its own cluster or its own partition of a cluster. 259 if task_type == "evaluator": 260 return task_id 261 262 # We currently don't assign ids to other tasks. 263 raise ValueError("There is no id for task_type %r" % task_type) 264 265 266def should_save_checkpoint(): 267 """Returns whether the current worker should save checkpoints. 268 269 In multi-worker training, if saving checkpoint is requested by user, or needed 270 for fault-tolerance, the cluster should save checkpoint but not necessarily 271 every worker in the cluster should. 272 273 TODO(rchao): Consider generalizing this util to be `should_save_file` as there 274 can be other files to save such as summary. 275 276 Returns: 277 Whether this particular worker in the cluster should save checkpoints. 278 """ 279 return dc_context.get_current_worker_context().should_checkpoint 280 281 282def should_load_checkpoint(): 283 """Returns whether the current worker should load checkpoints. 284 285 In multi-worker training, if loading checkpoint is requested by user, or 286 needed for fault-tolerance, the cluster should load checkpoint but not 287 necessarily every worker in the cluster should. 288 289 Returns: 290 Whether this particular worker in the cluster should load checkpoints. 291 """ 292 return dc_context.get_current_worker_context().experimental_should_init 293 294 295def wait_for_other_workers(): 296 """Waits for other workers to reach the same call to this method.""" 297 return dc_context.get_current_worker_context().wait_for_other_workers() 298 299 300def has_worker_context(): 301 """Returns whether a worker context has been entered.""" 302 return dc_context.get_current_worker_context() is not None 303