1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implementation of ClusterResolvers for GCE instance groups."""
16
17from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
18from tensorflow.python.training.server_lib import ClusterSpec
19from tensorflow.python.util.tf_export import tf_export
20
21
22_GOOGLE_API_CLIENT_INSTALLED = True
23try:
24  from googleapiclient import discovery  # pylint: disable=g-import-not-at-top
25  from oauth2client.client import GoogleCredentials  # pylint: disable=g-import-not-at-top
26except ImportError:
27  _GOOGLE_API_CLIENT_INSTALLED = False
28
29
30@tf_export('distribute.cluster_resolver.GCEClusterResolver')
31class GCEClusterResolver(ClusterResolver):
32  """ClusterResolver for Google Compute Engine.
33
34  This is an implementation of cluster resolvers for the Google Compute Engine
35  instance group platform. By specifying a project, zone, and instance group,
36  this will retrieve the IP address of all the instances within the instance
37  group and return a ClusterResolver object suitable for use for distributed
38  TensorFlow.
39
40  Note: this cluster resolver cannot retrieve `task_type`, `task_id` or
41  `rpc_layer`. To use it with some distribution strategies like
42  `tf.distribute.experimental.MultiWorkerMirroredStrategy`, you will need to
43  specify `task_type` and `task_id` in the constructor.
44
45  Usage example with tf.distribute.Strategy:
46
47    ```Python
48    # On worker 0
49    cluster_resolver = GCEClusterResolver("my-project", "us-west1",
50                                          "my-instance-group",
51                                          task_type="worker", task_id=0)
52    strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(
53        cluster_resolver=cluster_resolver)
54
55    # On worker 1
56    cluster_resolver = GCEClusterResolver("my-project", "us-west1",
57                                          "my-instance-group",
58                                          task_type="worker", task_id=1)
59    strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(
60        cluster_resolver=cluster_resolver)
61    ```
62  """
63
64  def __init__(self,
65               project,
66               zone,
67               instance_group,
68               port,
69               task_type='worker',
70               task_id=0,
71               rpc_layer='grpc',
72               credentials='default',
73               service=None):
74    """Creates a new GCEClusterResolver object.
75
76    This takes in a few parameters and creates a GCEClusterResolver project. It
77    will then use these parameters to query the GCE API for the IP addresses of
78    each instance in the instance group.
79
80    Args:
81      project: Name of the GCE project.
82      zone: Zone of the GCE instance group.
83      instance_group: Name of the GCE instance group.
84      port: Port of the listening TensorFlow server (default: 8470)
85      task_type: Name of the TensorFlow job this GCE instance group of VM
86        instances belong to.
87      task_id: The task index for this particular VM, within the GCE
88        instance group. In particular, every single instance should be assigned
89        a unique ordinal index within an instance group manually so that they
90        can be distinguished from each other.
91      rpc_layer: The RPC layer TensorFlow should use to communicate across
92        instances.
93      credentials: GCE Credentials. If nothing is specified, this defaults to
94        GoogleCredentials.get_application_default().
95      service: The GCE API object returned by the googleapiclient.discovery
96        function. (Default: discovery.build('compute', 'v1')). If you specify a
97        custom service object, then the credentials parameter will be ignored.
98
99    Raises:
100      ImportError: If the googleapiclient is not installed.
101    """
102    self._project = project
103    self._zone = zone
104    self._instance_group = instance_group
105    self._task_type = task_type
106    self._task_id = task_id
107    self._rpc_layer = rpc_layer
108    self._port = port
109    self._credentials = credentials
110
111    if credentials == 'default':
112      if _GOOGLE_API_CLIENT_INSTALLED:
113        self._credentials = GoogleCredentials.get_application_default()
114
115    if service is None:
116      if not _GOOGLE_API_CLIENT_INSTALLED:
117        raise ImportError('googleapiclient must be installed before using the '
118                          'GCE cluster resolver')
119      self._service = discovery.build(
120          'compute', 'v1',
121          credentials=self._credentials)
122    else:
123      self._service = service
124
125  def cluster_spec(self):
126    """Returns a ClusterSpec object based on the latest instance group info.
127
128    This returns a ClusterSpec object for use based on information from the
129    specified instance group. We will retrieve the information from the GCE APIs
130    every time this method is called.
131
132    Returns:
133      A ClusterSpec containing host information retrieved from GCE.
134    """
135    request_body = {'instanceState': 'RUNNING'}
136    request = self._service.instanceGroups().listInstances(
137        project=self._project,
138        zone=self._zone,
139        instanceGroups=self._instance_group,
140        body=request_body,
141        orderBy='name')
142
143    worker_list = []
144
145    while request is not None:
146      response = request.execute()
147
148      items = response['items']
149      for instance in items:
150        instance_name = instance['instance'].split('/')[-1]
151
152        instance_request = self._service.instances().get(
153            project=self._project,
154            zone=self._zone,
155            instance=instance_name)
156
157        if instance_request is not None:
158          instance_details = instance_request.execute()
159          ip_address = instance_details['networkInterfaces'][0]['networkIP']
160          instance_url = '%s:%s' % (ip_address, self._port)
161          worker_list.append(instance_url)
162
163      request = self._service.instanceGroups().listInstances_next(
164          previous_request=request,
165          previous_response=response)
166
167    worker_list.sort()
168    return ClusterSpec({self._task_type: worker_list})
169
170  def master(self, task_type=None, task_id=None, rpc_layer=None):
171    task_type = task_type if task_type is not None else self._task_type
172    task_id = task_id if task_id is not None else self._task_id
173
174    if task_type is not None and task_id is not None:
175      master = self.cluster_spec().task_address(task_type, task_id)
176      if rpc_layer or self._rpc_layer:
177        return '%s://%s' % (rpc_layer or self._rpc_layer, master)
178      else:
179        return master
180
181    return ''
182
183  @property
184  def task_type(self):
185    return self._task_type
186
187  @property
188  def task_id(self):
189    return self._task_id
190
191  @task_type.setter
192  def task_type(self, task_type):
193    raise RuntimeError(
194        'You cannot reset the task_type of the GCEClusterResolver after it has '
195        'been created.')
196
197  @task_id.setter
198  def task_id(self, task_id):
199    self._task_id = task_id
200
201  @property
202  def rpc_layer(self):
203    return self._rpc_layer
204
205  @rpc_layer.setter
206  def rpc_layer(self, rpc_layer):
207    self._rpc_layer = rpc_layer
208