xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/lib/grpc_debug_server.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""gRPC debug server in Python."""
16# pylint: disable=g-bad-import-order
17import collections
18import json
19import queue
20import threading
21import time
22
23from concurrent import futures
24import grpc
25
26from tensorflow.core.debug import debug_service_pb2
27from tensorflow.core.framework import graph_pb2
28from tensorflow.python.debug.lib import debug_graphs
29from tensorflow.python.debug.lib import debug_service_pb2_grpc
30from tensorflow.python.platform import tf_logging as logging
31from tensorflow.python.util import compat
32
33DebugWatch = collections.namedtuple("DebugWatch",
34                                    ["node_name", "output_slot", "debug_op"])
35
36
37def _state_change(new_state, node_name, output_slot, debug_op):
38  state_change = debug_service_pb2.EventReply.DebugOpStateChange()
39  state_change.state = new_state
40  state_change.node_name = node_name
41  state_change.output_slot = output_slot
42  state_change.debug_op = debug_op
43  return state_change
44
45
46class EventListenerBaseStreamHandler:
47  """Per-stream handler of EventListener gRPC streams."""
48
49  def __init__(self):
50    """Constructor of EventListenerBaseStreamHandler."""
51
52  def on_core_metadata_event(self, event):
53    """Callback for core metadata.
54
55    Args:
56      event: The Event proto that carries a JSON string in its
57        `log_message.message` field.
58
59    Returns:
60      `None` or an `EventReply` proto to be sent back to the client. If `None`,
61      an `EventReply` proto construct with the default no-arg constructor will
62      be sent back to the client.
63    """
64    raise NotImplementedError(
65        "on_core_metadata_event() is not implemented in the base servicer "
66        "class")
67
68  def on_graph_def(self, graph_def, device_name, wall_time):
69    """Callback for Event proto received through the gRPC stream.
70
71    This Event proto carries a GraphDef, encoded as bytes, in its graph_def
72    field.
73
74    Args:
75      graph_def: A GraphDef object.
76      device_name: Name of the device on which the graph was created.
77      wall_time: An epoch timestamp (in microseconds) for the graph.
78
79    Returns:
80      `None` or an `EventReply` proto to be sent back to the client. If `None`,
81      an `EventReply` proto construct with the default no-arg constructor will
82      be sent back to the client.
83    """
84    raise NotImplementedError(
85        "on_graph_def() is not implemented in the base servicer class")
86
87  def on_value_event(self, event):
88    """Callback for Event proto received through the gRPC stream.
89
90    This Event proto carries a Tensor in its summary.value[0] field.
91
92    Args:
93      event: The Event proto from the stream to be processed.
94    """
95    raise NotImplementedError(
96        "on_value_event() is not implemented in the base servicer class")
97
98
99class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer):
100  """Base Python class for gRPC debug server."""
101
102  def __init__(self, server_port, stream_handler_class):
103    """Constructor.
104
105    Args:
106      server_port: (int) Port number to bind to.
107      stream_handler_class: A class of the base class
108        `EventListenerBaseStreamHandler` that will be used to constructor
109        stream handler objects during `SendEvents` calls.
110    """
111
112    self._server_port = server_port
113    self._stream_handler_class = stream_handler_class
114
115    self._server_lock = threading.Lock()
116    self._server_started = False
117    self._stop_requested = False
118
119    self._debug_ops_state_change_queue = queue.Queue()
120    self._gated_grpc_debug_watches = set()
121    self._breakpoints = set()
122
123  def SendEvents(self, request_iterator, context):
124    """Implementation of the SendEvents service method.
125
126    This method receives streams of Event protos from the client, and processes
127    them in ways specified in the on_event() callback. The stream is
128    bi-directional, but currently only the client-to-server stream (i.e., the
129    stream from the debug ops to the server) is used.
130
131    Args:
132      request_iterator: The incoming stream of Event protos.
133      context: Server context.
134
135    Raises:
136      ValueError: If there are more than one core metadata events.
137
138    Yields:
139      An empty stream of responses.
140    """
141    core_metadata_count = 0
142
143    # A map from GraphDef hash to a list of received chunks.
144    graph_def_chunks = {}
145    tensor_chunks = {}
146
147    stream_handler = None
148    for event in request_iterator:
149      if not stream_handler:
150        stream_handler = self._stream_handler_class()
151
152      if event.summary and event.summary.value:
153        # An Event proto carrying a tensor value.
154        maybe_tensor_event = self._process_tensor_event_in_chunks(
155            event, tensor_chunks)
156        if maybe_tensor_event:
157          event_reply = stream_handler.on_value_event(maybe_tensor_event)
158          if event_reply is not None:
159            yield self._process_debug_op_state_changes(event_reply)
160      else:
161        # Non-tensor-value Event.
162        if event.graph_def:
163          # GraphDef-carrying Event.
164          maybe_graph_def, maybe_device_name, maybe_wall_time = (
165              self._process_encoded_graph_def_in_chunks(
166                  event, graph_def_chunks))
167          if maybe_graph_def:
168            reply = stream_handler.on_graph_def(
169                maybe_graph_def, maybe_device_name, maybe_wall_time)
170            yield self._process_debug_op_state_changes(reply)
171        elif event.log_message.message:
172          # Core metadata-carrying Event.
173          core_metadata_count += 1
174          if core_metadata_count > 1:
175            raise ValueError(
176                "Expected one core metadata event; received multiple")
177          reply = stream_handler.on_core_metadata_event(event)
178          yield self._process_debug_op_state_changes(reply)
179
180  def _process_debug_op_state_changes(self, event_reply=None):
181    """Dequeue and process all the queued debug-op state change protos.
182
183    Include all the debug-op state change protos in a `EventReply` proto.
184
185    Args:
186      event_reply: An `EventReply` to add the `DebugOpStateChange` protos to,
187        or `None`.
188
189    Returns:
190      An `EventReply` proto with the dequeued `DebugOpStateChange` protos (if
191        any) added.
192    """
193    if event_reply is None:
194      event_reply = debug_service_pb2.EventReply()
195    while not self._debug_ops_state_change_queue.empty():
196      state_change = self._debug_ops_state_change_queue.get()
197      debug_node_key = (state_change.node_name, state_change.output_slot,
198                        state_change.debug_op)
199      if (state_change.state ==
200          debug_service_pb2.EventReply.DebugOpStateChange.READ_WRITE):
201        logging.info("Adding breakpoint %s:%d:%s", state_change.node_name,
202                     state_change.output_slot, state_change.debug_op)
203        self._breakpoints.add(debug_node_key)
204      elif (state_change.state ==
205            debug_service_pb2.EventReply.DebugOpStateChange.READ_ONLY):
206        logging.info("Adding watchpoint %s:%d:%s", state_change.node_name,
207                     state_change.output_slot, state_change.debug_op)
208        if debug_node_key in self._breakpoints:
209          self._breakpoints.discard(debug_node_key)
210      elif (state_change.state ==
211            debug_service_pb2.EventReply.DebugOpStateChange.DISABLED):
212        logging.info("Removing watchpoint or breakpoint: %s:%d:%s",
213                     state_change.node_name, state_change.output_slot,
214                     state_change.debug_op)
215        if debug_node_key in self._breakpoints:
216          self._breakpoints.discard(debug_node_key)
217        else:
218          logging.warn(
219              "Attempting to remove a non-existent debug node key: %s",
220              debug_node_key)
221      new_state_change = event_reply.debug_op_state_changes.add()
222      new_state_change.CopyFrom(state_change)
223    return event_reply
224
225  def _process_tensor_event_in_chunks(self, event, tensor_chunks):
226    """Possibly reassemble event chunks.
227
228    Due to gRPC's message size limit, a large tensor can be encapsulated in
229    multiple Event proto chunks to be sent through the debugger stream. This
230    method keeps track of the chunks that have arrived, reassemble all chunks
231    corresponding to a tensor when they have arrived and return the reassembled
232    Event proto.
233
234    Args:
235      event: The single Event proto that has arrived.
236      tensor_chunks: A dict used to keep track of the Event protos that have
237        arrived but haven't been reassembled.
238
239    Returns:
240      If all Event protos corresponding to a tensor have arrived, returns the
241      reassembled Event proto. Otherwise, return None.
242    """
243
244    value = event.summary.value[0]
245    debugger_plugin_metadata = json.loads(
246        compat.as_text(value.metadata.plugin_data.content))
247    device_name = debugger_plugin_metadata["device"]
248    num_chunks = debugger_plugin_metadata["numChunks"]
249    chunk_index = debugger_plugin_metadata["chunkIndex"]
250
251    if num_chunks <= 1:
252      return event
253
254    debug_node_name = value.node_name
255    timestamp = int(event.wall_time)
256    tensor_key = "%s_%s_%d" % (device_name, debug_node_name, timestamp)
257
258    if tensor_key not in tensor_chunks:
259      tensor_chunks[tensor_key] = [None] * num_chunks
260
261    chunks = tensor_chunks[tensor_key]
262    if value.tensor.tensor_content:
263      chunks[chunk_index] = value.tensor
264    elif value.tensor.string_val:
265      chunks[chunk_index] = event
266
267    if None not in chunks:
268      if value.tensor.tensor_content:
269        event.summary.value[0].tensor.tensor_content = b"".join(
270            chunk.tensor_content for chunk in chunks)
271        del tensor_chunks[tensor_key]
272        return event
273      elif value.tensor.string_val:
274        merged_event = chunks[0]
275        for chunk in chunks[1:]:
276          merged_event.summary.value[0].tensor.string_val.extend(
277              list(chunk.summary.value[0].tensor.string_val))
278        return merged_event
279
280  def _process_encoded_graph_def_in_chunks(self,
281                                           event,
282                                           graph_def_chunks):
283    """Process an Event proto containing a chunk of encoded GraphDef.
284
285    Args:
286      event: the Event proto containing the chunk of encoded GraphDef.
287      graph_def_chunks: A dict mapping keys for GraphDefs (i.e.,
288      "<graph_def_hash>,<device_name>,<wall_time>") to a list of chunks of
289      encoded GraphDefs.
290
291    Returns:
292      If all chunks of the GraphDef have arrived,
293        return decoded GraphDef proto, device name, wall_time.
294      Otherwise,
295        return None, None, None.
296    """
297    graph_def = graph_pb2.GraphDef()
298    index_bar_0 = event.graph_def.find(b"|")
299    index_bar_1 = event.graph_def.find(b"|", index_bar_0 + 1)
300    index_bar_2 = event.graph_def.find(b"|", index_bar_1 + 1)
301    graph_def_hash_device_timestamp = event.graph_def[:index_bar_0]
302    chunk_index = int(event.graph_def[index_bar_0 + 1 : index_bar_1])
303    num_chunks = int(event.graph_def[index_bar_1 + 1 : index_bar_2])
304    if graph_def_hash_device_timestamp not in graph_def_chunks:
305      graph_def_chunks[graph_def_hash_device_timestamp] = [None] * num_chunks
306    graph_def_chunks[graph_def_hash_device_timestamp][
307        chunk_index] = event.graph_def[index_bar_2 + 1:]
308    if all(graph_def_chunks[graph_def_hash_device_timestamp]):
309      device_name = graph_def_hash_device_timestamp.split(b",")[1]
310      wall_time = int(graph_def_hash_device_timestamp.split(b",")[2])
311      graph_def.ParseFromString(
312          b"".join(graph_def_chunks[graph_def_hash_device_timestamp]))
313      del graph_def_chunks[graph_def_hash_device_timestamp]
314      self._process_graph_def(graph_def)
315      return graph_def, device_name, wall_time
316    else:
317      return None, None, None
318
319  def _process_graph_def(self, graph_def):
320    for node_def in graph_def.node:
321      if (debug_graphs.is_debug_node(node_def.name) and
322          node_def.attr["gated_grpc"].b):
323        node_name, output_slot, _, debug_op = (
324            debug_graphs.parse_debug_node_name(node_def.name))
325        self._gated_grpc_debug_watches.add(
326            DebugWatch(node_name, output_slot, debug_op))
327
328  def run_server(self, blocking=True):
329    """Start running the server.
330
331    Args:
332      blocking: If `True`, block until `stop_server()` is invoked.
333
334    Raises:
335      ValueError: If server stop has already been requested, or if the server
336        has already started running.
337    """
338    self._server_lock.acquire()
339    try:
340      if self._stop_requested:
341        raise ValueError("Server has already stopped")
342      if self._server_started:
343        raise ValueError("Server has already started running")
344
345      no_max_message_sizes = [("grpc.max_receive_message_length", -1),
346                              ("grpc.max_send_message_length", -1)]
347      self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=10),
348                                options=no_max_message_sizes)
349      debug_service_pb2_grpc.add_EventListenerServicer_to_server(self,
350                                                                 self.server)
351      self.server.add_insecure_port("[::]:%d" % self._server_port)
352      self.server.start()
353      self._server_started = True
354    finally:
355      self._server_lock.release()
356
357    if blocking:
358      while not self._stop_requested:
359        time.sleep(1.0)
360
361  def stop_server(self, grace=1.0):
362    """Request server stopping.
363
364    Once stopped, server cannot be stopped or started again. This method is
365    non-blocking. Call `wait()` on the returned event to block until the server
366    has completely stopped.
367
368    Args:
369      grace: Grace period in seconds to be used when calling `server.stop()`.
370
371    Raises:
372      ValueError: If server stop has already been requested, or if the server
373        has not started running yet.
374
375    Returns:
376      A threading.Event that will be set when the server has completely stopped.
377    """
378    self._server_lock.acquire()
379    try:
380      if not self._server_started:
381        raise ValueError("Server has not started running")
382      if self._stop_requested:
383        raise ValueError("Server has already stopped")
384
385      self._stop_requested = True
386      return self.server.stop(grace=grace)
387    finally:
388      self._server_lock.release()
389
390  def request_watch(self, node_name, output_slot, debug_op, breakpoint=False):  # pylint: disable=redefined-builtin
391    """Request enabling a debug tensor watchpoint or breakpoint.
392
393    This will let the server send a EventReply to the client side
394    (i.e., the debugged TensorFlow runtime process) to request adding a watch
395    key (i.e., <node_name>:<output_slot>:<debug_op>) to the list of enabled
396    watch keys. The list applies only to debug ops with the attribute
397    gated_grpc=True.
398
399    To disable the watch, use `request_unwatch()`.
400
401    Args:
402      node_name: (`str`) name of the node that the to-be-watched tensor belongs
403        to, e.g., "hidden/Weights".
404      output_slot: (`int`) output slot index of the tensor to watch.
405      debug_op: (`str`) name of the debug op to enable. This should not include
406        any attribute substrings.
407      breakpoint: (`bool`) Iff `True`, the debug op will block and wait until it
408        receives an `EventReply` response from the server. The `EventReply`
409        proto may carry a TensorProto that modifies the value of the debug op's
410        output tensor.
411    """
412    self._debug_ops_state_change_queue.put(
413        _state_change(
414            debug_service_pb2.EventReply.DebugOpStateChange.READ_WRITE
415            if breakpoint
416            else debug_service_pb2.EventReply.DebugOpStateChange.READ_ONLY,
417            node_name, output_slot, debug_op))
418
419  def request_unwatch(self, node_name, output_slot, debug_op):
420    """Request disabling a debug tensor watchpoint or breakpoint.
421
422    This is the opposite of `request_watch()`.
423
424    Args:
425      node_name: (`str`) name of the node that the to-be-watched tensor belongs
426        to, e.g., "hidden/Weights".
427      output_slot: (`int`) output slot index of the tensor to watch.
428      debug_op: (`str`) name of the debug op to enable. This should not include
429        any attribute substrings.
430    """
431    self._debug_ops_state_change_queue.put(
432        _state_change(
433            debug_service_pb2.EventReply.DebugOpStateChange.DISABLED, node_name,
434            output_slot, debug_op))
435
436  @property
437  def breakpoints(self):
438    """Get a set of the currently-activated breakpoints.
439
440    Returns:
441      A `set` of 3-tuples: (node_name, output_slot, debug_op), e.g.,
442        {("MatMul", 0, "DebugIdentity")}.
443    """
444    return self._breakpoints
445
446  def gated_grpc_debug_watches(self):
447    """Get the list of debug watches with attribute gated_grpc=True.
448
449    Since the server receives `GraphDef` from the debugged runtime, it can only
450    return such debug watches that it has received so far.
451
452    Returns:
453      A `list` of `DebugWatch` `namedtuples` representing the debug watches with
454      gated_grpc=True. Each `namedtuple` element has the attributes:
455        `node_name` as a `str`,
456        `output_slot` as an `int`,
457        `debug_op` as a `str`.
458    """
459    return list(self._gated_grpc_debug_watches)
460
461  def SendTracebacks(self, request, context):
462    """Base implementation of the handling of SendTracebacks calls.
463
464    The base implementation does nothing with the incoming request.
465    Override in an implementation of the server if necessary.
466
467    Args:
468      request: A `CallTraceback` proto, containing information about the
469        type (e.g., graph vs. eager execution) and source-code traceback of the
470        call and (any) associated `tf.Graph`s.
471      context: Server context.
472
473    Returns:
474      A `EventReply` proto.
475    """
476    return debug_service_pb2.EventReply()
477
478  def SendSourceFiles(self, request, context):
479    """Base implementation of the handling of SendSourceFiles calls.
480
481    The base implementation does nothing with the incoming request.
482    Override in an implementation of the server if necessary.
483
484    Args:
485      request: A `DebuggedSourceFiles` proto, containing the path, content, size
486        and last-modified timestamp of source files.
487      context: Server context.
488
489    Returns:
490      A `EventReply` proto.
491    """
492    return debug_service_pb2.EventReply()
493