xref: /aosp_15_r20/external/grpc-grpc/src/python/grpcio/grpc/_observability.py (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1# Copyright 2023 The gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import annotations
16
17import abc
18import contextlib
19import logging
20import threading
21from typing import Any, Generator, Generic, List, Optional, TypeVar
22
23from grpc._cython import cygrpc as _cygrpc
24
25_LOGGER = logging.getLogger(__name__)
26
27_channel = Any  # _channel.py imports this module.
28ClientCallTracerCapsule = TypeVar("ClientCallTracerCapsule")
29ServerCallTracerFactoryCapsule = TypeVar("ServerCallTracerFactoryCapsule")
30
31_plugin_lock: threading.RLock = threading.RLock()
32_OBSERVABILITY_PLUGIN: Optional["ObservabilityPlugin"] = None
33_SERVICES_TO_EXCLUDE: List[bytes] = [
34    b"google.monitoring.v3.MetricService",
35    b"google.devtools.cloudtrace.v2.TraceService",
36]
37
38
39class ObservabilityPlugin(
40    Generic[ClientCallTracerCapsule, ServerCallTracerFactoryCapsule],
41    metaclass=abc.ABCMeta,
42):
43    """Abstract base class for observability plugin.
44
45    *This is a semi-private class that was intended for the exclusive use of
46     the gRPC team.*
47
48    The ClientCallTracerCapsule and ClientCallTracerCapsule created by this
49    plugin should be inject to gRPC core using observability_init at the
50    start of a program, before any channels/servers are built.
51
52    Any future methods added to this interface cannot have the
53    @abc.abstractmethod annotation.
54
55    Attributes:
56      _stats_enabled: A bool indicates whether tracing is enabled.
57      _tracing_enabled: A bool indicates whether stats(metrics) is enabled.
58    """
59
60    _tracing_enabled: bool = False
61    _stats_enabled: bool = False
62
63    @abc.abstractmethod
64    def create_client_call_tracer(
65        self, method_name: bytes, target: bytes
66    ) -> ClientCallTracerCapsule:
67        """Creates a ClientCallTracerCapsule.
68
69        After register the plugin, if tracing or stats is enabled, this method
70        will be called after a call was created, the ClientCallTracer created
71        by this method will be saved to call context.
72
73        The ClientCallTracer is an object which implements `grpc_core::ClientCallTracer`
74        interface and wrapped in a PyCapsule using `client_call_tracer` as name.
75
76        Args:
77          method_name: The method name of the call in byte format.
78          target: The channel target of the call in byte format.
79
80        Returns:
81          A PyCapsule which stores a ClientCallTracer object.
82        """
83        raise NotImplementedError()
84
85    @abc.abstractmethod
86    def delete_client_call_tracer(
87        self, client_call_tracer: ClientCallTracerCapsule
88    ) -> None:
89        """Deletes the ClientCallTracer stored in ClientCallTracerCapsule.
90
91        After register the plugin, if tracing or stats is enabled, this method
92        will be called at the end of the call to destroy the ClientCallTracer.
93
94        The ClientCallTracer is an object which implements `grpc_core::ClientCallTracer`
95        interface and wrapped in a PyCapsule using `client_call_tracer` as name.
96
97        Args:
98          client_call_tracer: A PyCapsule which stores a ClientCallTracer object.
99        """
100        raise NotImplementedError()
101
102    @abc.abstractmethod
103    def save_trace_context(
104        self, trace_id: str, span_id: str, is_sampled: bool
105    ) -> None:
106        """Saves the trace_id and span_id related to the current span.
107
108        After register the plugin, if tracing is enabled, this method will be
109        called after the server finished sending response.
110
111        This method can be used to propagate census context.
112
113        Args:
114          trace_id: The identifier for the trace associated with the span as a
115            32-character hexadecimal encoded string,
116            e.g. 26ed0036f2eff2b7317bccce3e28d01f
117          span_id: The identifier for the span as a 16-character hexadecimal encoded
118            string. e.g. 113ec879e62583bc
119          is_sampled: A bool indicates whether the span is sampled.
120        """
121        raise NotImplementedError()
122
123    @abc.abstractmethod
124    def create_server_call_tracer_factory(
125        self,
126    ) -> ServerCallTracerFactoryCapsule:
127        """Creates a ServerCallTracerFactoryCapsule.
128
129        After register the plugin, if tracing or stats is enabled, this method
130        will be called by calling observability_init, the ServerCallTracerFactory
131        created by this method will be registered to gRPC core.
132
133        The ServerCallTracerFactory is an object which implements
134        `grpc_core::ServerCallTracerFactory` interface and wrapped in a PyCapsule
135        using `server_call_tracer_factory` as name.
136
137        Returns:
138          A PyCapsule which stores a ServerCallTracerFactory object.
139        """
140        raise NotImplementedError()
141
142    @abc.abstractmethod
143    def record_rpc_latency(
144        self, method: str, target: str, rpc_latency: float, status_code: Any
145    ) -> None:
146        """Record the latency of the RPC.
147
148        After register the plugin, if stats is enabled, this method will be
149        called at the end of each RPC.
150
151        Args:
152          method: The fully-qualified name of the RPC method being invoked.
153          target: The target name of the RPC method being invoked.
154          rpc_latency: The latency for the RPC in seconds, equals to the time between
155            when the client invokes the RPC and when the client receives the status.
156          status_code: An element of grpc.StatusCode in string format representing the
157            final status for the RPC.
158        """
159        raise NotImplementedError()
160
161    def set_tracing(self, enable: bool) -> None:
162        """Enable or disable tracing.
163
164        Args:
165          enable: A bool indicates whether tracing should be enabled.
166        """
167        self._tracing_enabled = enable
168
169    def set_stats(self, enable: bool) -> None:
170        """Enable or disable stats(metrics).
171
172        Args:
173          enable: A bool indicates whether stats should be enabled.
174        """
175        self._stats_enabled = enable
176
177    @property
178    def tracing_enabled(self) -> bool:
179        return self._tracing_enabled
180
181    @property
182    def stats_enabled(self) -> bool:
183        return self._stats_enabled
184
185    @property
186    def observability_enabled(self) -> bool:
187        return self.tracing_enabled or self.stats_enabled
188
189
190@contextlib.contextmanager
191def get_plugin() -> Generator[Optional[ObservabilityPlugin], None, None]:
192    """Get the ObservabilityPlugin in _observability module.
193
194    Returns:
195      The ObservabilityPlugin currently registered with the _observability
196    module. Or None if no plugin exists at the time of calling this method.
197    """
198    with _plugin_lock:
199        yield _OBSERVABILITY_PLUGIN
200
201
202def set_plugin(observability_plugin: Optional[ObservabilityPlugin]) -> None:
203    """Save ObservabilityPlugin to _observability module.
204
205    Args:
206      observability_plugin: The ObservabilityPlugin to save.
207
208    Raises:
209      ValueError: If an ObservabilityPlugin was already registered at the
210        time of calling this method.
211    """
212    global _OBSERVABILITY_PLUGIN  # pylint: disable=global-statement
213    with _plugin_lock:
214        if observability_plugin and _OBSERVABILITY_PLUGIN:
215            raise ValueError("observability_plugin was already set!")
216        _OBSERVABILITY_PLUGIN = observability_plugin
217
218
219def observability_init(observability_plugin: ObservabilityPlugin) -> None:
220    """Initialize observability with provided ObservabilityPlugin.
221
222    This method have to be called at the start of a program, before any
223    channels/servers are built.
224
225    Args:
226      observability_plugin: The ObservabilityPlugin to use.
227
228    Raises:
229      ValueError: If an ObservabilityPlugin was already registered at the
230        time of calling this method.
231    """
232    set_plugin(observability_plugin)
233    try:
234        _cygrpc.set_server_call_tracer_factory(observability_plugin)
235    except Exception:  # pylint:disable=broad-except
236        _LOGGER.exception("Failed to set server call tracer factory!")
237
238
239def observability_deinit() -> None:
240    """Clear the observability context, including ObservabilityPlugin and
241    ServerCallTracerFactory
242
243    This method have to be called after exit observability context so that
244    it's possible to re-initialize again.
245    """
246    set_plugin(None)
247    _cygrpc.clear_server_call_tracer_factory()
248
249
250def delete_call_tracer(client_call_tracer_capsule: Any) -> None:
251    """Deletes the ClientCallTracer stored in ClientCallTracerCapsule.
252
253    This method will be called at the end of the call to destroy the ClientCallTracer.
254
255    The ClientCallTracer is an object which implements `grpc_core::ClientCallTracer`
256    interface and wrapped in a PyCapsule using `client_call_tracer` as the name.
257
258    Args:
259      client_call_tracer_capsule: A PyCapsule which stores a ClientCallTracer object.
260    """
261    with get_plugin() as plugin:
262        if not (plugin and plugin.observability_enabled):
263            return
264        plugin.delete_client_call_tracer(client_call_tracer_capsule)
265
266
267def maybe_record_rpc_latency(state: "_channel._RPCState") -> None:
268    """Record the latency of the RPC, if the plugin is registered and stats is enabled.
269
270    This method will be called at the end of each RPC.
271
272    Args:
273      state: a grpc._channel._RPCState object which contains the stats related to the
274        RPC.
275    """
276    # TODO(xuanwn): use channel args to exclude those metrics.
277    for exclude_prefix in _SERVICES_TO_EXCLUDE:
278        if exclude_prefix in state.method.encode("utf8"):
279            return
280    with get_plugin() as plugin:
281        if not (plugin and plugin.stats_enabled):
282            return
283        rpc_latency_s = state.rpc_end_time - state.rpc_start_time
284        rpc_latency_ms = rpc_latency_s * 1000
285        plugin.record_rpc_latency(
286            state.method, state.target, rpc_latency_ms, state.code
287        )
288