1# Copyright 2023 The gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from __future__ import annotations 16 17import abc 18import contextlib 19import logging 20import threading 21from typing import Any, Generator, Generic, List, Optional, TypeVar 22 23from grpc._cython import cygrpc as _cygrpc 24 25_LOGGER = logging.getLogger(__name__) 26 27_channel = Any # _channel.py imports this module. 28ClientCallTracerCapsule = TypeVar("ClientCallTracerCapsule") 29ServerCallTracerFactoryCapsule = TypeVar("ServerCallTracerFactoryCapsule") 30 31_plugin_lock: threading.RLock = threading.RLock() 32_OBSERVABILITY_PLUGIN: Optional["ObservabilityPlugin"] = None 33_SERVICES_TO_EXCLUDE: List[bytes] = [ 34 b"google.monitoring.v3.MetricService", 35 b"google.devtools.cloudtrace.v2.TraceService", 36] 37 38 39class ObservabilityPlugin( 40 Generic[ClientCallTracerCapsule, ServerCallTracerFactoryCapsule], 41 metaclass=abc.ABCMeta, 42): 43 """Abstract base class for observability plugin. 44 45 *This is a semi-private class that was intended for the exclusive use of 46 the gRPC team.* 47 48 The ClientCallTracerCapsule and ClientCallTracerCapsule created by this 49 plugin should be inject to gRPC core using observability_init at the 50 start of a program, before any channels/servers are built. 51 52 Any future methods added to this interface cannot have the 53 @abc.abstractmethod annotation. 54 55 Attributes: 56 _stats_enabled: A bool indicates whether tracing is enabled. 57 _tracing_enabled: A bool indicates whether stats(metrics) is enabled. 58 """ 59 60 _tracing_enabled: bool = False 61 _stats_enabled: bool = False 62 63 @abc.abstractmethod 64 def create_client_call_tracer( 65 self, method_name: bytes, target: bytes 66 ) -> ClientCallTracerCapsule: 67 """Creates a ClientCallTracerCapsule. 68 69 After register the plugin, if tracing or stats is enabled, this method 70 will be called after a call was created, the ClientCallTracer created 71 by this method will be saved to call context. 72 73 The ClientCallTracer is an object which implements `grpc_core::ClientCallTracer` 74 interface and wrapped in a PyCapsule using `client_call_tracer` as name. 75 76 Args: 77 method_name: The method name of the call in byte format. 78 target: The channel target of the call in byte format. 79 80 Returns: 81 A PyCapsule which stores a ClientCallTracer object. 82 """ 83 raise NotImplementedError() 84 85 @abc.abstractmethod 86 def delete_client_call_tracer( 87 self, client_call_tracer: ClientCallTracerCapsule 88 ) -> None: 89 """Deletes the ClientCallTracer stored in ClientCallTracerCapsule. 90 91 After register the plugin, if tracing or stats is enabled, this method 92 will be called at the end of the call to destroy the ClientCallTracer. 93 94 The ClientCallTracer is an object which implements `grpc_core::ClientCallTracer` 95 interface and wrapped in a PyCapsule using `client_call_tracer` as name. 96 97 Args: 98 client_call_tracer: A PyCapsule which stores a ClientCallTracer object. 99 """ 100 raise NotImplementedError() 101 102 @abc.abstractmethod 103 def save_trace_context( 104 self, trace_id: str, span_id: str, is_sampled: bool 105 ) -> None: 106 """Saves the trace_id and span_id related to the current span. 107 108 After register the plugin, if tracing is enabled, this method will be 109 called after the server finished sending response. 110 111 This method can be used to propagate census context. 112 113 Args: 114 trace_id: The identifier for the trace associated with the span as a 115 32-character hexadecimal encoded string, 116 e.g. 26ed0036f2eff2b7317bccce3e28d01f 117 span_id: The identifier for the span as a 16-character hexadecimal encoded 118 string. e.g. 113ec879e62583bc 119 is_sampled: A bool indicates whether the span is sampled. 120 """ 121 raise NotImplementedError() 122 123 @abc.abstractmethod 124 def create_server_call_tracer_factory( 125 self, 126 ) -> ServerCallTracerFactoryCapsule: 127 """Creates a ServerCallTracerFactoryCapsule. 128 129 After register the plugin, if tracing or stats is enabled, this method 130 will be called by calling observability_init, the ServerCallTracerFactory 131 created by this method will be registered to gRPC core. 132 133 The ServerCallTracerFactory is an object which implements 134 `grpc_core::ServerCallTracerFactory` interface and wrapped in a PyCapsule 135 using `server_call_tracer_factory` as name. 136 137 Returns: 138 A PyCapsule which stores a ServerCallTracerFactory object. 139 """ 140 raise NotImplementedError() 141 142 @abc.abstractmethod 143 def record_rpc_latency( 144 self, method: str, target: str, rpc_latency: float, status_code: Any 145 ) -> None: 146 """Record the latency of the RPC. 147 148 After register the plugin, if stats is enabled, this method will be 149 called at the end of each RPC. 150 151 Args: 152 method: The fully-qualified name of the RPC method being invoked. 153 target: The target name of the RPC method being invoked. 154 rpc_latency: The latency for the RPC in seconds, equals to the time between 155 when the client invokes the RPC and when the client receives the status. 156 status_code: An element of grpc.StatusCode in string format representing the 157 final status for the RPC. 158 """ 159 raise NotImplementedError() 160 161 def set_tracing(self, enable: bool) -> None: 162 """Enable or disable tracing. 163 164 Args: 165 enable: A bool indicates whether tracing should be enabled. 166 """ 167 self._tracing_enabled = enable 168 169 def set_stats(self, enable: bool) -> None: 170 """Enable or disable stats(metrics). 171 172 Args: 173 enable: A bool indicates whether stats should be enabled. 174 """ 175 self._stats_enabled = enable 176 177 @property 178 def tracing_enabled(self) -> bool: 179 return self._tracing_enabled 180 181 @property 182 def stats_enabled(self) -> bool: 183 return self._stats_enabled 184 185 @property 186 def observability_enabled(self) -> bool: 187 return self.tracing_enabled or self.stats_enabled 188 189 190@contextlib.contextmanager 191def get_plugin() -> Generator[Optional[ObservabilityPlugin], None, None]: 192 """Get the ObservabilityPlugin in _observability module. 193 194 Returns: 195 The ObservabilityPlugin currently registered with the _observability 196 module. Or None if no plugin exists at the time of calling this method. 197 """ 198 with _plugin_lock: 199 yield _OBSERVABILITY_PLUGIN 200 201 202def set_plugin(observability_plugin: Optional[ObservabilityPlugin]) -> None: 203 """Save ObservabilityPlugin to _observability module. 204 205 Args: 206 observability_plugin: The ObservabilityPlugin to save. 207 208 Raises: 209 ValueError: If an ObservabilityPlugin was already registered at the 210 time of calling this method. 211 """ 212 global _OBSERVABILITY_PLUGIN # pylint: disable=global-statement 213 with _plugin_lock: 214 if observability_plugin and _OBSERVABILITY_PLUGIN: 215 raise ValueError("observability_plugin was already set!") 216 _OBSERVABILITY_PLUGIN = observability_plugin 217 218 219def observability_init(observability_plugin: ObservabilityPlugin) -> None: 220 """Initialize observability with provided ObservabilityPlugin. 221 222 This method have to be called at the start of a program, before any 223 channels/servers are built. 224 225 Args: 226 observability_plugin: The ObservabilityPlugin to use. 227 228 Raises: 229 ValueError: If an ObservabilityPlugin was already registered at the 230 time of calling this method. 231 """ 232 set_plugin(observability_plugin) 233 try: 234 _cygrpc.set_server_call_tracer_factory(observability_plugin) 235 except Exception: # pylint:disable=broad-except 236 _LOGGER.exception("Failed to set server call tracer factory!") 237 238 239def observability_deinit() -> None: 240 """Clear the observability context, including ObservabilityPlugin and 241 ServerCallTracerFactory 242 243 This method have to be called after exit observability context so that 244 it's possible to re-initialize again. 245 """ 246 set_plugin(None) 247 _cygrpc.clear_server_call_tracer_factory() 248 249 250def delete_call_tracer(client_call_tracer_capsule: Any) -> None: 251 """Deletes the ClientCallTracer stored in ClientCallTracerCapsule. 252 253 This method will be called at the end of the call to destroy the ClientCallTracer. 254 255 The ClientCallTracer is an object which implements `grpc_core::ClientCallTracer` 256 interface and wrapped in a PyCapsule using `client_call_tracer` as the name. 257 258 Args: 259 client_call_tracer_capsule: A PyCapsule which stores a ClientCallTracer object. 260 """ 261 with get_plugin() as plugin: 262 if not (plugin and plugin.observability_enabled): 263 return 264 plugin.delete_client_call_tracer(client_call_tracer_capsule) 265 266 267def maybe_record_rpc_latency(state: "_channel._RPCState") -> None: 268 """Record the latency of the RPC, if the plugin is registered and stats is enabled. 269 270 This method will be called at the end of each RPC. 271 272 Args: 273 state: a grpc._channel._RPCState object which contains the stats related to the 274 RPC. 275 """ 276 # TODO(xuanwn): use channel args to exclude those metrics. 277 for exclude_prefix in _SERVICES_TO_EXCLUDE: 278 if exclude_prefix in state.method.encode("utf8"): 279 return 280 with get_plugin() as plugin: 281 if not (plugin and plugin.stats_enabled): 282 return 283 rpc_latency_s = state.rpc_end_time - state.rpc_start_time 284 rpc_latency_ms = rpc_latency_s * 1000 285 plugin.record_rpc_latency( 286 state.method, state.target, rpc_latency_ms, state.code 287 ) 288