xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/client/client.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Cloud TPU Client."""
16
17from concurrent import futures
18import datetime
19import json
20import logging
21import os
22import time
23import urllib
24
25from absl import flags
26
27_GOOGLE_API_CLIENT_INSTALLED = True
28try:
29  from googleapiclient import discovery  # pylint: disable=g-import-not-at-top
30  from oauth2client import client  # pylint: disable=g-import-not-at-top
31except ImportError:
32  _GOOGLE_API_CLIENT_INSTALLED = False
33
34FLAGS = flags.FLAGS
35
36flags.DEFINE_bool('runtime_oom_exit', True,
37                  'Exit the script when the TPU runtime is OOM.')
38flags.DEFINE_bool('hbm_oom_exit', True,
39                  'Exit the script when the TPU HBM is OOM.')
40
41_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'
42_DEFAULT_TPUCONFIG_VARIABLE = 'TPU_CONFIG'
43_ENDPOINTS_SEPARATOR = ','
44_DEFAULT_ENV_VARIABLE = 'TPU_NAME'
45_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL'
46_GCE_METADATA_URL_ENV_VARIABLE = 'GCE_METADATA_IP'
47_DEFAULT_ENDPOINT_PORT = '8470'
48_OOM_EVENT_COOL_TIME_SEC = 90
49_VERSION_SWITCHER_ENDPOINT = 'http://{}:8475/requestversion'
50
51
52def _utcnow():
53  """A wrapper function around datetime.datetime.utcnow.
54
55  This function is created for unit testing purpose. It's not easy to do
56  StubOutWithMock with datetime.datetime package.
57
58  Returns:
59    datetime.datetime
60  """
61  return datetime.datetime.utcnow()
62
63
64def _environment_discovery_url():
65  return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE)
66
67
68def _gce_metadata_endpoint():
69  return 'http://' + os.environ.get(_GCE_METADATA_URL_ENV_VARIABLE,
70                                    'metadata.google.internal')
71
72
73def _request_compute_metadata(path):
74  req = urllib.request.Request(
75      '%s/computeMetadata/v1/%s' % (_gce_metadata_endpoint(), path),
76      headers={'Metadata-Flavor': 'Google'})
77  resp = urllib.request.urlopen(req)
78  return _as_text(resp.read())
79
80
81def _environment_var_to_network_endpoints(endpoints):
82  """Yields a dict with ip address and port."""
83  for endpoint in endpoints.split(','):
84    grpc_prefix = 'grpc://'
85    if endpoint.startswith(grpc_prefix):
86      endpoint = endpoint.split(grpc_prefix)[1]
87    parts = endpoint.split(':')
88    ip_address = parts[0]
89    port = _DEFAULT_ENDPOINT_PORT
90    if len(parts) > 1:
91      port = parts[1]
92    yield {
93        'ipAddress': ip_address,
94        'port': port
95    }
96
97
98def _get_tpu_node_config():
99  tpu_config_env = os.environ.get(_DEFAULT_TPUCONFIG_VARIABLE)
100  if tpu_config_env:
101    return json.loads(tpu_config_env)
102  return None
103
104
105def _get_tpu_name(tpu):
106  if tpu:
107    return tpu
108
109  for e in [_GKE_ENV_VARIABLE, _DEFAULT_ENV_VARIABLE]:
110    if e in os.environ:
111      return os.environ[e]
112  return None
113
114
115def _as_text(s):
116  if isinstance(s, bytes):
117    return s.decode('utf-8')
118  return s
119
120
121class Client:
122  """Client for working with the Cloud TPU API.
123
124  This client is intended to be used for resolving tpu name to ip addresses.
125
126  It's recommended to use this library as a contextlib to utilize all
127  functionality.
128  """
129
130  def __init__(self,
131               tpu=None,
132               zone=None,
133               project=None,
134               credentials='default',
135               service=None,
136               discovery_url=None):
137    if isinstance(tpu, list):
138      if not tpu:
139        raise ValueError('At least one TPU must be specified.')
140      if len(tpu) != 1:
141        raise NotImplementedError(
142            'Using multiple TPUs in a single session is not yet implemented')
143      tpu = tpu[0]
144
145    tpu = _get_tpu_name(tpu)
146
147    if tpu is None:
148      tpu_node_config = _get_tpu_node_config()
149      if tpu_node_config:
150        tpu = tpu_node_config.get('tpu_node_name')
151        project = project or tpu_node_config.get('project')
152        zone = zone or tpu_node_config.get('zone')
153      else:
154        raise ValueError('Please provide a TPU Name to connect to.')
155
156    self._tpu = _as_text(tpu)
157
158    self._use_api = not self._tpu.startswith('grpc://')
159    self._service = service
160
161    self._credentials = None
162    self._project = None
163    self._zone = None
164    self._discovery_url = None
165    if self._use_api:
166      if credentials != 'default':
167        self._credentials = credentials
168      # Automatically detect project and zone if unspecified.
169      if project:
170        self._project = _as_text(project)
171      else:
172        self._project = _request_compute_metadata('project/project-id')
173      if zone:
174        self._zone = _as_text(zone)
175      else:
176        zone_path = _request_compute_metadata('instance/zone')
177        self._zone = zone_path.split('/')[-1]
178      self._discovery_url = _environment_discovery_url() or discovery_url
179
180  def _symptom_msg(self, msg):
181    """Return the structured Symptom message."""
182    return 'Symptom: ' + msg
183
184  def _oom_event(self, symptoms):
185    """Check if a runtime OOM event is reported."""
186    if not symptoms:
187      return False
188    for symptom in reversed(symptoms):
189      if symptom['symptomType'] != 'OUT_OF_MEMORY':
190        continue
191      oom_datetime_str = symptom['createTime'].split('.')[0]
192      oom_datetime = datetime.datetime.strptime(oom_datetime_str,
193                                                '%Y-%m-%dT%H:%M:%S')
194      time_diff = _utcnow() - oom_datetime
195      if time_diff < datetime.timedelta(seconds=_OOM_EVENT_COOL_TIME_SEC):
196        logging.warning(
197            self._symptom_msg(
198                'a recent runtime OOM has occurred ~{} seconds ago. The model '
199                'script will terminate automatically. To prevent future OOM '
200                'events, please consider reducing the model size. To disable this '
201                'behavior, set flag --runtime_oom_exit=false when starting the '
202                'script.'.format(time_diff.seconds)))
203        return True
204    return False
205
206  def _hbm_oom_event(self, symptoms):
207    """Check if a HBM OOM event is reported."""
208    if not symptoms:
209      return False
210    for symptom in reversed(symptoms):
211      if symptom['symptomType'] != 'HBM_OUT_OF_MEMORY':
212        continue
213      oom_datetime_str = symptom['createTime'].split('.')[0]
214      oom_datetime = datetime.datetime.strptime(oom_datetime_str,
215                                                '%Y-%m-%dT%H:%M:%S')
216      time_diff = _utcnow() - oom_datetime
217      if time_diff < datetime.timedelta(seconds=_OOM_EVENT_COOL_TIME_SEC):
218        logging.warning(
219            self._symptom_msg(
220                'a recent HBM OOM has occurred ~{} seconds ago. The model '
221                'script will terminate automatically. To prevent future HBM OOM '
222                'events, please consider reducing the model size. To disable this '
223                'behavior, set flag --hbm_oom_exit=false when starting the '
224                'script.'.format(time_diff.seconds)))
225        return True
226    return False
227
228  def _tpu_service(self):
229    """Creates a new Cloud TPU API object.
230
231    This works around an issue where the underlying HTTP connection sometimes
232    times out when the script has been running for too long. Other methods in
233    this object call this method to get a new API object whenever they need
234    to communicate with the Cloud API.
235
236    Raises:
237      RuntimeError: If the dependent Python packages are missing.
238
239    Returns:
240      A Google Cloud TPU API object.
241    """
242    if self._service:
243      return self._service
244
245    if not _GOOGLE_API_CLIENT_INSTALLED:
246      raise RuntimeError('Missing runtime dependency on the Google API client. '
247                         'Run `pip install cloud-tpu-client` to fix.')
248
249    credentials = self._credentials
250    if credentials is None or credentials == 'default':
251      credentials = client.GoogleCredentials.get_application_default()
252
253    if self._discovery_url:
254      return discovery.build(
255          'tpu',
256          'v1',
257          credentials=credentials,
258          discoveryServiceUrl=self._discovery_url,
259          cache_discovery=False)
260    else:
261      return discovery.build(
262          'tpu', 'v1', credentials=credentials, cache_discovery=False)
263
264  def _full_name(self):
265    """Returns the full Cloud name for this TPU."""
266    return 'projects/%s/locations/%s/nodes/%s' % (
267        self._project, self._zone, self._tpu)
268
269  def _fetch_cloud_tpu_metadata(self):
270    """Returns the TPU metadata object from the TPU Get API call."""
271    service = self._tpu_service()
272    try:
273      r = service.projects().locations().nodes().get(name=self._full_name())
274      return r.execute()
275    except Exception as e:
276      raise ValueError("Could not lookup TPU metadata from name '%s'. Please "
277                       'doublecheck the tpu argument in the TPUClusterResolver '
278                       'constructor. Exception: %s' % (self._tpu, e))
279
280  def _get_tpu_property(self, key):
281    if self._use_api:
282      metadata = self._fetch_cloud_tpu_metadata()
283      return metadata.get(key)
284
285    return None
286
287  def __enter__(self):
288    self._open = True
289
290  def __exit__(self, type, value, traceback):  # pylint: disable=redefined-builtin
291    del type, value, traceback
292
293  def recoverable(self):
294    """Returns true if the TPU is in a state where training should eventually resume.
295
296    If false the TPU is in a unrecoverable state and should be recreated.
297    """
298    state = self.state()
299    symptoms = self.symptoms()
300    if state and state in ['TERMINATED', 'PREEMPTED']:
301      return False
302    elif FLAGS.runtime_oom_exit and self._oom_event(symptoms):
303      return False
304    elif FLAGS.hbm_oom_exit and self._hbm_oom_event(symptoms):
305      return False
306    return True
307
308  def symptoms(self):
309    """Return Cloud TPU Symptoms of the TPU."""
310    return self._get_tpu_property('symptoms')
311
312  def state(self):
313    """Return state of the TPU."""
314    return self._get_tpu_property('state')
315
316  def health(self):
317    """Return health of the TPU."""
318    return self._get_tpu_property('health')
319
320  def runtime_version(self):
321    """Return runtime version of the TPU."""
322
323    if not self._use_api:
324      # Fallback on getting version directly from TPU.
325      url = _VERSION_SWITCHER_ENDPOINT.format(
326          self.network_endpoints()[0]['ipAddress'])
327      try:
328        req = urllib.request.Request(url)
329        resp = urllib.request.urlopen(req)
330        version_details = json.loads(resp.read())
331        return version_details.get('currentVersion')
332      except urllib.error.HTTPError as e:
333        status_code = e.code
334        if status_code == 404:
335          return None
336        else:
337          raise e
338    return self._get_tpu_property('tensorflowVersion')
339
340  def accelerator_type(self):
341    """Return accelerator type of the TPU."""
342    return self._get_tpu_property('acceleratorType')
343
344  def api_available(self):
345    """Return if the Cloud TPU API is available, if not certain features will not work."""
346    return self._use_api
347
348  def name(self):
349    """Return the name of the tpu, or the ip address if name is not provided."""
350    return self._tpu
351
352  def get_local_ip(self):
353    """Return the local ip address of the Google Cloud VM the workload is running on."""
354    return _request_compute_metadata('instance/network-interfaces/0/ip')
355
356  def network_endpoints(self):
357    """Return a list of tpu endpoints."""
358    if not self._use_api:
359      return list(_environment_var_to_network_endpoints(self._tpu))
360    response = self._fetch_cloud_tpu_metadata()
361
362    if response.get('state') != 'READY':
363      raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' %
364                         (self._tpu, response.get('state')))
365    if 'networkEndpoints' in response:
366      return response['networkEndpoints']
367    else:
368      return [{'ipAddress': response['ipAddress'], 'port': response['port']}]
369
370  def wait_for_healthy(self, timeout_s=1200, interval=30):
371    """Wait for TPU to become healthy or raise error if timeout reached.
372
373    Args:
374      timeout_s (int): The timeout in seconds for waiting TPU to become healthy.
375      interval (int): The interval in seconds to poll the TPU for health.
376
377    Raises:
378      RuntimeError: If the TPU doesn't become healthy by the timeout.
379    """
380    timeout = time.time() + timeout_s
381    while self.health() != 'HEALTHY':
382      logging.warning(
383          ('Waiting for TPU "%s" with state "%s" '
384           'and health "%s" to become healthy'),
385          self.name(), self.state(), self.health())
386      if time.time() + interval > timeout:
387        raise RuntimeError(
388            'Timed out waiting for TPU "%s" to become healthy' % self.name())
389      time.sleep(interval)
390
391    logging.warning('TPU "%s" is healthy.', self.name())
392
393  def configure_tpu_version(self, version, restart_type='always'):
394    """Configure TPU software version.
395
396    Args:
397      version (string): Version of software to configure the TPU with.
398      restart_type (string): Restart behaviour when switching versions,
399        defaults to always restart. Options are 'always', 'ifNeeded'.
400
401    """
402
403    def configure_worker(worker):
404      """Configure individual TPU worker.
405
406      Args:
407        worker: A dict with the field ipAddress where the configure request will
408          be sent.
409      """
410      ip_address = worker['ipAddress']
411      url = (_VERSION_SWITCHER_ENDPOINT + '/{}?restartType={}').format(
412          ip_address, version, restart_type)
413      req = urllib.request.Request(url, data=b'')
414      try:
415        urllib.request.urlopen(req)
416      except urllib.error.HTTPError as e:
417        status_code = e.code
418        if status_code == 404:
419          raise Exception(
420              'Tensorflow version {} is not available on Cloud TPU, '
421              'try a previous nightly version or refer to '
422              'https://cloud.google.com/tpu/docs/release-notes for '
423              'the latest official version.'.format(version))
424        else:
425          raise Exception('Failed to configure worker {}'.format(ip_address))
426
427    workers = self.network_endpoints()
428
429    with futures.ThreadPoolExecutor(max_workers=len(workers)) as executor:
430      results = executor.map(configure_worker, workers)
431      for result in results:
432        if result:
433          result.result()
434