1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""gRPC debug server in Python.""" 16# pylint: disable=g-bad-import-order 17import collections 18import json 19import queue 20import threading 21import time 22 23from concurrent import futures 24import grpc 25 26from tensorflow.core.debug import debug_service_pb2 27from tensorflow.core.framework import graph_pb2 28from tensorflow.python.debug.lib import debug_graphs 29from tensorflow.python.debug.lib import debug_service_pb2_grpc 30from tensorflow.python.platform import tf_logging as logging 31from tensorflow.python.util import compat 32 33DebugWatch = collections.namedtuple("DebugWatch", 34 ["node_name", "output_slot", "debug_op"]) 35 36 37def _state_change(new_state, node_name, output_slot, debug_op): 38 state_change = debug_service_pb2.EventReply.DebugOpStateChange() 39 state_change.state = new_state 40 state_change.node_name = node_name 41 state_change.output_slot = output_slot 42 state_change.debug_op = debug_op 43 return state_change 44 45 46class EventListenerBaseStreamHandler: 47 """Per-stream handler of EventListener gRPC streams.""" 48 49 def __init__(self): 50 """Constructor of EventListenerBaseStreamHandler.""" 51 52 def on_core_metadata_event(self, event): 53 """Callback for core metadata. 54 55 Args: 56 event: The Event proto that carries a JSON string in its 57 `log_message.message` field. 58 59 Returns: 60 `None` or an `EventReply` proto to be sent back to the client. If `None`, 61 an `EventReply` proto construct with the default no-arg constructor will 62 be sent back to the client. 63 """ 64 raise NotImplementedError( 65 "on_core_metadata_event() is not implemented in the base servicer " 66 "class") 67 68 def on_graph_def(self, graph_def, device_name, wall_time): 69 """Callback for Event proto received through the gRPC stream. 70 71 This Event proto carries a GraphDef, encoded as bytes, in its graph_def 72 field. 73 74 Args: 75 graph_def: A GraphDef object. 76 device_name: Name of the device on which the graph was created. 77 wall_time: An epoch timestamp (in microseconds) for the graph. 78 79 Returns: 80 `None` or an `EventReply` proto to be sent back to the client. If `None`, 81 an `EventReply` proto construct with the default no-arg constructor will 82 be sent back to the client. 83 """ 84 raise NotImplementedError( 85 "on_graph_def() is not implemented in the base servicer class") 86 87 def on_value_event(self, event): 88 """Callback for Event proto received through the gRPC stream. 89 90 This Event proto carries a Tensor in its summary.value[0] field. 91 92 Args: 93 event: The Event proto from the stream to be processed. 94 """ 95 raise NotImplementedError( 96 "on_value_event() is not implemented in the base servicer class") 97 98 99class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer): 100 """Base Python class for gRPC debug server.""" 101 102 def __init__(self, server_port, stream_handler_class): 103 """Constructor. 104 105 Args: 106 server_port: (int) Port number to bind to. 107 stream_handler_class: A class of the base class 108 `EventListenerBaseStreamHandler` that will be used to constructor 109 stream handler objects during `SendEvents` calls. 110 """ 111 112 self._server_port = server_port 113 self._stream_handler_class = stream_handler_class 114 115 self._server_lock = threading.Lock() 116 self._server_started = False 117 self._stop_requested = False 118 119 self._debug_ops_state_change_queue = queue.Queue() 120 self._gated_grpc_debug_watches = set() 121 self._breakpoints = set() 122 123 def SendEvents(self, request_iterator, context): 124 """Implementation of the SendEvents service method. 125 126 This method receives streams of Event protos from the client, and processes 127 them in ways specified in the on_event() callback. The stream is 128 bi-directional, but currently only the client-to-server stream (i.e., the 129 stream from the debug ops to the server) is used. 130 131 Args: 132 request_iterator: The incoming stream of Event protos. 133 context: Server context. 134 135 Raises: 136 ValueError: If there are more than one core metadata events. 137 138 Yields: 139 An empty stream of responses. 140 """ 141 core_metadata_count = 0 142 143 # A map from GraphDef hash to a list of received chunks. 144 graph_def_chunks = {} 145 tensor_chunks = {} 146 147 stream_handler = None 148 for event in request_iterator: 149 if not stream_handler: 150 stream_handler = self._stream_handler_class() 151 152 if event.summary and event.summary.value: 153 # An Event proto carrying a tensor value. 154 maybe_tensor_event = self._process_tensor_event_in_chunks( 155 event, tensor_chunks) 156 if maybe_tensor_event: 157 event_reply = stream_handler.on_value_event(maybe_tensor_event) 158 if event_reply is not None: 159 yield self._process_debug_op_state_changes(event_reply) 160 else: 161 # Non-tensor-value Event. 162 if event.graph_def: 163 # GraphDef-carrying Event. 164 maybe_graph_def, maybe_device_name, maybe_wall_time = ( 165 self._process_encoded_graph_def_in_chunks( 166 event, graph_def_chunks)) 167 if maybe_graph_def: 168 reply = stream_handler.on_graph_def( 169 maybe_graph_def, maybe_device_name, maybe_wall_time) 170 yield self._process_debug_op_state_changes(reply) 171 elif event.log_message.message: 172 # Core metadata-carrying Event. 173 core_metadata_count += 1 174 if core_metadata_count > 1: 175 raise ValueError( 176 "Expected one core metadata event; received multiple") 177 reply = stream_handler.on_core_metadata_event(event) 178 yield self._process_debug_op_state_changes(reply) 179 180 def _process_debug_op_state_changes(self, event_reply=None): 181 """Dequeue and process all the queued debug-op state change protos. 182 183 Include all the debug-op state change protos in a `EventReply` proto. 184 185 Args: 186 event_reply: An `EventReply` to add the `DebugOpStateChange` protos to, 187 or `None`. 188 189 Returns: 190 An `EventReply` proto with the dequeued `DebugOpStateChange` protos (if 191 any) added. 192 """ 193 if event_reply is None: 194 event_reply = debug_service_pb2.EventReply() 195 while not self._debug_ops_state_change_queue.empty(): 196 state_change = self._debug_ops_state_change_queue.get() 197 debug_node_key = (state_change.node_name, state_change.output_slot, 198 state_change.debug_op) 199 if (state_change.state == 200 debug_service_pb2.EventReply.DebugOpStateChange.READ_WRITE): 201 logging.info("Adding breakpoint %s:%d:%s", state_change.node_name, 202 state_change.output_slot, state_change.debug_op) 203 self._breakpoints.add(debug_node_key) 204 elif (state_change.state == 205 debug_service_pb2.EventReply.DebugOpStateChange.READ_ONLY): 206 logging.info("Adding watchpoint %s:%d:%s", state_change.node_name, 207 state_change.output_slot, state_change.debug_op) 208 if debug_node_key in self._breakpoints: 209 self._breakpoints.discard(debug_node_key) 210 elif (state_change.state == 211 debug_service_pb2.EventReply.DebugOpStateChange.DISABLED): 212 logging.info("Removing watchpoint or breakpoint: %s:%d:%s", 213 state_change.node_name, state_change.output_slot, 214 state_change.debug_op) 215 if debug_node_key in self._breakpoints: 216 self._breakpoints.discard(debug_node_key) 217 else: 218 logging.warn( 219 "Attempting to remove a non-existent debug node key: %s", 220 debug_node_key) 221 new_state_change = event_reply.debug_op_state_changes.add() 222 new_state_change.CopyFrom(state_change) 223 return event_reply 224 225 def _process_tensor_event_in_chunks(self, event, tensor_chunks): 226 """Possibly reassemble event chunks. 227 228 Due to gRPC's message size limit, a large tensor can be encapsulated in 229 multiple Event proto chunks to be sent through the debugger stream. This 230 method keeps track of the chunks that have arrived, reassemble all chunks 231 corresponding to a tensor when they have arrived and return the reassembled 232 Event proto. 233 234 Args: 235 event: The single Event proto that has arrived. 236 tensor_chunks: A dict used to keep track of the Event protos that have 237 arrived but haven't been reassembled. 238 239 Returns: 240 If all Event protos corresponding to a tensor have arrived, returns the 241 reassembled Event proto. Otherwise, return None. 242 """ 243 244 value = event.summary.value[0] 245 debugger_plugin_metadata = json.loads( 246 compat.as_text(value.metadata.plugin_data.content)) 247 device_name = debugger_plugin_metadata["device"] 248 num_chunks = debugger_plugin_metadata["numChunks"] 249 chunk_index = debugger_plugin_metadata["chunkIndex"] 250 251 if num_chunks <= 1: 252 return event 253 254 debug_node_name = value.node_name 255 timestamp = int(event.wall_time) 256 tensor_key = "%s_%s_%d" % (device_name, debug_node_name, timestamp) 257 258 if tensor_key not in tensor_chunks: 259 tensor_chunks[tensor_key] = [None] * num_chunks 260 261 chunks = tensor_chunks[tensor_key] 262 if value.tensor.tensor_content: 263 chunks[chunk_index] = value.tensor 264 elif value.tensor.string_val: 265 chunks[chunk_index] = event 266 267 if None not in chunks: 268 if value.tensor.tensor_content: 269 event.summary.value[0].tensor.tensor_content = b"".join( 270 chunk.tensor_content for chunk in chunks) 271 del tensor_chunks[tensor_key] 272 return event 273 elif value.tensor.string_val: 274 merged_event = chunks[0] 275 for chunk in chunks[1:]: 276 merged_event.summary.value[0].tensor.string_val.extend( 277 list(chunk.summary.value[0].tensor.string_val)) 278 return merged_event 279 280 def _process_encoded_graph_def_in_chunks(self, 281 event, 282 graph_def_chunks): 283 """Process an Event proto containing a chunk of encoded GraphDef. 284 285 Args: 286 event: the Event proto containing the chunk of encoded GraphDef. 287 graph_def_chunks: A dict mapping keys for GraphDefs (i.e., 288 "<graph_def_hash>,<device_name>,<wall_time>") to a list of chunks of 289 encoded GraphDefs. 290 291 Returns: 292 If all chunks of the GraphDef have arrived, 293 return decoded GraphDef proto, device name, wall_time. 294 Otherwise, 295 return None, None, None. 296 """ 297 graph_def = graph_pb2.GraphDef() 298 index_bar_0 = event.graph_def.find(b"|") 299 index_bar_1 = event.graph_def.find(b"|", index_bar_0 + 1) 300 index_bar_2 = event.graph_def.find(b"|", index_bar_1 + 1) 301 graph_def_hash_device_timestamp = event.graph_def[:index_bar_0] 302 chunk_index = int(event.graph_def[index_bar_0 + 1 : index_bar_1]) 303 num_chunks = int(event.graph_def[index_bar_1 + 1 : index_bar_2]) 304 if graph_def_hash_device_timestamp not in graph_def_chunks: 305 graph_def_chunks[graph_def_hash_device_timestamp] = [None] * num_chunks 306 graph_def_chunks[graph_def_hash_device_timestamp][ 307 chunk_index] = event.graph_def[index_bar_2 + 1:] 308 if all(graph_def_chunks[graph_def_hash_device_timestamp]): 309 device_name = graph_def_hash_device_timestamp.split(b",")[1] 310 wall_time = int(graph_def_hash_device_timestamp.split(b",")[2]) 311 graph_def.ParseFromString( 312 b"".join(graph_def_chunks[graph_def_hash_device_timestamp])) 313 del graph_def_chunks[graph_def_hash_device_timestamp] 314 self._process_graph_def(graph_def) 315 return graph_def, device_name, wall_time 316 else: 317 return None, None, None 318 319 def _process_graph_def(self, graph_def): 320 for node_def in graph_def.node: 321 if (debug_graphs.is_debug_node(node_def.name) and 322 node_def.attr["gated_grpc"].b): 323 node_name, output_slot, _, debug_op = ( 324 debug_graphs.parse_debug_node_name(node_def.name)) 325 self._gated_grpc_debug_watches.add( 326 DebugWatch(node_name, output_slot, debug_op)) 327 328 def run_server(self, blocking=True): 329 """Start running the server. 330 331 Args: 332 blocking: If `True`, block until `stop_server()` is invoked. 333 334 Raises: 335 ValueError: If server stop has already been requested, or if the server 336 has already started running. 337 """ 338 self._server_lock.acquire() 339 try: 340 if self._stop_requested: 341 raise ValueError("Server has already stopped") 342 if self._server_started: 343 raise ValueError("Server has already started running") 344 345 no_max_message_sizes = [("grpc.max_receive_message_length", -1), 346 ("grpc.max_send_message_length", -1)] 347 self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=10), 348 options=no_max_message_sizes) 349 debug_service_pb2_grpc.add_EventListenerServicer_to_server(self, 350 self.server) 351 self.server.add_insecure_port("[::]:%d" % self._server_port) 352 self.server.start() 353 self._server_started = True 354 finally: 355 self._server_lock.release() 356 357 if blocking: 358 while not self._stop_requested: 359 time.sleep(1.0) 360 361 def stop_server(self, grace=1.0): 362 """Request server stopping. 363 364 Once stopped, server cannot be stopped or started again. This method is 365 non-blocking. Call `wait()` on the returned event to block until the server 366 has completely stopped. 367 368 Args: 369 grace: Grace period in seconds to be used when calling `server.stop()`. 370 371 Raises: 372 ValueError: If server stop has already been requested, or if the server 373 has not started running yet. 374 375 Returns: 376 A threading.Event that will be set when the server has completely stopped. 377 """ 378 self._server_lock.acquire() 379 try: 380 if not self._server_started: 381 raise ValueError("Server has not started running") 382 if self._stop_requested: 383 raise ValueError("Server has already stopped") 384 385 self._stop_requested = True 386 return self.server.stop(grace=grace) 387 finally: 388 self._server_lock.release() 389 390 def request_watch(self, node_name, output_slot, debug_op, breakpoint=False): # pylint: disable=redefined-builtin 391 """Request enabling a debug tensor watchpoint or breakpoint. 392 393 This will let the server send a EventReply to the client side 394 (i.e., the debugged TensorFlow runtime process) to request adding a watch 395 key (i.e., <node_name>:<output_slot>:<debug_op>) to the list of enabled 396 watch keys. The list applies only to debug ops with the attribute 397 gated_grpc=True. 398 399 To disable the watch, use `request_unwatch()`. 400 401 Args: 402 node_name: (`str`) name of the node that the to-be-watched tensor belongs 403 to, e.g., "hidden/Weights". 404 output_slot: (`int`) output slot index of the tensor to watch. 405 debug_op: (`str`) name of the debug op to enable. This should not include 406 any attribute substrings. 407 breakpoint: (`bool`) Iff `True`, the debug op will block and wait until it 408 receives an `EventReply` response from the server. The `EventReply` 409 proto may carry a TensorProto that modifies the value of the debug op's 410 output tensor. 411 """ 412 self._debug_ops_state_change_queue.put( 413 _state_change( 414 debug_service_pb2.EventReply.DebugOpStateChange.READ_WRITE 415 if breakpoint 416 else debug_service_pb2.EventReply.DebugOpStateChange.READ_ONLY, 417 node_name, output_slot, debug_op)) 418 419 def request_unwatch(self, node_name, output_slot, debug_op): 420 """Request disabling a debug tensor watchpoint or breakpoint. 421 422 This is the opposite of `request_watch()`. 423 424 Args: 425 node_name: (`str`) name of the node that the to-be-watched tensor belongs 426 to, e.g., "hidden/Weights". 427 output_slot: (`int`) output slot index of the tensor to watch. 428 debug_op: (`str`) name of the debug op to enable. This should not include 429 any attribute substrings. 430 """ 431 self._debug_ops_state_change_queue.put( 432 _state_change( 433 debug_service_pb2.EventReply.DebugOpStateChange.DISABLED, node_name, 434 output_slot, debug_op)) 435 436 @property 437 def breakpoints(self): 438 """Get a set of the currently-activated breakpoints. 439 440 Returns: 441 A `set` of 3-tuples: (node_name, output_slot, debug_op), e.g., 442 {("MatMul", 0, "DebugIdentity")}. 443 """ 444 return self._breakpoints 445 446 def gated_grpc_debug_watches(self): 447 """Get the list of debug watches with attribute gated_grpc=True. 448 449 Since the server receives `GraphDef` from the debugged runtime, it can only 450 return such debug watches that it has received so far. 451 452 Returns: 453 A `list` of `DebugWatch` `namedtuples` representing the debug watches with 454 gated_grpc=True. Each `namedtuple` element has the attributes: 455 `node_name` as a `str`, 456 `output_slot` as an `int`, 457 `debug_op` as a `str`. 458 """ 459 return list(self._gated_grpc_debug_watches) 460 461 def SendTracebacks(self, request, context): 462 """Base implementation of the handling of SendTracebacks calls. 463 464 The base implementation does nothing with the incoming request. 465 Override in an implementation of the server if necessary. 466 467 Args: 468 request: A `CallTraceback` proto, containing information about the 469 type (e.g., graph vs. eager execution) and source-code traceback of the 470 call and (any) associated `tf.Graph`s. 471 context: Server context. 472 473 Returns: 474 A `EventReply` proto. 475 """ 476 return debug_service_pb2.EventReply() 477 478 def SendSourceFiles(self, request, context): 479 """Base implementation of the handling of SendSourceFiles calls. 480 481 The base implementation does nothing with the incoming request. 482 Override in an implementation of the server if necessary. 483 484 Args: 485 request: A `DebuggedSourceFiles` proto, containing the path, content, size 486 and last-modified timestamp of source files. 487 context: Server context. 488 489 Returns: 490 A `EventReply` proto. 491 """ 492 return debug_service_pb2.EventReply() 493