xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/FunctionOfAMatrixUtils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/FunctionOfAMatrixUtils.h>
3 
4 #include <ATen/core/Tensor.h>
5 #include <ATen/TensorIterator.h>
6 
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #include <ATen/NativeFunctions.h>
10 #else
11 #include <ATen/ops/_compute_linear_combination_native.h>
12 #include <ATen/ops/zeros.h>
13 #endif
14 
15 namespace at::native {
16 
17 DEFINE_DISPATCH(_compute_linear_combination_stub);
18 
19 // If `coefficients` is a [m, n] Tensor and
20 // `input` is a [n, ...] Tensor, then the output
21 // `output` is going to be a [m, ...] Tensor such that
22 // for i in range(m):
23 //    for j in range(n):
24 //        output[i, ...] += coefficients[i, j] * input[j, ...]
25 //
26 // Note: if input.dtype == scalar_t<T>, then coefficients.dtype == T.
27 // This is relevant when scalar_t<T> == complex<T>.
_compute_linear_combination(const Tensor & input,const Tensor & coefficients)28 Tensor _compute_linear_combination(const Tensor& input, const Tensor& coefficients) {
29   TORCH_CHECK(input.ndimension() > 0 && input.numel() > 0, "Empty tensor not supported");
30   auto output_first_dim_size = coefficients.size(0);
31 
32   auto output_sizes = input.sizes().vec();
33   output_sizes[0] = output_first_dim_size;
34   auto output = at::zeros(
35     output_sizes,
36     input.options().memory_format(at::MemoryFormat::Contiguous)
37   );
38 
39   native::_compute_linear_combination_out(input, coefficients, output);
40 
41   return output;
42 }
43 
44 // Note: the function is implemented using the __restrict__ memory modifier,
45 // which means that if `output` actually is aliased by `input`, the result
46 // produced is undefined.
_compute_linear_combination_out(const Tensor & input,const Tensor & coefficients,Tensor & output)47 Tensor& _compute_linear_combination_out(const Tensor& input, const Tensor& coefficients, Tensor& output) {
48   auto output_first_dim_size = coefficients.size(0);
49   auto input_first_dim_size = coefficients.size(1);
50 
51   // Recall that `coefficients` is a [m, n] Tensor,
52   // `input` is a [n, ...] Tensor, `output` is a [m, ...] Tensor.
53   // We restride Tensors to the common dim == input.dim() + 1, so that
54   // coefficients.sizes() = [m, 1 (instead of n), 1 repeated (input.dim() - 1) times],
55   // input.sizes() = [1, 1 (instead of n), ...],
56   // output.sizes() = [m, 1 (instead of n), ...].
57   // The second dimension in newly restrided Tensors is traversed inside the kernels.
58   // This is done to avoid synchronizations/atomic operations in the kernels
59   // and also guarantees determinism, required by the autograd.
60 
61   // restride output
62   auto output_to_broadcasted_dim = output.unsqueeze(1);
63   auto output_restrided_sizes = output_to_broadcasted_dim.sizes().vec();
64   auto output_restrided_strides = output_to_broadcasted_dim.strides().vec();
65   output_restrided_sizes[1] = 1;
66   output_restrided_strides[1] = 0;
67   auto output_restrided = output.as_strided(
68     output_restrided_sizes,
69     output_restrided_strides
70   );
71 
72   // restride input
73   auto input_to_broadcasted_dim = input.unsqueeze(0);
74   auto input_restrided_sizes = input_to_broadcasted_dim.sizes().vec();
75   auto input_restrided_strides = input_to_broadcasted_dim.strides().vec();
76   input_restrided_sizes[1] = 1;
77   input_restrided_strides[1] = 0;
78   auto input_restrided = input.as_strided(
79     input_restrided_sizes,
80     input_restrided_strides
81   );
82 
83   // restride coefficients
84   auto coefficients_restrided_sizes = std::vector<int64_t>(input.dim() + 1, 1);
85   coefficients_restrided_sizes[0] = output_first_dim_size;
86   coefficients_restrided_sizes[1] = 1;
87   auto coefficients_restrided_strides = std::vector<int64_t>(input.dim() + 1, 0);
88   coefficients_restrided_strides[0] = coefficients.stride(0);
89   coefficients_restrided_strides[1] = 0;
90   auto coefficients_restrided = coefficients.as_strided(
91     coefficients_restrided_sizes,
92     coefficients_restrided_strides
93   );
94 
95   auto iter = TensorIteratorConfig()
96     .set_check_mem_overlap(false)  // Output is intentionally 0 strided above
97     .check_all_same_dtype(false)
98     .resize_outputs(false)
99     .add_output(output_restrided)
100     .add_input(input_restrided)
101     .add_input(coefficients_restrided)
102     .build();
103 
104   // The dimension of size n is traversed inside the kernels,
105   // it is the first dimension of `input` and the second of `coefficients`
106   auto input_stride = input.stride(0);
107   auto coeff_stride = coefficients.stride(1);
108   _compute_linear_combination_stub(
109     iter.device_type(),
110     iter,
111     input_stride,
112     coeff_stride,
113     input_first_dim_size
114   );
115   return output;
116 }
117 
118 } // namespace at::native
119