xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/custom_call_for_test.pyx (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# cython: language_level=2
2# distutils: language = c++
3
4# Test case for defining a XLA custom call target in Cython, and registering
5# it via the xla_client SWIG API.
6
7from cpython.pycapsule cimport PyCapsule_New
8
9cdef void test_subtract_f32(void* out_ptr, void** data_ptr,
10                            void* xla_custom_call_status) nogil:
11  cdef float a = (<float*>(data_ptr[0]))[0]
12  cdef float b = (<float*>(data_ptr[1]))[0]
13  cdef float* out = <float*>(out_ptr)
14  out[0] = a - b
15
16cdef void test_add_input_and_opaque_len(void* out_buffer, void** ins,
17                                        const char* opaque, size_t opaque_len,
18                                        void* xla_custom_call_status):
19  cdef float a = (<float*>(ins[0]))[0]
20  cdef float b = <float>opaque_len
21  cdef float* out = <float*>(out_buffer)
22  out[0] = a + b
23
24
25cpu_custom_call_targets = {}
26
27cdef register_custom_call_target(fn_name, void* fn):
28  cdef const char* name = "xla._CUSTOM_CALL_TARGET"
29  cpu_custom_call_targets[fn_name] = PyCapsule_New(fn, name, NULL)
30
31register_custom_call_target(b"test_subtract_f32", <void*>(test_subtract_f32))
32register_custom_call_target(b"test_add_input_and_opaque_len",
33                            <void*>(test_add_input_and_opaque_len))
34