1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Implementation of ClusterResolvers for GCE instance groups.""" 16 17from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver 18from tensorflow.python.training.server_lib import ClusterSpec 19from tensorflow.python.util.tf_export import tf_export 20 21 22_GOOGLE_API_CLIENT_INSTALLED = True 23try: 24 from googleapiclient import discovery # pylint: disable=g-import-not-at-top 25 from oauth2client.client import GoogleCredentials # pylint: disable=g-import-not-at-top 26except ImportError: 27 _GOOGLE_API_CLIENT_INSTALLED = False 28 29 30@tf_export('distribute.cluster_resolver.GCEClusterResolver') 31class GCEClusterResolver(ClusterResolver): 32 """ClusterResolver for Google Compute Engine. 33 34 This is an implementation of cluster resolvers for the Google Compute Engine 35 instance group platform. By specifying a project, zone, and instance group, 36 this will retrieve the IP address of all the instances within the instance 37 group and return a ClusterResolver object suitable for use for distributed 38 TensorFlow. 39 40 Note: this cluster resolver cannot retrieve `task_type`, `task_id` or 41 `rpc_layer`. To use it with some distribution strategies like 42 `tf.distribute.experimental.MultiWorkerMirroredStrategy`, you will need to 43 specify `task_type` and `task_id` in the constructor. 44 45 Usage example with tf.distribute.Strategy: 46 47 ```Python 48 # On worker 0 49 cluster_resolver = GCEClusterResolver("my-project", "us-west1", 50 "my-instance-group", 51 task_type="worker", task_id=0) 52 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy( 53 cluster_resolver=cluster_resolver) 54 55 # On worker 1 56 cluster_resolver = GCEClusterResolver("my-project", "us-west1", 57 "my-instance-group", 58 task_type="worker", task_id=1) 59 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy( 60 cluster_resolver=cluster_resolver) 61 ``` 62 """ 63 64 def __init__(self, 65 project, 66 zone, 67 instance_group, 68 port, 69 task_type='worker', 70 task_id=0, 71 rpc_layer='grpc', 72 credentials='default', 73 service=None): 74 """Creates a new GCEClusterResolver object. 75 76 This takes in a few parameters and creates a GCEClusterResolver project. It 77 will then use these parameters to query the GCE API for the IP addresses of 78 each instance in the instance group. 79 80 Args: 81 project: Name of the GCE project. 82 zone: Zone of the GCE instance group. 83 instance_group: Name of the GCE instance group. 84 port: Port of the listening TensorFlow server (default: 8470) 85 task_type: Name of the TensorFlow job this GCE instance group of VM 86 instances belong to. 87 task_id: The task index for this particular VM, within the GCE 88 instance group. In particular, every single instance should be assigned 89 a unique ordinal index within an instance group manually so that they 90 can be distinguished from each other. 91 rpc_layer: The RPC layer TensorFlow should use to communicate across 92 instances. 93 credentials: GCE Credentials. If nothing is specified, this defaults to 94 GoogleCredentials.get_application_default(). 95 service: The GCE API object returned by the googleapiclient.discovery 96 function. (Default: discovery.build('compute', 'v1')). If you specify a 97 custom service object, then the credentials parameter will be ignored. 98 99 Raises: 100 ImportError: If the googleapiclient is not installed. 101 """ 102 self._project = project 103 self._zone = zone 104 self._instance_group = instance_group 105 self._task_type = task_type 106 self._task_id = task_id 107 self._rpc_layer = rpc_layer 108 self._port = port 109 self._credentials = credentials 110 111 if credentials == 'default': 112 if _GOOGLE_API_CLIENT_INSTALLED: 113 self._credentials = GoogleCredentials.get_application_default() 114 115 if service is None: 116 if not _GOOGLE_API_CLIENT_INSTALLED: 117 raise ImportError('googleapiclient must be installed before using the ' 118 'GCE cluster resolver') 119 self._service = discovery.build( 120 'compute', 'v1', 121 credentials=self._credentials) 122 else: 123 self._service = service 124 125 def cluster_spec(self): 126 """Returns a ClusterSpec object based on the latest instance group info. 127 128 This returns a ClusterSpec object for use based on information from the 129 specified instance group. We will retrieve the information from the GCE APIs 130 every time this method is called. 131 132 Returns: 133 A ClusterSpec containing host information retrieved from GCE. 134 """ 135 request_body = {'instanceState': 'RUNNING'} 136 request = self._service.instanceGroups().listInstances( 137 project=self._project, 138 zone=self._zone, 139 instanceGroups=self._instance_group, 140 body=request_body, 141 orderBy='name') 142 143 worker_list = [] 144 145 while request is not None: 146 response = request.execute() 147 148 items = response['items'] 149 for instance in items: 150 instance_name = instance['instance'].split('/')[-1] 151 152 instance_request = self._service.instances().get( 153 project=self._project, 154 zone=self._zone, 155 instance=instance_name) 156 157 if instance_request is not None: 158 instance_details = instance_request.execute() 159 ip_address = instance_details['networkInterfaces'][0]['networkIP'] 160 instance_url = '%s:%s' % (ip_address, self._port) 161 worker_list.append(instance_url) 162 163 request = self._service.instanceGroups().listInstances_next( 164 previous_request=request, 165 previous_response=response) 166 167 worker_list.sort() 168 return ClusterSpec({self._task_type: worker_list}) 169 170 def master(self, task_type=None, task_id=None, rpc_layer=None): 171 task_type = task_type if task_type is not None else self._task_type 172 task_id = task_id if task_id is not None else self._task_id 173 174 if task_type is not None and task_id is not None: 175 master = self.cluster_spec().task_address(task_type, task_id) 176 if rpc_layer or self._rpc_layer: 177 return '%s://%s' % (rpc_layer or self._rpc_layer, master) 178 else: 179 return master 180 181 return '' 182 183 @property 184 def task_type(self): 185 return self._task_type 186 187 @property 188 def task_id(self): 189 return self._task_id 190 191 @task_type.setter 192 def task_type(self, task_type): 193 raise RuntimeError( 194 'You cannot reset the task_type of the GCEClusterResolver after it has ' 195 'been created.') 196 197 @task_id.setter 198 def task_id(self, task_id): 199 self._task_id = task_id 200 201 @property 202 def rpc_layer(self): 203 return self._rpc_layer 204 205 @rpc_layer.setter 206 def rpc_layer(self, rpc_layer): 207 self._rpc_layer = rpc_layer 208