1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/FunctionOfAMatrixUtils.h>
3
4 #include <ATen/core/Tensor.h>
5 #include <ATen/TensorIterator.h>
6
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #include <ATen/NativeFunctions.h>
10 #else
11 #include <ATen/ops/_compute_linear_combination_native.h>
12 #include <ATen/ops/zeros.h>
13 #endif
14
15 namespace at::native {
16
17 DEFINE_DISPATCH(_compute_linear_combination_stub);
18
19 // If `coefficients` is a [m, n] Tensor and
20 // `input` is a [n, ...] Tensor, then the output
21 // `output` is going to be a [m, ...] Tensor such that
22 // for i in range(m):
23 // for j in range(n):
24 // output[i, ...] += coefficients[i, j] * input[j, ...]
25 //
26 // Note: if input.dtype == scalar_t<T>, then coefficients.dtype == T.
27 // This is relevant when scalar_t<T> == complex<T>.
_compute_linear_combination(const Tensor & input,const Tensor & coefficients)28 Tensor _compute_linear_combination(const Tensor& input, const Tensor& coefficients) {
29 TORCH_CHECK(input.ndimension() > 0 && input.numel() > 0, "Empty tensor not supported");
30 auto output_first_dim_size = coefficients.size(0);
31
32 auto output_sizes = input.sizes().vec();
33 output_sizes[0] = output_first_dim_size;
34 auto output = at::zeros(
35 output_sizes,
36 input.options().memory_format(at::MemoryFormat::Contiguous)
37 );
38
39 native::_compute_linear_combination_out(input, coefficients, output);
40
41 return output;
42 }
43
44 // Note: the function is implemented using the __restrict__ memory modifier,
45 // which means that if `output` actually is aliased by `input`, the result
46 // produced is undefined.
_compute_linear_combination_out(const Tensor & input,const Tensor & coefficients,Tensor & output)47 Tensor& _compute_linear_combination_out(const Tensor& input, const Tensor& coefficients, Tensor& output) {
48 auto output_first_dim_size = coefficients.size(0);
49 auto input_first_dim_size = coefficients.size(1);
50
51 // Recall that `coefficients` is a [m, n] Tensor,
52 // `input` is a [n, ...] Tensor, `output` is a [m, ...] Tensor.
53 // We restride Tensors to the common dim == input.dim() + 1, so that
54 // coefficients.sizes() = [m, 1 (instead of n), 1 repeated (input.dim() - 1) times],
55 // input.sizes() = [1, 1 (instead of n), ...],
56 // output.sizes() = [m, 1 (instead of n), ...].
57 // The second dimension in newly restrided Tensors is traversed inside the kernels.
58 // This is done to avoid synchronizations/atomic operations in the kernels
59 // and also guarantees determinism, required by the autograd.
60
61 // restride output
62 auto output_to_broadcasted_dim = output.unsqueeze(1);
63 auto output_restrided_sizes = output_to_broadcasted_dim.sizes().vec();
64 auto output_restrided_strides = output_to_broadcasted_dim.strides().vec();
65 output_restrided_sizes[1] = 1;
66 output_restrided_strides[1] = 0;
67 auto output_restrided = output.as_strided(
68 output_restrided_sizes,
69 output_restrided_strides
70 );
71
72 // restride input
73 auto input_to_broadcasted_dim = input.unsqueeze(0);
74 auto input_restrided_sizes = input_to_broadcasted_dim.sizes().vec();
75 auto input_restrided_strides = input_to_broadcasted_dim.strides().vec();
76 input_restrided_sizes[1] = 1;
77 input_restrided_strides[1] = 0;
78 auto input_restrided = input.as_strided(
79 input_restrided_sizes,
80 input_restrided_strides
81 );
82
83 // restride coefficients
84 auto coefficients_restrided_sizes = std::vector<int64_t>(input.dim() + 1, 1);
85 coefficients_restrided_sizes[0] = output_first_dim_size;
86 coefficients_restrided_sizes[1] = 1;
87 auto coefficients_restrided_strides = std::vector<int64_t>(input.dim() + 1, 0);
88 coefficients_restrided_strides[0] = coefficients.stride(0);
89 coefficients_restrided_strides[1] = 0;
90 auto coefficients_restrided = coefficients.as_strided(
91 coefficients_restrided_sizes,
92 coefficients_restrided_strides
93 );
94
95 auto iter = TensorIteratorConfig()
96 .set_check_mem_overlap(false) // Output is intentionally 0 strided above
97 .check_all_same_dtype(false)
98 .resize_outputs(false)
99 .add_output(output_restrided)
100 .add_input(input_restrided)
101 .add_input(coefficients_restrided)
102 .build();
103
104 // The dimension of size n is traversed inside the kernels,
105 // it is the first dimension of `input` and the second of `coefficients`
106 auto input_stride = input.stride(0);
107 auto coeff_stride = coefficients.stride(1);
108 _compute_linear_combination_stub(
109 iter.device_type(),
110 iter,
111 input_stride,
112 coeff_stride,
113 input_first_dim_size
114 );
115 return output;
116 }
117
118 } // namespace at::native
119