1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Device function for replicated training.""" 16from tensorflow.core.framework import node_def_pb2 17from tensorflow.python.framework import device as pydev 18from tensorflow.python.platform import tf_logging as logging 19from tensorflow.python.training import server_lib 20from tensorflow.python.util.tf_export import tf_export 21 22# This is a tuple of PS ops used by tf.estimator.Estimator which should work in 23# almost all of cases. 24STANDARD_PS_OPS = ("Variable", "VariableV2", "AutoReloadVariable", 25 "MutableHashTable", "MutableHashTableV2", 26 "MutableHashTableOfTensors", "MutableHashTableOfTensorsV2", 27 "MutableDenseHashTable", "MutableDenseHashTableV2", 28 "VarHandleOp", "BoostedTreesEnsembleResourceHandleOp", 29 "BoostedTreesQuantileStreamResourceHandleOp", 30 "ResourceConditionalAccumulator", 31 "DecisionTreeResource") 32 33 34class _RoundRobinStrategy: 35 """Returns the next ps task index for placement in round-robin order. 36 37 This class is not to be used directly by users. See instead 38 `replica_device_setter()` below. 39 """ 40 41 def __init__(self, num_tasks): 42 """Create a new `_RoundRobinStrategy`. 43 44 Args: 45 num_tasks: Number of ps tasks to cycle among. 46 """ 47 self._num_tasks = num_tasks 48 self._next_task = 0 49 50 def __call__(self, unused_op): 51 """Choose a ps task index for the given `Operation`. 52 53 Args: 54 unused_op: An `Operation` to be placed on ps. 55 56 Returns: 57 The next ps task index to use for the `Operation`. Returns the next 58 index, in the range `[offset, offset + num_tasks)`. 59 """ 60 task = self._next_task 61 self._next_task = (self._next_task + 1) % self._num_tasks 62 return task 63 64 65class _ReplicaDeviceChooser: 66 """Class to choose devices for Ops in a replicated training setup. 67 68 This class is not to be used directly by users. See instead 69 `replica_device_setter()` below. 70 """ 71 72 def __init__(self, ps_tasks, ps_device, worker_device, merge_devices, ps_ops, 73 ps_strategy): 74 """Create a new `_ReplicaDeviceChooser`. 75 76 Args: 77 ps_tasks: Number of tasks in the `ps` job. 78 ps_device: String. Name of the `ps` job. 79 worker_device: String. Name of the `worker` job. 80 merge_devices: Boolean. Set to True to allow merging of device specs. 81 ps_ops: List of strings representing `Operation` types that need to be 82 placed on `ps` devices. 83 ps_strategy: A callable invoked for every ps `Operation` (i.e. matched by 84 `ps_ops`), that takes the `Operation` and returns the ps task index to 85 use. 86 """ 87 self._ps_tasks = ps_tasks 88 self._ps_device = ps_device 89 self._worker_device = worker_device 90 self._merge_devices = merge_devices 91 self._ps_ops = ps_ops 92 self._ps_strategy = ps_strategy 93 94 def device_function(self, op): 95 """Choose a device for `op`. 96 97 Args: 98 op: an `Operation`. 99 100 Returns: 101 The device to use for the `Operation`. 102 """ 103 # If we don't return early here, either merge_devices is True, or op.device 104 # is empty (in which case merging is a no-op). So we can always merge below. 105 if not self._merge_devices and op.device: 106 return op.device 107 108 current_device = pydev.DeviceSpec.from_string(op.device or "") 109 110 # The ps_device will be used for specified ops (ps_ops) whenever it is 111 # present and ps_tasks is non-zero. However, its task number will only be 112 # set (using ps_strategy) if there is a job field in ps_device that won't be 113 # changed by the job field (if present) in current_device. 114 node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def 115 if self._ps_tasks and self._ps_device and node_def.op in self._ps_ops: 116 ps_device = pydev.DeviceSpec.from_string(self._ps_device) 117 118 current_job, ps_job = current_device.job, ps_device.job 119 if ps_job and (not current_job or current_job == ps_job): 120 ps_device = ps_device.replace(task=self._ps_strategy(op)) 121 122 ps_device = ps_device.make_merged_spec(current_device) 123 return ps_device.to_string() 124 125 worker_device = pydev.DeviceSpec.from_string(self._worker_device or "") 126 worker_device = worker_device.make_merged_spec(current_device) 127 return worker_device.to_string() 128 129 130@tf_export(v1=["train.replica_device_setter"]) 131def replica_device_setter(ps_tasks=0, 132 ps_device="/job:ps", 133 worker_device="/job:worker", 134 merge_devices=True, 135 cluster=None, 136 ps_ops=None, 137 ps_strategy=None): 138 """Return a `device function` to use when building a Graph for replicas. 139 140 Device Functions are used in `with tf.device(device_function):` statement to 141 automatically assign devices to `Operation` objects as they are constructed, 142 Device constraints are added from the inner-most context first, working 143 outwards. The merging behavior adds constraints to fields that are yet unset 144 by a more inner context. Currently the fields are (job, task, cpu/gpu). 145 146 If `cluster` is `None`, and `ps_tasks` is 0, the returned function is a no-op. 147 Otherwise, the value of `ps_tasks` is derived from `cluster`. 148 149 By default, only Variable ops are placed on ps tasks, and the placement 150 strategy is round-robin over all ps tasks. A custom `ps_strategy` may be used 151 to do more intelligent placement, such as 152 `tf.contrib.training.GreedyLoadBalancingStrategy`. 153 154 For example, 155 156 ```python 157 # To build a cluster with two ps jobs on hosts ps0 and ps1, and 3 worker 158 # jobs on hosts worker0, worker1 and worker2. 159 cluster_spec = { 160 "ps": ["ps0:2222", "ps1:2222"], 161 "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]} 162 with 163 tf.compat.v1.device(tf.compat.v1.train.replica_device_setter(cluster=cluster_spec)): 164 # Build your graph 165 v1 = tf.Variable(...) # assigned to /job:ps/task:0 166 v2 = tf.Variable(...) # assigned to /job:ps/task:1 167 v3 = tf.Variable(...) # assigned to /job:ps/task:0 168 # Run compute 169 ``` 170 171 Args: 172 ps_tasks: Number of tasks in the `ps` job. Ignored if `cluster` is 173 provided. 174 ps_device: String. Device of the `ps` job. If empty no `ps` job is used. 175 Defaults to `ps`. 176 worker_device: String. Device of the `worker` job. If empty no `worker` 177 job is used. 178 merge_devices: `Boolean`. If `True`, merges or only sets a device if the 179 device constraint is completely unset. merges device specification rather 180 than overriding them. 181 cluster: `ClusterDef` proto or `ClusterSpec`. 182 ps_ops: List of strings representing `Operation` types that need to be 183 placed on `ps` devices. If `None`, defaults to `STANDARD_PS_OPS`. 184 ps_strategy: A callable invoked for every ps `Operation` (i.e. matched by 185 `ps_ops`), that takes the `Operation` and returns the ps task index to 186 use. If `None`, defaults to a round-robin strategy across all `ps` 187 devices. 188 189 Returns: 190 A function to pass to `tf.device()`. 191 192 Raises: 193 TypeError if `cluster` is not a dictionary or `ClusterDef` protocol buffer, 194 or if `ps_strategy` is provided but not a callable. 195 """ 196 if cluster is not None: 197 if isinstance(cluster, server_lib.ClusterSpec): 198 cluster_spec = cluster.as_dict() 199 else: 200 cluster_spec = server_lib.ClusterSpec(cluster).as_dict() 201 # Get ps_job_name from ps_device by stripping "/job:". 202 ps_job_name = pydev.DeviceSpec.from_string(ps_device).job 203 if ps_job_name not in cluster_spec or cluster_spec[ps_job_name] is None: 204 return None 205 ps_tasks = len(cluster_spec[ps_job_name]) 206 207 if ps_tasks == 0: 208 return None 209 210 if ps_ops is None: 211 # TODO(sherrym): Variables in the LOCAL_VARIABLES collection should not be 212 # placed in the parameter server. 213 ps_ops = list(STANDARD_PS_OPS) 214 215 if not merge_devices: 216 logging.warning( 217 "DEPRECATION: It is recommended to set merge_devices=true in " 218 "replica_device_setter") 219 if ps_strategy is None: 220 ps_strategy = _RoundRobinStrategy(ps_tasks) 221 if not callable(ps_strategy): 222 raise TypeError("ps_strategy must be callable") 223 chooser = _ReplicaDeviceChooser(ps_tasks, ps_device, worker_device, 224 merge_devices, ps_ops, ps_strategy) 225 return chooser.device_function 226