1 #pragma once 2 3 #include <ATen/Config.h> 4 #include <ATen/Functions.h> 5 #include <c10/macros/Macros.h> 6 #include <torch/csrc/Export.h> 7 #include <cstdint> 8 #include <vector> 9 10 #define FOR_ALL_EXTERNAL_FUNCTIONS(_) \ 11 _(nnc_aten_adaptive_avg_pool2d) \ 12 _(nnc_aten_addmm) \ 13 _(nnc_aten_conv2d) \ 14 _(nnc_aten_conv1d) \ 15 _(nnc_aten_conv1d_out) \ 16 _(nnc_aten_dequantize) \ 17 _(nnc_aten_dequantize_out) \ 18 _(nnc_aten_embedding) \ 19 _(nnc_aten_matmul) \ 20 _(nnc_aten_mv) \ 21 _(nnc_aten_mm) \ 22 _(nnc_aten_mean) \ 23 _(nnc_aten_max_red) \ 24 _(nnc_aten_max_red_out) \ 25 _(nnc_aten_quantized_conv1d) \ 26 _(nnc_aten_quantized_conv1d_out) \ 27 _(nnc_aten_quantized_conv2d) \ 28 _(nnc_aten_quantized_conv2d_out) \ 29 _(nnc_aten_quantized_conv2d_relu) \ 30 _(nnc_aten_quantized_conv2d_relu_out) \ 31 _(nnc_aten_quantized_linear) \ 32 _(nnc_aten_quantized_linear_out) \ 33 _(nnc_aten_quantized_linear_relu) \ 34 _(nnc_aten_quantized_add) \ 35 _(nnc_aten_quantized_cat) \ 36 _(nnc_aten_quantized_mul) \ 37 _(nnc_aten_quantized_mul_out) \ 38 _(nnc_aten_quantized_mul_scalar) \ 39 _(nnc_aten_quantized_mul_scalar_out) \ 40 _(nnc_aten_quantized_relu) \ 41 _(nnc_aten_quantized_sigmoid) \ 42 _(nnc_aten_quantized_sigmoid_out) \ 43 _(nnc_aten_quantize_per_tensor) \ 44 _(nnc_aten_quantize_per_tensor_out) \ 45 _(nnc_aten_triangular_solve) \ 46 _(nnc_aten_upsample_nearest2d) \ 47 _(nnc_aten_upsample_nearest2d_out) \ 48 _(nnc_prepacked_conv2d_clamp_run) \ 49 _(nnc_prepacked_linear_clamp_run) 50 51 #define DECLARE_EXTERNAL_FUNCTION(NAME) \ 52 TORCH_API void NAME( \ 53 int64_t bufs_num, \ 54 void** buf_data, \ 55 int64_t* buf_ranks, \ 56 int64_t* buf_dims, \ 57 int64_t* buf_strides, \ 58 int8_t* buf_dtypes, \ 59 int64_t args_num, \ 60 int64_t* extra_args); 61 62 namespace torch::jit::tensorexpr { 63 struct QIData final { 64 double scale; 65 int64_t zero; 66 c10::ScalarType scalarType; 67 }; 68 std::vector<at::Tensor> constructTensors( 69 int64_t bufs_num, 70 void** buf_data, 71 int64_t* buf_ranks, 72 int64_t* buf_dims, 73 int64_t* buf_strides, 74 int8_t* buf_dtypes, 75 std::optional<std::vector<std::pair<size_t, QIData>>> qdataArg = 76 std::nullopt); 77 78 std::vector<at::Tensor> constructTensors2( 79 int64_t bufs_in_num, 80 void** buf_data, 81 int64_t* buf_ranks, 82 int64_t* buf_dims, 83 int64_t* buf_strides, 84 int8_t* buf_dtypes, 85 std::optional<std::vector<std::pair<size_t, QIData>>> qdataArg = 86 std::nullopt, 87 size_t bufs_out_num = 0); 88 89 #ifdef C10_MOBILE 90 extern "C" { 91 #endif 92 void DispatchParallel( 93 int8_t* func, 94 int64_t start, 95 int64_t stop, 96 int8_t* packed_data) noexcept; 97 98 FOR_ALL_EXTERNAL_FUNCTIONS(DECLARE_EXTERNAL_FUNCTION) 99 #if AT_MKLDNN_ENABLED() 100 DECLARE_EXTERNAL_FUNCTION(nnc_mkldnn_prepacked_conv_run); 101 #endif 102 103 TORCH_API void nnc_aten_free(size_t bufs_num, void** ptrs) noexcept; 104 105 #ifdef C10_MOBILE 106 } // extern "C" 107 #endif 108 109 } // namespace torch::jit::tensorexpr 110 111 #undef DECLARE_EXTERNAL_FUNCTION 112