xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/external_functions.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Config.h>
4 #include <ATen/Functions.h>
5 #include <c10/macros/Macros.h>
6 #include <torch/csrc/Export.h>
7 #include <cstdint>
8 #include <vector>
9 
10 #define FOR_ALL_EXTERNAL_FUNCTIONS(_)   \
11   _(nnc_aten_adaptive_avg_pool2d)       \
12   _(nnc_aten_addmm)                     \
13   _(nnc_aten_conv2d)                    \
14   _(nnc_aten_conv1d)                    \
15   _(nnc_aten_conv1d_out)                \
16   _(nnc_aten_dequantize)                \
17   _(nnc_aten_dequantize_out)            \
18   _(nnc_aten_embedding)                 \
19   _(nnc_aten_matmul)                    \
20   _(nnc_aten_mv)                        \
21   _(nnc_aten_mm)                        \
22   _(nnc_aten_mean)                      \
23   _(nnc_aten_max_red)                   \
24   _(nnc_aten_max_red_out)               \
25   _(nnc_aten_quantized_conv1d)          \
26   _(nnc_aten_quantized_conv1d_out)      \
27   _(nnc_aten_quantized_conv2d)          \
28   _(nnc_aten_quantized_conv2d_out)      \
29   _(nnc_aten_quantized_conv2d_relu)     \
30   _(nnc_aten_quantized_conv2d_relu_out) \
31   _(nnc_aten_quantized_linear)          \
32   _(nnc_aten_quantized_linear_out)      \
33   _(nnc_aten_quantized_linear_relu)     \
34   _(nnc_aten_quantized_add)             \
35   _(nnc_aten_quantized_cat)             \
36   _(nnc_aten_quantized_mul)             \
37   _(nnc_aten_quantized_mul_out)         \
38   _(nnc_aten_quantized_mul_scalar)      \
39   _(nnc_aten_quantized_mul_scalar_out)  \
40   _(nnc_aten_quantized_relu)            \
41   _(nnc_aten_quantized_sigmoid)         \
42   _(nnc_aten_quantized_sigmoid_out)     \
43   _(nnc_aten_quantize_per_tensor)       \
44   _(nnc_aten_quantize_per_tensor_out)   \
45   _(nnc_aten_triangular_solve)          \
46   _(nnc_aten_upsample_nearest2d)        \
47   _(nnc_aten_upsample_nearest2d_out)    \
48   _(nnc_prepacked_conv2d_clamp_run)     \
49   _(nnc_prepacked_linear_clamp_run)
50 
51 #define DECLARE_EXTERNAL_FUNCTION(NAME) \
52   TORCH_API void NAME(                  \
53       int64_t bufs_num,                 \
54       void** buf_data,                  \
55       int64_t* buf_ranks,               \
56       int64_t* buf_dims,                \
57       int64_t* buf_strides,             \
58       int8_t* buf_dtypes,               \
59       int64_t args_num,                 \
60       int64_t* extra_args);
61 
62 namespace torch::jit::tensorexpr {
63 struct QIData final {
64   double scale;
65   int64_t zero;
66   c10::ScalarType scalarType;
67 };
68 std::vector<at::Tensor> constructTensors(
69     int64_t bufs_num,
70     void** buf_data,
71     int64_t* buf_ranks,
72     int64_t* buf_dims,
73     int64_t* buf_strides,
74     int8_t* buf_dtypes,
75     std::optional<std::vector<std::pair<size_t, QIData>>> qdataArg =
76         std::nullopt);
77 
78 std::vector<at::Tensor> constructTensors2(
79     int64_t bufs_in_num,
80     void** buf_data,
81     int64_t* buf_ranks,
82     int64_t* buf_dims,
83     int64_t* buf_strides,
84     int8_t* buf_dtypes,
85     std::optional<std::vector<std::pair<size_t, QIData>>> qdataArg =
86         std::nullopt,
87     size_t bufs_out_num = 0);
88 
89 #ifdef C10_MOBILE
90 extern "C" {
91 #endif
92 void DispatchParallel(
93     int8_t* func,
94     int64_t start,
95     int64_t stop,
96     int8_t* packed_data) noexcept;
97 
98 FOR_ALL_EXTERNAL_FUNCTIONS(DECLARE_EXTERNAL_FUNCTION)
99 #if AT_MKLDNN_ENABLED()
100 DECLARE_EXTERNAL_FUNCTION(nnc_mkldnn_prepacked_conv_run);
101 #endif
102 
103 TORCH_API void nnc_aten_free(size_t bufs_num, void** ptrs) noexcept;
104 
105 #ifdef C10_MOBILE
106 } // extern "C"
107 #endif
108 
109 } // namespace torch::jit::tensorexpr
110 
111 #undef DECLARE_EXTERNAL_FUNCTION
112