xref: /aosp_15_r20/external/tensorflow/tensorflow/python/profiler/trace.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Trace allows the profiler to trace Python events."""
16
17import functools
18
19from tensorflow.python.profiler.internal import _pywrap_traceme
20from tensorflow.python.util.tf_export import tf_export
21
22# This variable is modified by PythonHooks::Start/Stop() in C++. Such
23# arrangement will reduce the number of calls through pybind11.
24enabled = False
25
26
27@tf_export('profiler.experimental.Trace', v1=[])
28class Trace(object):
29  """Context manager that generates a trace event in the profiler.
30
31  A trace event will start when entering the context, and stop and save the
32  result to the profiler when exiting the context. Open TensorBoard Profile tab
33  and choose trace viewer to view the trace event in the timeline.
34
35  Trace events are created only when the profiler is enabled. More information
36  on how to use the profiler can be found at
37  https://tensorflow.org/guide/profiler
38
39  Example usage:
40  ```python
41  tf.profiler.experimental.start('logdir')
42  for step in range(num_steps):
43    # Creates a trace event for each training step with the step number.
44    with tf.profiler.experimental.Trace("Train", step_num=step, _r=1):
45      train_fn()
46  tf.profiler.experimental.stop()
47  ```
48  """
49
50  def __init__(self, name, **kwargs):
51    """Creates a trace event in the profiler.
52
53    Args:
54      name: The name of the trace event.
55      **kwargs: Keyword arguments added to the trace event.
56                Both the key and value are of types that
57                can be converted to strings, which will be
58                interpreted by the profiler according to the
59                traceme name.
60
61      Example usage:
62
63      ```python
64
65        tf.profiler.experimental.start('logdir')
66        for step in range(num_steps):
67          # Creates a trace event for each training step with the
68          # step number.
69          with tf.profiler.experimental.Trace("Train", step_num=step):
70            train_fn()
71        tf.profiler.experimental.stop()
72
73      ```
74      The example above uses the keyword argument "step_num" to specify the
75      training step being traced.
76    """
77    if enabled:
78      # Creating _pywrap_traceme.TraceMe starts the clock.
79      self._traceme = _pywrap_traceme.TraceMe(name, **kwargs)
80    else:
81      self._traceme = None
82
83  def __enter__(self):
84    # Starting the TraceMe clock here would require an extra Python->C++ call.
85    return self
86
87  def set_metadata(self, **kwargs):
88    """Sets metadata in this trace event.
89
90    Args:
91      **kwargs: metadata in key-value pairs.
92
93    This method enables setting metadata in a trace event after it is
94    created.
95
96    Example usage:
97
98    ```python
99
100      def call(function):
101        with tf.profiler.experimental.Trace("call",
102             function_name=function.name) as tm:
103          binary, in_cache = jit_compile(function)
104          tm.set_metadata(in_cache=in_cache)
105          execute(binary)
106
107    ```
108    In this example, we want to trace how much time spent on
109    calling a function, which includes compilation and execution.
110    The compilation can be either getting a cached copy of the
111    binary or actually generating the binary, which is indicated
112    by the boolean "in_cache" returned by jit_compile(). We need
113    to use set_metadata() to pass in_cache because we did not know
114    the in_cache value when the trace was created (and we cannot
115    create the trace after jit_compile(), because we want
116    to measure the entire duration of call()).
117    """
118    if self._traceme and kwargs:
119      self._traceme.SetMetadata(**kwargs)
120
121  def __exit__(self, exc_type, exc_val, exc_tb):
122    if self._traceme:
123      self._traceme.Stop()
124
125
126def trace_wrapper(trace_name, **trace_kwargs):
127  """Decorator alternative to `with Trace(): ...`.  It's faster.
128
129  Args:
130    trace_name: The name of the trace event, or a callable to be traced, in
131      which case the name is inferred from qualname or name of the callable.
132    **trace_kwargs: Keyword arguments added to the trace event. Both the key and
133      value are of types that can be converted to strings, which will be
134      interpreted by the profiler according to the traceme name.
135
136  Returns:
137    A decorator that can wrap a function and apply `Trace` scope if needed,
138    or a decorated function if used as a decorator directly.
139
140  Example usage:
141    ```python
142
143    @trace_wrapper('trace_name')
144    def func(x, y, z):
145      pass  # code to execute and apply `Trace` if needed.
146
147    # Equivalent to
148    # with Trace('trace_name'):
149    #   func(1, 2, 3)
150    func(1, 2, 3)
151    ```
152
153  or
154    ```python
155
156    @trace_wrapper
157    def func(x, y, z):
158      pass  # code to execute and apply `Trace` if needed.
159
160    # Equivalent to
161    # with Trace(func.__qualname__):
162    #   func(1, 2, 3)
163    func(1, 2, 3)
164    ```
165
166  """
167
168  if callable(trace_name):
169    func = trace_name
170    name = getattr(func, '__qualname__', None)
171    if not name:
172      name = getattr(func, '__name__', 'unknown function')
173
174    return trace_wrapper(name)(func)
175
176  def inner_wrapper(func):
177
178    @functools.wraps(func)
179    def wrapped(*args, **kwargs):
180      if enabled:
181        with Trace(trace_name, **trace_kwargs):
182          return func(*args, **kwargs)
183      return func(*args, **kwargs)
184
185    return wrapped
186
187  return inner_wrapper
188