1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Cloud TPU Client.""" 16 17from concurrent import futures 18import datetime 19import json 20import logging 21import os 22import time 23import urllib 24 25from absl import flags 26 27_GOOGLE_API_CLIENT_INSTALLED = True 28try: 29 from googleapiclient import discovery # pylint: disable=g-import-not-at-top 30 from oauth2client import client # pylint: disable=g-import-not-at-top 31except ImportError: 32 _GOOGLE_API_CLIENT_INSTALLED = False 33 34FLAGS = flags.FLAGS 35 36flags.DEFINE_bool('runtime_oom_exit', True, 37 'Exit the script when the TPU runtime is OOM.') 38flags.DEFINE_bool('hbm_oom_exit', True, 39 'Exit the script when the TPU HBM is OOM.') 40 41_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' 42_DEFAULT_TPUCONFIG_VARIABLE = 'TPU_CONFIG' 43_ENDPOINTS_SEPARATOR = ',' 44_DEFAULT_ENV_VARIABLE = 'TPU_NAME' 45_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL' 46_GCE_METADATA_URL_ENV_VARIABLE = 'GCE_METADATA_IP' 47_DEFAULT_ENDPOINT_PORT = '8470' 48_OOM_EVENT_COOL_TIME_SEC = 90 49_VERSION_SWITCHER_ENDPOINT = 'http://{}:8475/requestversion' 50 51 52def _utcnow(): 53 """A wrapper function around datetime.datetime.utcnow. 54 55 This function is created for unit testing purpose. It's not easy to do 56 StubOutWithMock with datetime.datetime package. 57 58 Returns: 59 datetime.datetime 60 """ 61 return datetime.datetime.utcnow() 62 63 64def _environment_discovery_url(): 65 return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE) 66 67 68def _gce_metadata_endpoint(): 69 return 'http://' + os.environ.get(_GCE_METADATA_URL_ENV_VARIABLE, 70 'metadata.google.internal') 71 72 73def _request_compute_metadata(path): 74 req = urllib.request.Request( 75 '%s/computeMetadata/v1/%s' % (_gce_metadata_endpoint(), path), 76 headers={'Metadata-Flavor': 'Google'}) 77 resp = urllib.request.urlopen(req) 78 return _as_text(resp.read()) 79 80 81def _environment_var_to_network_endpoints(endpoints): 82 """Yields a dict with ip address and port.""" 83 for endpoint in endpoints.split(','): 84 grpc_prefix = 'grpc://' 85 if endpoint.startswith(grpc_prefix): 86 endpoint = endpoint.split(grpc_prefix)[1] 87 parts = endpoint.split(':') 88 ip_address = parts[0] 89 port = _DEFAULT_ENDPOINT_PORT 90 if len(parts) > 1: 91 port = parts[1] 92 yield { 93 'ipAddress': ip_address, 94 'port': port 95 } 96 97 98def _get_tpu_node_config(): 99 tpu_config_env = os.environ.get(_DEFAULT_TPUCONFIG_VARIABLE) 100 if tpu_config_env: 101 return json.loads(tpu_config_env) 102 return None 103 104 105def _get_tpu_name(tpu): 106 if tpu: 107 return tpu 108 109 for e in [_GKE_ENV_VARIABLE, _DEFAULT_ENV_VARIABLE]: 110 if e in os.environ: 111 return os.environ[e] 112 return None 113 114 115def _as_text(s): 116 if isinstance(s, bytes): 117 return s.decode('utf-8') 118 return s 119 120 121class Client: 122 """Client for working with the Cloud TPU API. 123 124 This client is intended to be used for resolving tpu name to ip addresses. 125 126 It's recommended to use this library as a contextlib to utilize all 127 functionality. 128 """ 129 130 def __init__(self, 131 tpu=None, 132 zone=None, 133 project=None, 134 credentials='default', 135 service=None, 136 discovery_url=None): 137 if isinstance(tpu, list): 138 if not tpu: 139 raise ValueError('At least one TPU must be specified.') 140 if len(tpu) != 1: 141 raise NotImplementedError( 142 'Using multiple TPUs in a single session is not yet implemented') 143 tpu = tpu[0] 144 145 tpu = _get_tpu_name(tpu) 146 147 if tpu is None: 148 tpu_node_config = _get_tpu_node_config() 149 if tpu_node_config: 150 tpu = tpu_node_config.get('tpu_node_name') 151 project = project or tpu_node_config.get('project') 152 zone = zone or tpu_node_config.get('zone') 153 else: 154 raise ValueError('Please provide a TPU Name to connect to.') 155 156 self._tpu = _as_text(tpu) 157 158 self._use_api = not self._tpu.startswith('grpc://') 159 self._service = service 160 161 self._credentials = None 162 self._project = None 163 self._zone = None 164 self._discovery_url = None 165 if self._use_api: 166 if credentials != 'default': 167 self._credentials = credentials 168 # Automatically detect project and zone if unspecified. 169 if project: 170 self._project = _as_text(project) 171 else: 172 self._project = _request_compute_metadata('project/project-id') 173 if zone: 174 self._zone = _as_text(zone) 175 else: 176 zone_path = _request_compute_metadata('instance/zone') 177 self._zone = zone_path.split('/')[-1] 178 self._discovery_url = _environment_discovery_url() or discovery_url 179 180 def _symptom_msg(self, msg): 181 """Return the structured Symptom message.""" 182 return 'Symptom: ' + msg 183 184 def _oom_event(self, symptoms): 185 """Check if a runtime OOM event is reported.""" 186 if not symptoms: 187 return False 188 for symptom in reversed(symptoms): 189 if symptom['symptomType'] != 'OUT_OF_MEMORY': 190 continue 191 oom_datetime_str = symptom['createTime'].split('.')[0] 192 oom_datetime = datetime.datetime.strptime(oom_datetime_str, 193 '%Y-%m-%dT%H:%M:%S') 194 time_diff = _utcnow() - oom_datetime 195 if time_diff < datetime.timedelta(seconds=_OOM_EVENT_COOL_TIME_SEC): 196 logging.warning( 197 self._symptom_msg( 198 'a recent runtime OOM has occurred ~{} seconds ago. The model ' 199 'script will terminate automatically. To prevent future OOM ' 200 'events, please consider reducing the model size. To disable this ' 201 'behavior, set flag --runtime_oom_exit=false when starting the ' 202 'script.'.format(time_diff.seconds))) 203 return True 204 return False 205 206 def _hbm_oom_event(self, symptoms): 207 """Check if a HBM OOM event is reported.""" 208 if not symptoms: 209 return False 210 for symptom in reversed(symptoms): 211 if symptom['symptomType'] != 'HBM_OUT_OF_MEMORY': 212 continue 213 oom_datetime_str = symptom['createTime'].split('.')[0] 214 oom_datetime = datetime.datetime.strptime(oom_datetime_str, 215 '%Y-%m-%dT%H:%M:%S') 216 time_diff = _utcnow() - oom_datetime 217 if time_diff < datetime.timedelta(seconds=_OOM_EVENT_COOL_TIME_SEC): 218 logging.warning( 219 self._symptom_msg( 220 'a recent HBM OOM has occurred ~{} seconds ago. The model ' 221 'script will terminate automatically. To prevent future HBM OOM ' 222 'events, please consider reducing the model size. To disable this ' 223 'behavior, set flag --hbm_oom_exit=false when starting the ' 224 'script.'.format(time_diff.seconds))) 225 return True 226 return False 227 228 def _tpu_service(self): 229 """Creates a new Cloud TPU API object. 230 231 This works around an issue where the underlying HTTP connection sometimes 232 times out when the script has been running for too long. Other methods in 233 this object call this method to get a new API object whenever they need 234 to communicate with the Cloud API. 235 236 Raises: 237 RuntimeError: If the dependent Python packages are missing. 238 239 Returns: 240 A Google Cloud TPU API object. 241 """ 242 if self._service: 243 return self._service 244 245 if not _GOOGLE_API_CLIENT_INSTALLED: 246 raise RuntimeError('Missing runtime dependency on the Google API client. ' 247 'Run `pip install cloud-tpu-client` to fix.') 248 249 credentials = self._credentials 250 if credentials is None or credentials == 'default': 251 credentials = client.GoogleCredentials.get_application_default() 252 253 if self._discovery_url: 254 return discovery.build( 255 'tpu', 256 'v1', 257 credentials=credentials, 258 discoveryServiceUrl=self._discovery_url, 259 cache_discovery=False) 260 else: 261 return discovery.build( 262 'tpu', 'v1', credentials=credentials, cache_discovery=False) 263 264 def _full_name(self): 265 """Returns the full Cloud name for this TPU.""" 266 return 'projects/%s/locations/%s/nodes/%s' % ( 267 self._project, self._zone, self._tpu) 268 269 def _fetch_cloud_tpu_metadata(self): 270 """Returns the TPU metadata object from the TPU Get API call.""" 271 service = self._tpu_service() 272 try: 273 r = service.projects().locations().nodes().get(name=self._full_name()) 274 return r.execute() 275 except Exception as e: 276 raise ValueError("Could not lookup TPU metadata from name '%s'. Please " 277 'doublecheck the tpu argument in the TPUClusterResolver ' 278 'constructor. Exception: %s' % (self._tpu, e)) 279 280 def _get_tpu_property(self, key): 281 if self._use_api: 282 metadata = self._fetch_cloud_tpu_metadata() 283 return metadata.get(key) 284 285 return None 286 287 def __enter__(self): 288 self._open = True 289 290 def __exit__(self, type, value, traceback): # pylint: disable=redefined-builtin 291 del type, value, traceback 292 293 def recoverable(self): 294 """Returns true if the TPU is in a state where training should eventually resume. 295 296 If false the TPU is in a unrecoverable state and should be recreated. 297 """ 298 state = self.state() 299 symptoms = self.symptoms() 300 if state and state in ['TERMINATED', 'PREEMPTED']: 301 return False 302 elif FLAGS.runtime_oom_exit and self._oom_event(symptoms): 303 return False 304 elif FLAGS.hbm_oom_exit and self._hbm_oom_event(symptoms): 305 return False 306 return True 307 308 def symptoms(self): 309 """Return Cloud TPU Symptoms of the TPU.""" 310 return self._get_tpu_property('symptoms') 311 312 def state(self): 313 """Return state of the TPU.""" 314 return self._get_tpu_property('state') 315 316 def health(self): 317 """Return health of the TPU.""" 318 return self._get_tpu_property('health') 319 320 def runtime_version(self): 321 """Return runtime version of the TPU.""" 322 323 if not self._use_api: 324 # Fallback on getting version directly from TPU. 325 url = _VERSION_SWITCHER_ENDPOINT.format( 326 self.network_endpoints()[0]['ipAddress']) 327 try: 328 req = urllib.request.Request(url) 329 resp = urllib.request.urlopen(req) 330 version_details = json.loads(resp.read()) 331 return version_details.get('currentVersion') 332 except urllib.error.HTTPError as e: 333 status_code = e.code 334 if status_code == 404: 335 return None 336 else: 337 raise e 338 return self._get_tpu_property('tensorflowVersion') 339 340 def accelerator_type(self): 341 """Return accelerator type of the TPU.""" 342 return self._get_tpu_property('acceleratorType') 343 344 def api_available(self): 345 """Return if the Cloud TPU API is available, if not certain features will not work.""" 346 return self._use_api 347 348 def name(self): 349 """Return the name of the tpu, or the ip address if name is not provided.""" 350 return self._tpu 351 352 def get_local_ip(self): 353 """Return the local ip address of the Google Cloud VM the workload is running on.""" 354 return _request_compute_metadata('instance/network-interfaces/0/ip') 355 356 def network_endpoints(self): 357 """Return a list of tpu endpoints.""" 358 if not self._use_api: 359 return list(_environment_var_to_network_endpoints(self._tpu)) 360 response = self._fetch_cloud_tpu_metadata() 361 362 if response.get('state') != 'READY': 363 raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' % 364 (self._tpu, response.get('state'))) 365 if 'networkEndpoints' in response: 366 return response['networkEndpoints'] 367 else: 368 return [{'ipAddress': response['ipAddress'], 'port': response['port']}] 369 370 def wait_for_healthy(self, timeout_s=1200, interval=30): 371 """Wait for TPU to become healthy or raise error if timeout reached. 372 373 Args: 374 timeout_s (int): The timeout in seconds for waiting TPU to become healthy. 375 interval (int): The interval in seconds to poll the TPU for health. 376 377 Raises: 378 RuntimeError: If the TPU doesn't become healthy by the timeout. 379 """ 380 timeout = time.time() + timeout_s 381 while self.health() != 'HEALTHY': 382 logging.warning( 383 ('Waiting for TPU "%s" with state "%s" ' 384 'and health "%s" to become healthy'), 385 self.name(), self.state(), self.health()) 386 if time.time() + interval > timeout: 387 raise RuntimeError( 388 'Timed out waiting for TPU "%s" to become healthy' % self.name()) 389 time.sleep(interval) 390 391 logging.warning('TPU "%s" is healthy.', self.name()) 392 393 def configure_tpu_version(self, version, restart_type='always'): 394 """Configure TPU software version. 395 396 Args: 397 version (string): Version of software to configure the TPU with. 398 restart_type (string): Restart behaviour when switching versions, 399 defaults to always restart. Options are 'always', 'ifNeeded'. 400 401 """ 402 403 def configure_worker(worker): 404 """Configure individual TPU worker. 405 406 Args: 407 worker: A dict with the field ipAddress where the configure request will 408 be sent. 409 """ 410 ip_address = worker['ipAddress'] 411 url = (_VERSION_SWITCHER_ENDPOINT + '/{}?restartType={}').format( 412 ip_address, version, restart_type) 413 req = urllib.request.Request(url, data=b'') 414 try: 415 urllib.request.urlopen(req) 416 except urllib.error.HTTPError as e: 417 status_code = e.code 418 if status_code == 404: 419 raise Exception( 420 'Tensorflow version {} is not available on Cloud TPU, ' 421 'try a previous nightly version or refer to ' 422 'https://cloud.google.com/tpu/docs/release-notes for ' 423 'the latest official version.'.format(version)) 424 else: 425 raise Exception('Failed to configure worker {}'.format(ip_address)) 426 427 workers = self.network_endpoints() 428 429 with futures.ThreadPoolExecutor(max_workers=len(workers)) as executor: 430 results = executor.map(configure_worker, workers) 431 for result in results: 432 if result: 433 result.result() 434