1# Copyright 2023 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15cimport cpython
16from cython.operator cimport dereference
17
18import enum
19import functools
20import logging
21import os
22from threading import Thread
23from typing import List, Mapping, Tuple, Union
24
25from grpc_observability import _observability
26
27# Time we wait for batch exporting census data
28# TODO(xuanwn): change interval to a more appropriate number
29CENSUS_EXPORT_BATCH_INTERVAL_SECS = float(os.environ.get('GRPC_PYTHON_CENSUS_EXPORT_BATCH_INTERVAL_SECS', 0.5))
30GRPC_PYTHON_CENSUS_EXPORT_THREAD_TIMEOUT = float(os.environ.get('GRPC_PYTHON_CENSUS_EXPORT_THREAD_TIMEOUT', 10))
31cdef const char* CLIENT_CALL_TRACER = "client_call_tracer"
32cdef const char* SERVER_CALL_TRACER_FACTORY = "server_call_tracer_factory"
33cdef bint GLOBAL_SHUTDOWN_EXPORT_THREAD = False
34cdef object GLOBAL_EXPORT_THREAD
35
36_LOGGER = logging.getLogger(__name__)
37
38
39class _CyMetricsName:
40  CY_CLIENT_API_LATENCY = kRpcClientApiLatencyMeasureName
41  CY_CLIENT_SNET_MESSSAGES_PER_RPC = kRpcClientSentMessagesPerRpcMeasureName
42  CY_CLIENT_SEND_BYTES_PER_RPC = kRpcClientSentBytesPerRpcMeasureName
43  CY_CLIENT_RECEIVED_MESSAGES_PER_RPC = kRpcClientReceivedMessagesPerRpcMeasureName
44  CY_CLIENT_RECEIVED_BYTES_PER_RPC = kRpcClientReceivedBytesPerRpcMeasureName
45  CY_CLIENT_ROUNDTRIP_LATENCY = kRpcClientRoundtripLatencyMeasureName
46  CY_CLIENT_COMPLETED_RPC = kRpcClientCompletedRpcMeasureName
47  CY_CLIENT_SERVER_LATENCY = kRpcClientServerLatencyMeasureName
48  CY_CLIENT_STARTED_RPCS = kRpcClientStartedRpcsMeasureName
49  CY_CLIENT_RETRIES_PER_CALL = kRpcClientRetriesPerCallMeasureName
50  CY_CLIENT_TRANSPARENT_RETRIES_PER_CALL = kRpcClientTransparentRetriesPerCallMeasureName
51  CY_CLIENT_RETRY_DELAY_PER_CALL = kRpcClientRetryDelayPerCallMeasureName
52  CY_CLIENT_TRANSPORT_LATENCY = kRpcClientTransportLatencyMeasureName
53  CY_SERVER_SENT_MESSAGES_PER_RPC = kRpcServerSentMessagesPerRpcMeasureName
54  CY_SERVER_SENT_BYTES_PER_RPC = kRpcServerSentBytesPerRpcMeasureName
55  CY_SERVER_RECEIVED_MESSAGES_PER_RPC = kRpcServerReceivedMessagesPerRpcMeasureName
56  CY_SERVER_RECEIVED_BYTES_PER_RPC = kRpcServerReceivedBytesPerRpcMeasureName
57  CY_SERVER_SERVER_LATENCY = kRpcServerServerLatencyMeasureName
58  CY_SERVER_COMPLETED_RPC = kRpcServerCompletedRpcMeasureName
59  CY_SERVER_STARTED_RPCS = kRpcServerStartedRpcsMeasureName
60
61@enum.unique
62class MetricsName(enum.Enum):
63  CLIENT_STARTED_RPCS = _CyMetricsName.CY_CLIENT_STARTED_RPCS
64  CLIENT_API_LATENCY = _CyMetricsName.CY_CLIENT_API_LATENCY
65  CLIENT_SNET_MESSSAGES_PER_RPC = _CyMetricsName.CY_CLIENT_SNET_MESSSAGES_PER_RPC
66  CLIENT_SEND_BYTES_PER_RPC = _CyMetricsName.CY_CLIENT_SEND_BYTES_PER_RPC
67  CLIENT_RECEIVED_MESSAGES_PER_RPC = _CyMetricsName.CY_CLIENT_RECEIVED_MESSAGES_PER_RPC
68  CLIENT_RECEIVED_BYTES_PER_RPC = _CyMetricsName.CY_CLIENT_RECEIVED_BYTES_PER_RPC
69  CLIENT_ROUNDTRIP_LATENCY = _CyMetricsName.CY_CLIENT_ROUNDTRIP_LATENCY
70  CLIENT_COMPLETED_RPC = _CyMetricsName.CY_CLIENT_COMPLETED_RPC
71  CLIENT_SERVER_LATENCY = _CyMetricsName.CY_CLIENT_SERVER_LATENCY
72  CLIENT_RETRIES_PER_CALL = _CyMetricsName.CY_CLIENT_RETRIES_PER_CALL
73  CLIENT_TRANSPARENT_RETRIES_PER_CALL = _CyMetricsName.CY_CLIENT_TRANSPARENT_RETRIES_PER_CALL
74  CLIENT_RETRY_DELAY_PER_CALL = _CyMetricsName.CY_CLIENT_RETRY_DELAY_PER_CALL
75  SERVER_SENT_MESSAGES_PER_RPC = _CyMetricsName.CY_SERVER_SENT_MESSAGES_PER_RPC
76  SERVER_SENT_BYTES_PER_RPC = _CyMetricsName.CY_SERVER_SENT_BYTES_PER_RPC
77  SERVER_RECEIVED_MESSAGES_PER_RPC = _CyMetricsName.CY_SERVER_RECEIVED_MESSAGES_PER_RPC
78  SERVER_RECEIVED_BYTES_PER_RPC = _CyMetricsName.CY_SERVER_RECEIVED_BYTES_PER_RPC
79  SERVER_SERVER_LATENCY = _CyMetricsName.CY_SERVER_SERVER_LATENCY
80  SERVER_COMPLETED_RPC = _CyMetricsName.CY_SERVER_COMPLETED_RPC
81  SERVER_STARTED_RPCS = _CyMetricsName.CY_SERVER_STARTED_RPCS
82
83# Delay map creation due to circular dependencies
84_CY_METRICS_NAME_TO_PY_METRICS_NAME_MAPPING = {x.value: x for x in MetricsName}
85
86def cyobservability_init(object exporter) -> None:
87  exporter: _observability.Exporter
88
89  NativeObservabilityInit()
90  _start_exporting_thread(exporter)
91
92
93def _start_exporting_thread(object exporter) -> None:
94  exporter: _observability.Exporter
95
96  global GLOBAL_EXPORT_THREAD
97  global GLOBAL_SHUTDOWN_EXPORT_THREAD
98  GLOBAL_SHUTDOWN_EXPORT_THREAD = False
99  # TODO(xuanwn): Change it to daemon thread.
100  GLOBAL_EXPORT_THREAD = Thread(target=_export_census_data, args=(exporter,))
101  GLOBAL_EXPORT_THREAD.start()
102
103def activate_config(object py_config) -> None:
104  py_config: "_observability_config.GcpObservabilityConfig"
105
106  if (py_config.tracing_enabled):
107    EnablePythonCensusTracing(True);
108    # Save sampling rate to global sampler.
109    ProbabilitySampler.Get().SetThreshold(py_config.sampling_rate)
110
111  if (py_config.stats_enabled):
112    EnablePythonCensusStats(True);
113
114def activate_stats() -> None:
115  EnablePythonCensusStats(True);
116
117
118def create_client_call_tracer(bytes method_name, bytes target, bytes trace_id,
119                              bytes parent_span_id=b'') -> cpython.PyObject:
120  """Create a ClientCallTracer and save to PyCapsule.
121
122  Returns: A grpc_observability._observability.ClientCallTracerCapsule object.
123  """
124  cdef char* c_method = cpython.PyBytes_AsString(method_name)
125  cdef char* c_target = cpython.PyBytes_AsString(target)
126  cdef char* c_trace_id = cpython.PyBytes_AsString(trace_id)
127  cdef char* c_parent_span_id = cpython.PyBytes_AsString(parent_span_id)
128
129  cdef void* call_tracer = CreateClientCallTracer(c_method, c_target, c_trace_id, c_parent_span_id)
130  capsule = cpython.PyCapsule_New(call_tracer, CLIENT_CALL_TRACER, NULL)
131  return capsule
132
133
134def create_server_call_tracer_factory_capsule() -> cpython.PyObject:
135  """Create a ServerCallTracerFactory and save to PyCapsule.
136
137  Returns: A grpc_observability._observability.ServerCallTracerFactoryCapsule object.
138  """
139  cdef void* call_tracer_factory = CreateServerCallTracerFactory()
140  capsule = cpython.PyCapsule_New(call_tracer_factory, SERVER_CALL_TRACER_FACTORY, NULL)
141  return capsule
142
143
144def delete_client_call_tracer(object client_call_tracer) -> None:
145  client_call_tracer: grpc._observability.ClientCallTracerCapsule
146
147  if cpython.PyCapsule_IsValid(client_call_tracer, CLIENT_CALL_TRACER):
148    capsule_ptr = cpython.PyCapsule_GetPointer(client_call_tracer, CLIENT_CALL_TRACER)
149    call_tracer_ptr = <ClientCallTracer*>capsule_ptr
150    del call_tracer_ptr
151
152
153def _c_label_to_labels(vector[Label] c_labels) -> Mapping[str, str]:
154  py_labels = {}
155  for label in c_labels:
156    py_labels[_decode(label.key)] = _decode(label.value)
157  return py_labels
158
159
160def _c_measurement_to_measurement(object measurement
161  ) -> Mapping[str, Union[enum, Mapping[str, Union[float, int]]]]:
162  """Convert Cython Measurement to Python measurement.
163
164  Args:
165  measurement: Actual measurement repesented by Cython type Measurement, using object here
166   since Cython refuse to automatically convert a union with unsafe type combinations.
167
168  Returns:
169   A mapping object with keys and values as following:
170    name -> cMetricsName
171    type -> MeasurementType
172    value -> {value_double: float | value_int: int}
173  """
174  measurement: Measurement
175
176  py_measurement = {}
177  py_measurement['name'] = measurement['name']
178  py_measurement['type'] = measurement['type']
179  if measurement['type'] == kMeasurementDouble:
180    py_measurement['value'] = {'value_double': measurement['value']['value_double']}
181  else:
182    py_measurement['value'] = {'value_int': measurement['value']['value_int']}
183  return py_measurement
184
185
186def _c_annotation_to_annotations(vector[Annotation] c_annotations) -> List[Tuple[str, str]]:
187  py_annotations = []
188  for annotation in c_annotations:
189    py_annotations.append((_decode(annotation.time_stamp),
190                          _decode(annotation.description)))
191  return py_annotations
192
193
194def observability_deinit() -> None:
195  _shutdown_exporting_thread()
196  EnablePythonCensusStats(False)
197  EnablePythonCensusTracing(False)
198
199
200@functools.lru_cache(maxsize=None)
201def _cy_metric_name_to_py_metric_name(cMetricsName metric_name) -> MetricsName:
202  try:
203    return _CY_METRICS_NAME_TO_PY_METRICS_NAME_MAPPING[metric_name]
204  except KeyError:
205    raise ValueError('Invalid metric name %s' % metric_name)
206
207
208def _get_stats_data(object measurement, object labels) -> _observability.StatsData:
209  """Convert a Python measurement to StatsData.
210
211  Args:
212  measurement: A dict of type Mapping[str, Union[enum, Mapping[str, Union[float, int]]]]
213    with keys and values as following:
214      name -> cMetricsName
215      type -> MeasurementType
216      value -> {value_double: float | value_int: int}
217  labels: Labels assciociated with stats data with type of dict[str, str].
218  """
219  measurement: Measurement
220  labels: Mapping[str, str]
221
222  metric_name = _cy_metric_name_to_py_metric_name(measurement['name'])
223  if measurement['type'] == kMeasurementDouble:
224    py_stat = _observability.StatsData(name=metric_name, measure_double=True,
225                                       value_float=measurement['value']['value_double'],
226                                       labels=labels)
227  else:
228    py_stat = _observability.StatsData(name=metric_name, measure_double=False,
229                                       value_int=measurement['value']['value_int'],
230                                       labels=labels)
231  return py_stat
232
233
234def _get_tracing_data(SpanCensusData span_data, vector[Label] span_labels,
235                      vector[Annotation] span_annotations) -> _observability.TracingData:
236  py_span_labels = _c_label_to_labels(span_labels)
237  py_span_annotations = _c_annotation_to_annotations(span_annotations)
238  return _observability.TracingData(name=_decode(span_data.name),
239                                    start_time = _decode(span_data.start_time),
240                                    end_time = _decode(span_data.end_time),
241                                    trace_id = _decode(span_data.trace_id),
242                                    span_id = _decode(span_data.span_id),
243                                    parent_span_id = _decode(span_data.parent_span_id),
244                                    status = _decode(span_data.status),
245                                    should_sample = span_data.should_sample,
246                                    child_span_count = span_data.child_span_count,
247                                    span_labels = py_span_labels,
248                                    span_annotations = py_span_annotations)
249
250
251def _record_rpc_latency(object exporter, str method, str target, float rpc_latency, str status_code) -> None:
252  exporter: _observability.Exporter
253
254  measurement = {}
255  measurement['name'] = kRpcClientApiLatencyMeasureName
256  measurement['type'] = kMeasurementDouble
257  measurement['value'] = {'value_double': rpc_latency}
258
259  labels = {}
260  labels[_decode(kClientMethod)] = method.strip("/")
261  labels[_decode(kClientTarget)] = target
262  labels[_decode(kClientStatus)] = status_code
263  metric = _get_stats_data(measurement, labels)
264  exporter.export_stats_data([metric])
265
266
267cdef void _export_census_data(object exporter):
268  """Main function running in export thread."""
269  exporter: _observability.Exporter
270
271  cdef int export_interval_ms = CENSUS_EXPORT_BATCH_INTERVAL_SECS * 1000
272  while True:
273    with nogil:
274      while not GLOBAL_SHUTDOWN_EXPORT_THREAD:
275        lk = new unique_lock[mutex](g_census_data_buffer_mutex)
276        # Wait for next batch of census data OR timeout at fixed interval.
277        # Batch export census data to minimize the time we acquiring the GIL.
278        AwaitNextBatchLocked(dereference(lk), export_interval_ms)
279
280        # Break only when buffer have data
281        if not g_census_data_buffer.empty():
282          del lk
283          break
284        else:
285          del lk
286
287    _flush_census_data(exporter)
288
289    if GLOBAL_SHUTDOWN_EXPORT_THREAD:
290      break # Break to shutdown exporting thead
291
292  # Flush one last time before shutdown thread
293  _flush_census_data(exporter)
294
295
296cdef void _flush_census_data(object exporter):
297  exporter: _observability.Exporter
298
299  lk = new unique_lock[mutex](g_census_data_buffer_mutex)
300  if g_census_data_buffer.empty():
301    del lk
302    return
303  py_metrics_batch = []
304  py_spans_batch = []
305  while not g_census_data_buffer.empty():
306    c_census_data = g_census_data_buffer.front()
307    if c_census_data.type == kMetricData:
308      py_labels = _c_label_to_labels(c_census_data.labels)
309      py_measurement = _c_measurement_to_measurement(c_census_data.measurement_data)
310      py_metric = _get_stats_data(py_measurement, py_labels)
311      py_metrics_batch.append(py_metric)
312    else:
313      py_span = _get_tracing_data(c_census_data.span_data, c_census_data.span_data.span_labels,
314                                  c_census_data.span_data.span_annotations)
315      py_spans_batch.append(py_span)
316    g_census_data_buffer.pop()
317
318  del lk
319  exporter.export_stats_data(py_metrics_batch)
320  exporter.export_tracing_data(py_spans_batch)
321
322
323cdef void _shutdown_exporting_thread():
324  with nogil:
325    global GLOBAL_SHUTDOWN_EXPORT_THREAD
326    GLOBAL_SHUTDOWN_EXPORT_THREAD = True
327    g_census_data_buffer_cv.notify_all()
328  GLOBAL_EXPORT_THREAD.join(timeout=GRPC_PYTHON_CENSUS_EXPORT_THREAD_TIMEOUT)
329
330
331cdef str _decode(bytes bytestring):
332  if isinstance(bytestring, (str,)):
333    return <str>bytestring
334  else:
335    try:
336      return bytestring.decode('utf8')
337    except UnicodeDecodeError:
338      _LOGGER.exception('Invalid encoding on %s', bytestring)
339      return bytestring.decode('latin1')
340