xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/nested/NestedTensorBackward.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/native/nested/NestedTensorMath.h>
2 
3 #include <ATen/ATen.h>
4 #include <ATen/AccumulateType.h>
5 #include <ATen/NamedTensorUtils.h>
6 #include <ATen/WrapDimUtils.h>
7 #include <ATen/core/op_registration/op_registration.h>
8 #include <ATen/native/layer_norm.h>
9 #include <ATen/NestedTensorImpl.h>
10 #include <c10/core/DispatchKey.h>
11 #include <ATen/native/nested/NestedTensorUtils.h>
12 #include <ATen/native/nested/NestedTensorMath.h>
13 #include <ATen/native/layer_norm.h>
14 #include <c10/core/DeviceType.h>
15 
16 #include <utility>
17 
18 namespace at::native {
19 
20 // See Note [nested tensor matmul] in NestedTensorMath.cpp
matmul_backward_nested(const Tensor & grad,const Tensor & self,const Tensor & other,std::array<bool,2> grad_input_mask)21 std::tuple<Tensor, Tensor> matmul_backward_nested(
22     const Tensor& grad,
23     const Tensor& self,
24     const Tensor& other,
25     std::array<bool, 2> grad_input_mask) {
26   if (!grad.defined()) {
27     return std::make_tuple(Tensor(), Tensor());
28   }
29   Tensor grad_self, grad_other;
30   if (grad_input_mask[0]) {
31     grad_self = at::matmul(grad, other.transpose(-1, -2));
32   }
33   if (grad_input_mask[1]) {
34     grad_other = at::matmul(self.transpose(-1, -2), grad);
35   }
36   return std::make_tuple(grad_self, grad_other);
37 }
38 
nested_linear_backward(const Tensor & input,const Tensor & grad_output,const Tensor & weight,std::array<bool,3> output_mask)39 std::tuple<Tensor, Tensor, Tensor> nested_linear_backward(
40     const Tensor& input,
41     const Tensor& grad_output,
42     const Tensor& weight,
43     std::array<bool, 3> output_mask) {
44   if (!grad_output.defined()) {
45     return std::tuple<Tensor, Tensor, Tensor>{Tensor(), Tensor(), Tensor()};
46   }
47   Tensor grad_input, grad_weight, grad_bias;
48   auto grad_output_contiguous = grad_output.contiguous();
49   auto* nt_grad_output = get_nested_tensor_impl(grad_output_contiguous);
50   auto* nt_input = get_nested_tensor_impl(input);
51   TORCH_INTERNAL_ASSERT(nt_grad_output != nullptr);
52   TORCH_INTERNAL_ASSERT(nt_input != nullptr);
53   TORCH_INTERNAL_ASSERT(nested_tensor_impl_is_contiguous(nt_grad_output));
54   auto grad_output_buffer = nt_grad_output->get_buffer();
55   auto input_buffer = nt_input->get_buffer();
56 
57   auto reshaped_grad = grad_output_buffer.reshape({-1, weight.size(0)});
58 
59   if (output_mask[0]) {
60     auto grad_input_buffer = at::mm(reshaped_grad, weight).view({-1});
61     auto grad_input_nt_size = nt_input->get_nested_sizes().clone();
62     grad_input = wrap_buffer(grad_input_buffer, grad_input_nt_size);
63   }
64   if (output_mask[1]) {
65     grad_weight =
66         at::mm(reshaped_grad.t(), input_buffer.reshape({-1, weight.size(1)}));
67   }
68   if (output_mask[2]) {
69     grad_bias = reshaped_grad.sum(0);
70   }
71   return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias};
72 }
73 
nested_softmax_backward(const Tensor & grad,const Tensor & output,int64_t dim,ScalarType input_dtype)74 Tensor nested_softmax_backward(
75     const Tensor& grad,
76     const Tensor& output,
77     int64_t dim,
78     ScalarType input_dtype) {
79   TORCH_INTERNAL_ASSERT(grad.is_nested(), "Should be nested grad")
80   TORCH_INTERNAL_ASSERT(output.is_nested(), "Should be nested output")
81 
82   auto output_ptr = get_nested_tensor_impl(output);
83   auto grad_ptr = get_nested_tensor_impl(grad);
84   int64_t ntensors = output_ptr->size(0);
85   if (ntensors == 0) {
86     return grad.clone();
87   }
88   int64_t positive_dim = at::maybe_wrap_dim(dim, output_ptr->dim());
89 
90   //  Get the info about the output
91   const Tensor &output_buffer = output_ptr->get_buffer(),
92                &output_sizemat = output_ptr->get_nested_sizes();
93 
94   //  Get the info about the grad
95   const Tensor &grad_sizemat = grad_ptr->get_nested_sizes();
96 
97   TORCH_INTERNAL_ASSERT(output_sizemat.equal(grad_sizemat));
98   Tensor grad_output =
99       wrap_buffer(at::empty_like(output_buffer), output_sizemat.clone());
100 
101   // Unbind nt into individual tensor slices for calculating the derivative
102   std::vector<Tensor> grad_output_unbind{grad_output.unbind()},
103       grad_unbind{grad.unbind()}, output_unbind{output.unbind()};
104 
105   for(const auto i: c10::irange(ntensors)) {
106     at::_softmax_backward_data_out(
107         grad_output_unbind[i],
108         grad_unbind[i],
109         output_unbind[i],
110         positive_dim - 1,
111         input_dtype);
112   }
113   return grad_output;
114 
115 }
116 
117 // Rudimentary sum backward assuming the conditions in #82387
_nested_sum_backward_cpu(const Tensor & grad,const Tensor & nested_self,OptionalIntArrayRef opt_dims,bool keepdim)118 Tensor _nested_sum_backward_cpu(
119   const Tensor& grad,
120   const Tensor& nested_self,
121   OptionalIntArrayRef opt_dims,
122   bool keepdim) {
123   auto nt_self = get_nested_tensor_impl(nested_self);
124   auto nt_grad = get_nested_tensor_impl(grad);
125   const Tensor& grad_buffer = nt_grad->get_buffer();
126   const Tensor& self_buffer = nt_self->get_buffer();
127   auto grad_sizes = nt_grad->get_nested_sizes();
128   auto self_sizes = nt_self->get_nested_sizes();
129   int64_t ntensors = nt_self->size(0);
130   const Tensor& self_grad_buffer = self_buffer.new_empty(self_buffer.sizes());
131 
132   auto num_segments = at::prod(grad_sizes, -1);
133   auto segment_lengths = self_sizes.select(1, -1);
134 
135   // This logic assumes for now that
136   // (1) all the gradient nested tensors are contiguous
137   // (2) the gradient nested tensors are stored contiguously in the buffer
138   AT_DISPATCH_ALL_TYPES_AND2(
139     ScalarType::Half, ScalarType::BFloat16, self_grad_buffer.scalar_type(), "nested_sum_dim_cpu", [&]() {
140     auto* self_grad_data = self_grad_buffer.data_ptr<scalar_t>();
141     const auto* output_grad_data = grad_buffer.const_data_ptr<scalar_t>();
142     int64_t out_idx = 0, in_idx = 0;
143     for (const auto i : c10::irange(ntensors)) {
144       int64_t segments = num_segments[i].item<int64_t>();
145       int64_t segment_length = segment_lengths[i].item<int64_t>();
146       for (auto j = 0; j < segments; j++) {
147         scalar_t output_grad = output_grad_data[out_idx];
148         for (auto k = 0; k < segment_length; k++) {
149           self_grad_data[in_idx] = output_grad;
150           in_idx += 1;
151         }
152         out_idx += 1;
153       }
154     }
155   });
156 
157   return wrap_buffer(self_grad_buffer, self_sizes);
158 
159 }
160 
161 
_nested_select_backward_symint(const Tensor & grad,const Tensor & nested_self,int64_t dim,c10::SymInt index)162 Tensor _nested_select_backward_symint(
163   const Tensor& grad,
164   const Tensor& nested_self,
165   int64_t dim,
166   // NOLINTNEXTLINE(performance-unnecessary-value-param)
167   c10::SymInt index) {
168   auto nt_self = get_nested_tensor_impl(nested_self);
169   const Tensor& self_buffer = nt_self->get_buffer();
170   const auto self_sizes = nt_self->get_nested_sizes();
171   const Tensor& self_grad_buffer = self_buffer.new_zeros(self_buffer.sizes());
172 
173   auto nt_grad = wrap_buffer(self_grad_buffer, self_sizes);
174   nt_grad.select_symint(dim, std::move(index)).copy_(grad);
175 
176   return nt_grad;
177 }
178 
gelu_backwards_nested(const Tensor & grad,const Tensor & self,c10::string_view approximate)179 Tensor gelu_backwards_nested(const Tensor& grad, const Tensor& self, c10::string_view approximate){
180     auto partial_gelu_backward = [approximate](auto && PH1, auto && PH2) { return at::gelu_backward(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2), approximate); };
181     return map_nt_binary(grad, self, partial_gelu_backward);
182 }
183 
184 // Naming convention for relu
threshold_backwards_nested(const Tensor & grad_output,const Tensor & input,const Scalar & threshold)185 Tensor threshold_backwards_nested(const Tensor& grad_output, const Tensor& input, const Scalar& threshold){
186     auto partial_relu_backward = [threshold](auto && PH1, auto && PH2) { return at::threshold_backward(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2), threshold); };
187     return map_nt_binary(grad_output, input, partial_relu_backward);
188 }
189 
190 // Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!)
silu_backward_nested(const Tensor & grad_output,const Tensor & self)191 Tensor silu_backward_nested(const Tensor& grad_output, const Tensor& self){
192     auto partial_silu_backward = [](auto && PH1, auto && PH2) { return at::silu_backward(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2)); };
193     return map_nt_binary(grad_output, self, partial_silu_backward);
194 }
195 
layer_norm_backward_nested(const Tensor & grad,const Tensor & input,IntArrayRef normalized_shape,const Tensor & mean,const Tensor & rstd,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,std::array<bool,3> grad_input_mask)196 std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_nested(
197     const Tensor& grad,
198     const Tensor& input,
199     IntArrayRef normalized_shape,
200     const Tensor& mean,
201     const Tensor& rstd,
202     const std::optional<Tensor>& weight_opt /* optional */,
203     const std::optional<Tensor>& bias_opt /*{ optional */,
204     std::array<bool, 3> grad_input_mask) {
205   // For NestedTensors weight and bias are non nested.
206   auto* nt_impl_grad = get_nested_tensor_impl(grad);
207   auto* nt_impl_input = get_nested_tensor_impl(input);
208   const auto& weight = *weight_opt;
209   const auto& bias = *bias_opt;
210   const auto& sizes = nt_impl_input->get_nested_sizes();
211   auto M_N = _check_nested_layer_norm_inputs(
212       *nt_impl_input, normalized_shape, weight, bias);
213   auto M = M_N.first;
214   auto N = M_N.second;
215 
216   auto gamma = weight.expect_contiguous();
217   auto beta = bias.expect_contiguous();
218 
219   Tensor dInput;
220   Tensor dgamma;
221   Tensor dbeta;
222   auto input_buffer = nt_impl_input->get_buffer();
223   auto grad_buffer = nt_impl_grad->get_buffer();
224   // NOLINTNEXTLINE(bugprone-branch-clone)
225   if (grad_input_mask[0]) {
226     dInput = at::native::empty_like(
227         input_buffer,
228         std::nullopt /* dtype */,
229         std::nullopt /* layout */,
230         std::nullopt /* device */,
231         std::nullopt /* pin_memory */,
232         at::MemoryFormat::Contiguous);
233   } else {
234     dInput = at::native::zeros_like(
235         input_buffer,
236         std::nullopt /* dtype */,
237         std::nullopt /* layout */,
238         std::nullopt /* device */,
239         std::nullopt /* pin_memory */,
240         at::MemoryFormat::Contiguous);
241   }
242   if (grad_input_mask[1]) {
243     dgamma = M > 0 ? at::native::empty_like(
244                          *gamma,
245                          std::nullopt /* dtype */,
246                          std::nullopt /* layout */,
247                          std::nullopt /* device */,
248                          std::nullopt /* pin_memory */,
249                          at::MemoryFormat::Contiguous)
250                    : at::native::zeros_like(
251                          *gamma,
252                          std::nullopt /* dtype */,
253                          std::nullopt /* layout */,
254                          std::nullopt /* device */,
255                          std::nullopt /* pin_memory */,
256                          at::MemoryFormat::Contiguous);
257   }
258   if (grad_input_mask[2]) {
259     dbeta = M > 0 ? at::native::empty_like(
260                         *beta,
261                         std::nullopt /* dtype */,
262                         std::nullopt /* layout */,
263                         std::nullopt /* device */,
264                         std::nullopt /* pin_memory */,
265                         at::MemoryFormat::Contiguous)
266                   : at::native::zeros_like(
267                         *beta,
268                         std::nullopt /* dtype */,
269                         std::nullopt /* layout */,
270                         std::nullopt /* device */,
271                         std::nullopt /* pin_memory */,
272                         at::MemoryFormat::Contiguous);
273   }
274   if (M > 0) {
275     LayerNormBackwardKernel(
276         input_buffer.is_cuda() ? kCUDA : kCPU,
277         grad_buffer,
278         input_buffer,
279         mean,
280         rstd,
281         *gamma,
282         M,
283         N,
284         &dInput,
285         &dgamma,
286         &dbeta);
287   }
288   return std::make_tuple(
289       wrap_buffer(dInput, sizes), std::move(dgamma), std::move(dbeta));
290 }
291 
292 } // namespace at::native
293