1# cython: language_level=2 2# distutils: language = c++ 3 4# Test case for defining a XLA custom call target in Cython, and registering 5# it via the xla_client SWIG API. 6 7from cpython.pycapsule cimport PyCapsule_New 8 9cdef void test_subtract_f32(void* out_ptr, void** data_ptr, 10 void* xla_custom_call_status) nogil: 11 cdef float a = (<float*>(data_ptr[0]))[0] 12 cdef float b = (<float*>(data_ptr[1]))[0] 13 cdef float* out = <float*>(out_ptr) 14 out[0] = a - b 15 16cdef void test_add_input_and_opaque_len(void* out_buffer, void** ins, 17 const char* opaque, size_t opaque_len, 18 void* xla_custom_call_status): 19 cdef float a = (<float*>(ins[0]))[0] 20 cdef float b = <float>opaque_len 21 cdef float* out = <float*>(out_buffer) 22 out[0] = a + b 23 24 25cpu_custom_call_targets = {} 26 27cdef register_custom_call_target(fn_name, void* fn): 28 cdef const char* name = "xla._CUSTOM_CALL_TARGET" 29 cpu_custom_call_targets[fn_name] = PyCapsule_New(fn, name, NULL) 30 31register_custom_call_target(b"test_subtract_f32", <void*>(test_subtract_f32)) 32register_custom_call_target(b"test_add_input_and_opaque_len", 33 <void*>(test_add_input_and_opaque_len)) 34