1 #include <ATen/native/nested/NestedTensorMath.h>
2
3 #include <ATen/ATen.h>
4 #include <ATen/AccumulateType.h>
5 #include <ATen/NamedTensorUtils.h>
6 #include <ATen/WrapDimUtils.h>
7 #include <ATen/core/op_registration/op_registration.h>
8 #include <ATen/native/layer_norm.h>
9 #include <ATen/NestedTensorImpl.h>
10 #include <c10/core/DispatchKey.h>
11 #include <ATen/native/nested/NestedTensorUtils.h>
12 #include <ATen/native/nested/NestedTensorMath.h>
13 #include <ATen/native/layer_norm.h>
14 #include <c10/core/DeviceType.h>
15
16 #include <utility>
17
18 namespace at::native {
19
20 // See Note [nested tensor matmul] in NestedTensorMath.cpp
matmul_backward_nested(const Tensor & grad,const Tensor & self,const Tensor & other,std::array<bool,2> grad_input_mask)21 std::tuple<Tensor, Tensor> matmul_backward_nested(
22 const Tensor& grad,
23 const Tensor& self,
24 const Tensor& other,
25 std::array<bool, 2> grad_input_mask) {
26 if (!grad.defined()) {
27 return std::make_tuple(Tensor(), Tensor());
28 }
29 Tensor grad_self, grad_other;
30 if (grad_input_mask[0]) {
31 grad_self = at::matmul(grad, other.transpose(-1, -2));
32 }
33 if (grad_input_mask[1]) {
34 grad_other = at::matmul(self.transpose(-1, -2), grad);
35 }
36 return std::make_tuple(grad_self, grad_other);
37 }
38
nested_linear_backward(const Tensor & input,const Tensor & grad_output,const Tensor & weight,std::array<bool,3> output_mask)39 std::tuple<Tensor, Tensor, Tensor> nested_linear_backward(
40 const Tensor& input,
41 const Tensor& grad_output,
42 const Tensor& weight,
43 std::array<bool, 3> output_mask) {
44 if (!grad_output.defined()) {
45 return std::tuple<Tensor, Tensor, Tensor>{Tensor(), Tensor(), Tensor()};
46 }
47 Tensor grad_input, grad_weight, grad_bias;
48 auto grad_output_contiguous = grad_output.contiguous();
49 auto* nt_grad_output = get_nested_tensor_impl(grad_output_contiguous);
50 auto* nt_input = get_nested_tensor_impl(input);
51 TORCH_INTERNAL_ASSERT(nt_grad_output != nullptr);
52 TORCH_INTERNAL_ASSERT(nt_input != nullptr);
53 TORCH_INTERNAL_ASSERT(nested_tensor_impl_is_contiguous(nt_grad_output));
54 auto grad_output_buffer = nt_grad_output->get_buffer();
55 auto input_buffer = nt_input->get_buffer();
56
57 auto reshaped_grad = grad_output_buffer.reshape({-1, weight.size(0)});
58
59 if (output_mask[0]) {
60 auto grad_input_buffer = at::mm(reshaped_grad, weight).view({-1});
61 auto grad_input_nt_size = nt_input->get_nested_sizes().clone();
62 grad_input = wrap_buffer(grad_input_buffer, grad_input_nt_size);
63 }
64 if (output_mask[1]) {
65 grad_weight =
66 at::mm(reshaped_grad.t(), input_buffer.reshape({-1, weight.size(1)}));
67 }
68 if (output_mask[2]) {
69 grad_bias = reshaped_grad.sum(0);
70 }
71 return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias};
72 }
73
nested_softmax_backward(const Tensor & grad,const Tensor & output,int64_t dim,ScalarType input_dtype)74 Tensor nested_softmax_backward(
75 const Tensor& grad,
76 const Tensor& output,
77 int64_t dim,
78 ScalarType input_dtype) {
79 TORCH_INTERNAL_ASSERT(grad.is_nested(), "Should be nested grad")
80 TORCH_INTERNAL_ASSERT(output.is_nested(), "Should be nested output")
81
82 auto output_ptr = get_nested_tensor_impl(output);
83 auto grad_ptr = get_nested_tensor_impl(grad);
84 int64_t ntensors = output_ptr->size(0);
85 if (ntensors == 0) {
86 return grad.clone();
87 }
88 int64_t positive_dim = at::maybe_wrap_dim(dim, output_ptr->dim());
89
90 // Get the info about the output
91 const Tensor &output_buffer = output_ptr->get_buffer(),
92 &output_sizemat = output_ptr->get_nested_sizes();
93
94 // Get the info about the grad
95 const Tensor &grad_sizemat = grad_ptr->get_nested_sizes();
96
97 TORCH_INTERNAL_ASSERT(output_sizemat.equal(grad_sizemat));
98 Tensor grad_output =
99 wrap_buffer(at::empty_like(output_buffer), output_sizemat.clone());
100
101 // Unbind nt into individual tensor slices for calculating the derivative
102 std::vector<Tensor> grad_output_unbind{grad_output.unbind()},
103 grad_unbind{grad.unbind()}, output_unbind{output.unbind()};
104
105 for(const auto i: c10::irange(ntensors)) {
106 at::_softmax_backward_data_out(
107 grad_output_unbind[i],
108 grad_unbind[i],
109 output_unbind[i],
110 positive_dim - 1,
111 input_dtype);
112 }
113 return grad_output;
114
115 }
116
117 // Rudimentary sum backward assuming the conditions in #82387
_nested_sum_backward_cpu(const Tensor & grad,const Tensor & nested_self,OptionalIntArrayRef opt_dims,bool keepdim)118 Tensor _nested_sum_backward_cpu(
119 const Tensor& grad,
120 const Tensor& nested_self,
121 OptionalIntArrayRef opt_dims,
122 bool keepdim) {
123 auto nt_self = get_nested_tensor_impl(nested_self);
124 auto nt_grad = get_nested_tensor_impl(grad);
125 const Tensor& grad_buffer = nt_grad->get_buffer();
126 const Tensor& self_buffer = nt_self->get_buffer();
127 auto grad_sizes = nt_grad->get_nested_sizes();
128 auto self_sizes = nt_self->get_nested_sizes();
129 int64_t ntensors = nt_self->size(0);
130 const Tensor& self_grad_buffer = self_buffer.new_empty(self_buffer.sizes());
131
132 auto num_segments = at::prod(grad_sizes, -1);
133 auto segment_lengths = self_sizes.select(1, -1);
134
135 // This logic assumes for now that
136 // (1) all the gradient nested tensors are contiguous
137 // (2) the gradient nested tensors are stored contiguously in the buffer
138 AT_DISPATCH_ALL_TYPES_AND2(
139 ScalarType::Half, ScalarType::BFloat16, self_grad_buffer.scalar_type(), "nested_sum_dim_cpu", [&]() {
140 auto* self_grad_data = self_grad_buffer.data_ptr<scalar_t>();
141 const auto* output_grad_data = grad_buffer.const_data_ptr<scalar_t>();
142 int64_t out_idx = 0, in_idx = 0;
143 for (const auto i : c10::irange(ntensors)) {
144 int64_t segments = num_segments[i].item<int64_t>();
145 int64_t segment_length = segment_lengths[i].item<int64_t>();
146 for (auto j = 0; j < segments; j++) {
147 scalar_t output_grad = output_grad_data[out_idx];
148 for (auto k = 0; k < segment_length; k++) {
149 self_grad_data[in_idx] = output_grad;
150 in_idx += 1;
151 }
152 out_idx += 1;
153 }
154 }
155 });
156
157 return wrap_buffer(self_grad_buffer, self_sizes);
158
159 }
160
161
_nested_select_backward_symint(const Tensor & grad,const Tensor & nested_self,int64_t dim,c10::SymInt index)162 Tensor _nested_select_backward_symint(
163 const Tensor& grad,
164 const Tensor& nested_self,
165 int64_t dim,
166 // NOLINTNEXTLINE(performance-unnecessary-value-param)
167 c10::SymInt index) {
168 auto nt_self = get_nested_tensor_impl(nested_self);
169 const Tensor& self_buffer = nt_self->get_buffer();
170 const auto self_sizes = nt_self->get_nested_sizes();
171 const Tensor& self_grad_buffer = self_buffer.new_zeros(self_buffer.sizes());
172
173 auto nt_grad = wrap_buffer(self_grad_buffer, self_sizes);
174 nt_grad.select_symint(dim, std::move(index)).copy_(grad);
175
176 return nt_grad;
177 }
178
gelu_backwards_nested(const Tensor & grad,const Tensor & self,c10::string_view approximate)179 Tensor gelu_backwards_nested(const Tensor& grad, const Tensor& self, c10::string_view approximate){
180 auto partial_gelu_backward = [approximate](auto && PH1, auto && PH2) { return at::gelu_backward(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2), approximate); };
181 return map_nt_binary(grad, self, partial_gelu_backward);
182 }
183
184 // Naming convention for relu
threshold_backwards_nested(const Tensor & grad_output,const Tensor & input,const Scalar & threshold)185 Tensor threshold_backwards_nested(const Tensor& grad_output, const Tensor& input, const Scalar& threshold){
186 auto partial_relu_backward = [threshold](auto && PH1, auto && PH2) { return at::threshold_backward(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2), threshold); };
187 return map_nt_binary(grad_output, input, partial_relu_backward);
188 }
189
190 // Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!)
silu_backward_nested(const Tensor & grad_output,const Tensor & self)191 Tensor silu_backward_nested(const Tensor& grad_output, const Tensor& self){
192 auto partial_silu_backward = [](auto && PH1, auto && PH2) { return at::silu_backward(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2)); };
193 return map_nt_binary(grad_output, self, partial_silu_backward);
194 }
195
layer_norm_backward_nested(const Tensor & grad,const Tensor & input,IntArrayRef normalized_shape,const Tensor & mean,const Tensor & rstd,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,std::array<bool,3> grad_input_mask)196 std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_nested(
197 const Tensor& grad,
198 const Tensor& input,
199 IntArrayRef normalized_shape,
200 const Tensor& mean,
201 const Tensor& rstd,
202 const std::optional<Tensor>& weight_opt /* optional */,
203 const std::optional<Tensor>& bias_opt /*{ optional */,
204 std::array<bool, 3> grad_input_mask) {
205 // For NestedTensors weight and bias are non nested.
206 auto* nt_impl_grad = get_nested_tensor_impl(grad);
207 auto* nt_impl_input = get_nested_tensor_impl(input);
208 const auto& weight = *weight_opt;
209 const auto& bias = *bias_opt;
210 const auto& sizes = nt_impl_input->get_nested_sizes();
211 auto M_N = _check_nested_layer_norm_inputs(
212 *nt_impl_input, normalized_shape, weight, bias);
213 auto M = M_N.first;
214 auto N = M_N.second;
215
216 auto gamma = weight.expect_contiguous();
217 auto beta = bias.expect_contiguous();
218
219 Tensor dInput;
220 Tensor dgamma;
221 Tensor dbeta;
222 auto input_buffer = nt_impl_input->get_buffer();
223 auto grad_buffer = nt_impl_grad->get_buffer();
224 // NOLINTNEXTLINE(bugprone-branch-clone)
225 if (grad_input_mask[0]) {
226 dInput = at::native::empty_like(
227 input_buffer,
228 std::nullopt /* dtype */,
229 std::nullopt /* layout */,
230 std::nullopt /* device */,
231 std::nullopt /* pin_memory */,
232 at::MemoryFormat::Contiguous);
233 } else {
234 dInput = at::native::zeros_like(
235 input_buffer,
236 std::nullopt /* dtype */,
237 std::nullopt /* layout */,
238 std::nullopt /* device */,
239 std::nullopt /* pin_memory */,
240 at::MemoryFormat::Contiguous);
241 }
242 if (grad_input_mask[1]) {
243 dgamma = M > 0 ? at::native::empty_like(
244 *gamma,
245 std::nullopt /* dtype */,
246 std::nullopt /* layout */,
247 std::nullopt /* device */,
248 std::nullopt /* pin_memory */,
249 at::MemoryFormat::Contiguous)
250 : at::native::zeros_like(
251 *gamma,
252 std::nullopt /* dtype */,
253 std::nullopt /* layout */,
254 std::nullopt /* device */,
255 std::nullopt /* pin_memory */,
256 at::MemoryFormat::Contiguous);
257 }
258 if (grad_input_mask[2]) {
259 dbeta = M > 0 ? at::native::empty_like(
260 *beta,
261 std::nullopt /* dtype */,
262 std::nullopt /* layout */,
263 std::nullopt /* device */,
264 std::nullopt /* pin_memory */,
265 at::MemoryFormat::Contiguous)
266 : at::native::zeros_like(
267 *beta,
268 std::nullopt /* dtype */,
269 std::nullopt /* layout */,
270 std::nullopt /* device */,
271 std::nullopt /* pin_memory */,
272 at::MemoryFormat::Contiguous);
273 }
274 if (M > 0) {
275 LayerNormBackwardKernel(
276 input_buffer.is_cuda() ? kCUDA : kCPU,
277 grad_buffer,
278 input_buffer,
279 mean,
280 rstd,
281 *gamma,
282 M,
283 N,
284 &dInput,
285 &dgamma,
286 &dbeta);
287 }
288 return std::make_tuple(
289 wrap_buffer(dInput, sizes), std::move(dgamma), std::move(dbeta));
290 }
291
292 } // namespace at::native
293