1# Copyright 2023 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14import codecs 15from typing import Optional 16 17from libcpp.cast cimport static_cast 18 19from grpc import _observability 20 21 22cdef const char* CLIENT_CALL_TRACER = "client_call_tracer" 23cdef const char* SERVER_CALL_TRACER_FACTORY = "server_call_tracer_factory" 24 25 26def set_server_call_tracer_factory(object observability_plugin) -> None: 27 capsule = observability_plugin.create_server_call_tracer_factory() 28 capsule_ptr = cpython.PyCapsule_GetPointer(capsule, SERVER_CALL_TRACER_FACTORY) 29 _register_server_call_tracer_factory(capsule_ptr) 30 31 32def clear_server_call_tracer_factory() -> None: 33 _register_server_call_tracer_factory(NULL) 34 35 36def maybe_save_server_trace_context(RequestCallEvent event) -> None: 37 cdef ServerCallTracer* server_call_tracer 38 with _observability.get_plugin() as plugin: 39 if not (plugin and plugin.tracing_enabled): 40 return 41 server_call_tracer = static_cast['ServerCallTracer*'](_get_call_tracer(event.call.c_call)) 42 # TraceId and SpanId is hex string, need to convert to str 43 trace_id = _decode(codecs.decode(server_call_tracer.TraceId(), 'hex_codec')) 44 span_id = _decode(codecs.decode(server_call_tracer.SpanId(), 'hex_codec')) 45 is_sampled = server_call_tracer.IsSampled() 46 plugin.save_trace_context(trace_id, span_id, is_sampled) 47 48 49cdef void _set_call_tracer(grpc_call* call, void* capsule_ptr): 50 cdef ClientCallTracer* call_tracer = <ClientCallTracer*>capsule_ptr 51 grpc_call_context_set(call, GRPC_CONTEXT_CALL_TRACER_ANNOTATION_INTERFACE, call_tracer, NULL) 52 53 54cdef void* _get_call_tracer(grpc_call* call): 55 cdef void* call_tracer = grpc_call_context_get(call, GRPC_CONTEXT_CALL_TRACER_ANNOTATION_INTERFACE) 56 return call_tracer 57 58 59cdef void _register_server_call_tracer_factory(void* capsule_ptr): 60 cdef ServerCallTracerFactory* call_tracer_factory = <ServerCallTracerFactory*>capsule_ptr 61 ServerCallTracerFactory.RegisterGlobal(call_tracer_factory) 62