xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cudnn/LossCTC.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Config.h>
3 #include <ATen/core/Tensor.h>
4 #include <ATen/cuda/CUDAConfig.h>
5 #include <ATen/cuda/CUDAGraphsUtils.cuh>
6 #if AT_CUDNN_ENABLED()
7 #include <ATen/cudnn/Descriptors.h>
8 #endif
9 
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/_cudnn_ctc_loss.h>
15 #include <ATen/ops/_cudnn_ctc_loss_native.h>
16 #include <ATen/ops/_use_cudnn_ctc_loss.h>
17 #include <ATen/ops/_use_cudnn_ctc_loss_native.h>
18 #include <ATen/ops/empty.h>
19 #include <ATen/ops/empty_like.h>
20 #endif
21 
22 #if (!AT_CUDNN_ENABLED())
23 
24 namespace at {
25 namespace native {
26 
27 // See Note [ATen preprocessor philosophy]
28 
_use_cudnn_ctc_loss(const Tensor & log_probs,const Tensor & targets,IntArrayRef input_lengths,IntArrayRef target_lengths,int64_t BLANK)29 bool _use_cudnn_ctc_loss(
30     const Tensor& log_probs,
31     const Tensor& targets,
32     IntArrayRef input_lengths,
33     IntArrayRef target_lengths,
34     int64_t BLANK) {
35   return false;
36 }
37 
_use_cudnn_ctc_loss_tensor(const Tensor & log_probs,const Tensor & targets,const Tensor & input_lengths,const Tensor & target_lengths,int64_t BLANK)38 bool _use_cudnn_ctc_loss_tensor(
39     const Tensor& log_probs,
40     const Tensor& targets,
41     const Tensor& input_lengths,
42     const Tensor& target_lengths,
43     int64_t BLANK) {
44   return false;
45 }
46 
_cudnn_ctc_loss(const Tensor & log_probs,const Tensor & targets,IntArrayRef input_lengths,IntArrayRef target_lengths,int64_t BLANK,bool deterministic,bool zero_infinity)47 std::tuple<Tensor, Tensor> _cudnn_ctc_loss(
48     const Tensor& log_probs,
49     const Tensor& targets,
50     IntArrayRef input_lengths,
51     IntArrayRef target_lengths,
52     int64_t BLANK,
53     bool deterministic,
54     bool zero_infinity) {
55   AT_ERROR("cudnn_ctc_loss: ATen not compiled with cuDNN >= 7 support");
56 }
57 
_cudnn_ctc_loss_tensor(const Tensor & log_probs,const Tensor & targets,const Tensor & input_lengths,const Tensor & target_lengths,int64_t BLANK,bool deterministic,bool zero_infinity)58 std::tuple<Tensor, Tensor> _cudnn_ctc_loss_tensor(
59     const Tensor& log_probs,
60     const Tensor& targets,
61     const Tensor& input_lengths,
62     const Tensor& target_lengths,
63     int64_t BLANK,
64     bool deterministic,
65     bool zero_infinity) {
66   AT_ERROR("cudnn_ctc_loss: ATen not compiled with cuDNN >= 8 support");
67 }
68 
69 } // namespace native
70 } // namespace at
71 
72 #else // AT_CUDNN_ENABLED
73 
74 #include <ATen/cudnn/Descriptors.h>
75 #include <ATen/cudnn/Types.h>
76 #include <ATen/cudnn/Utils.h>
77 
78 #include <ATen/TensorUtils.h>
79 #include <c10/util/irange.h>
80 
81 namespace at {
82 namespace native {
83 
84 namespace {
85 // "cache" whether we've previously failed the target lengths check
86 static bool tensor_failed_target_lengths_check = false;
87 } // namespace
88 
_use_cudnn_ctc_loss(const Tensor & log_probs,const Tensor & targets,IntArrayRef input_lengths,IntArrayRef target_lengths,int64_t BLANK)89 bool _use_cudnn_ctc_loss(
90     const Tensor& log_probs,
91     const Tensor& targets,
92     IntArrayRef input_lengths,
93     IntArrayRef target_lengths,
94     int64_t BLANK) {
95   auto& ctx = at::globalContext();
96 
97   bool use_cudnn = ctx.userEnabledCuDNN() && (BLANK == 0) &&
98       (targets.dim() == 1) && (log_probs.scalar_type() == at::kFloat) &&
99       (targets.scalar_type() == at::kInt) &&
100       (targets.device().type() == at::kCPU) && (targets.is_contiguous()) &&
101       (log_probs.device().type() == at::kCUDA) && (log_probs.dim() == 3);
102 
103   if (use_cudnn) {
104     // we don't know that input_lengths and target_lengths have the same size
105     // (they should, but we didn't check yet)
106     int64_t max_input_length = log_probs.size(0);
107     for (const auto input_length : input_lengths) {
108       use_cudnn = use_cudnn && ((input_length == max_input_length) ? 1 : 0);
109     }
110     for (const auto b : c10::irange(target_lengths.size())) {
111       // target length < 256 is documented, but we see illegal memory accesses
112       // when target lengths > input lengths for CuDNN
113       use_cudnn = use_cudnn && (target_lengths[b] < 256) &&
114           (target_lengths[b] <= input_lengths[b]);
115     }
116   }
117   return use_cudnn;
118 }
119 
_use_cudnn_ctc_loss_tensor(const Tensor & log_probs,const Tensor & targets,const Tensor & input_lengths,const Tensor & target_lengths,int64_t BLANK)120 bool _use_cudnn_ctc_loss_tensor(
121     const Tensor& log_probs,
122     const Tensor& targets,
123     const Tensor& input_lengths,
124     const Tensor& target_lengths,
125     int64_t BLANK) {
126   auto& ctx = at::globalContext();
127 
128   bool use_cudnn = ctx.userEnabledCuDNN() && (BLANK == 0) &&
129       (targets.dim() == 1) && (log_probs.scalar_type() == at::kFloat) &&
130       (targets.scalar_type() == at::kInt) &&
131       (log_probs.device().type() == at::kCUDA) && (targets.is_contiguous()) &&
132       (log_probs.dim() == 3) && (input_lengths.scalar_type() == at::kInt) &&
133       (target_lengths.scalar_type() == at::kInt);
134 
135   if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) {
136     Tensor tlc = target_lengths.to(Device(at::kCPU), at::kLong).contiguous();
137     IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel());
138     for (const auto b : c10::irange(tl.size())) {
139       // target length < 256 is documented, but we see illegal memory accesses
140       // when target lengths > input lengths for CuDNN
141       Tensor ilc = input_lengths.to(Device(at::kCPU), at::kLong).contiguous();
142       Tensor tlc = target_lengths.to(Device(at::kCPU), at::kLong).contiguous();
143       IntArrayRef il(ilc.data_ptr<int64_t>(), ilc.numel());
144       IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel());
145       use_cudnn = use_cudnn && (tl[b] < 256) && (tl[b] <= il[b]);
146       if (!use_cudnn) {
147         tensor_failed_target_lengths_check = true;
148         break;
149       }
150     }
151   } else {
152     use_cudnn = use_cudnn && !tensor_failed_target_lengths_check;
153     if (tensor_failed_target_lengths_check) {
154       TORCH_WARN(
155           "cuDNN max target length restriction < 256 cannot be checked during graph capture,"
156           " but target length >= 256 was observed previously e.g., during warmup, so we"
157           " presume it is unsafe to dispatch to cuDNN ctc_loss.");
158     }
159   }
160 
161   return use_cudnn;
162 }
163 
_cudnn_ctc_loss(const Tensor & log_probs_t,const Tensor & targets_t,IntArrayRef input_lengths_,IntArrayRef target_lengths_,int64_t BLANK,bool deterministic,bool zero_infinity)164 std::tuple<Tensor, Tensor> _cudnn_ctc_loss(
165     const Tensor& log_probs_t,
166     const Tensor& targets_t,
167     IntArrayRef input_lengths_,
168     IntArrayRef target_lengths_,
169     int64_t BLANK,
170     bool deterministic,
171     bool zero_infinity) {
172   (void)zero_infinity; // only used for backward
173   const CheckedFrom c = "cudnn_ctc_loss";
174   const TensorArg log_probs{log_probs_t, "log_probs", 1};
175   const TensorArg targets{targets_t, "targets", 2};
176   checkDim(c, log_probs, 3);
177   checkScalarType(c, log_probs, kFloat);
178   checkDim(c, targets, 1);
179   checkScalarType(c, targets, kInt);
180   checkContiguous(c, targets); // ?
181   checkBackend(c, {*log_probs}, Backend::CUDA);
182   checkBackend(c, {*targets}, Backend::CPU);
183   const auto batch_size = log_probs->size(1);
184   TORCH_CHECK(
185       static_cast<int64_t>(input_lengths_.size()) == batch_size,
186       "input_lengths needs to have size to match batch_size");
187   TORCH_CHECK(
188       static_cast<int64_t>(target_lengths_.size()) == batch_size,
189       "target_lengths needs to have size to match batch_size");
190 
191   std::vector<int> input_lengths(input_lengths_.begin(), input_lengths_.end());
192   std::vector<int> target_lengths(
193       target_lengths_.begin(), target_lengths_.end());
194 
195   TORCH_CHECK(BLANK == 0, "blank must be label 0 for cudnn_ctc_loss");
196   // checked in dispatch:
197   // assert other conditions for cudnnCTCLoss: all label lengths <= 256
198   // all input lengths = logprob.size(0)
199 
200   const auto handle = getCudnnHandle();
201 
202   const cudnnCTCLossAlgo_t algo =
203       (deterministic ? CUDNN_CTC_LOSS_ALGO_DETERMINISTIC
204                      : CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC);
205 
206   CTCLossDescriptor ctc_loss_desc;
207 
208   // so the CuDNN gradient semantics have changed between 7.1 and 7.6,
209   // this is CuDNN 7.6 only, see PyTorch 1.2 for older CuDNN.
210   ctc_loss_desc.setEx(
211       CUDNN_DATA_FLOAT, CUDNN_LOSS_NORMALIZATION_SOFTMAX, CUDNN_PROPAGATE_NAN);
212   TensorDescriptor log_probs_desc{log_probs_t};
213   Tensor grad = at::empty_like(log_probs_t, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
214   TensorDescriptor grad_desc{grad};
215 
216   size_t workspace_size;
217   AT_CUDNN_CHECK(cudnnGetCTCLossWorkspaceSize(
218       handle,
219       log_probs_desc.desc(),
220       grad_desc.desc(),
221       targets->data_ptr<int>(),
222       target_lengths.data(),
223       input_lengths.data(),
224       algo,
225       ctc_loss_desc.desc(),
226       &workspace_size));
227 
228   Tensor workspace =
229       at::empty(workspace_size, log_probs->options().dtype(kByte));
230   Tensor costs = at::empty({log_probs->size(1)}, log_probs->options());
231 
232   AT_CUDNN_CHECK(cudnnCTCLoss(
233       handle,
234       log_probs_desc.desc(),
235       log_probs_t.data_ptr(),
236       targets->data_ptr<int>(),
237       target_lengths.data(),
238       input_lengths.data(),
239       costs.data_ptr(),
240       grad_desc.desc(),
241       grad.data_ptr(),
242       algo,
243       ctc_loss_desc.desc(),
244       workspace.data_ptr(),
245       workspace_size));
246   return std::make_tuple(costs, grad);
247 }
248 
_cudnn_ctc_loss_tensor(const Tensor & log_probs_t,const Tensor & targets_t,const Tensor & input_lengths,const Tensor & target_lengths,int64_t BLANK,bool deterministic,bool zero_infinity)249 std::tuple<Tensor, Tensor> _cudnn_ctc_loss_tensor(
250     const Tensor& log_probs_t,
251     const Tensor& targets_t,
252     const Tensor& input_lengths,
253     const Tensor& target_lengths,
254     int64_t BLANK,
255     bool deterministic,
256     bool zero_infinity) {
257   Tensor targets_t_ = targets_t;
258   if (targets_t.device().type() == at::kCPU) {
259     targets_t_ = targets_t.to(Device(at::kCUDA));
260   }
261   const CheckedFrom c = "cudnn_ctc_loss";
262   const TensorArg log_probs{log_probs_t, "log_probs", 1};
263   const TensorArg targets{targets_t_, "targets", 2};
264   checkDim(c, log_probs, 3);
265   checkScalarType(c, log_probs, kFloat);
266   checkDim(c, targets, 1);
267   checkScalarType(c, targets, kInt);
268   checkContiguous(c, targets); // ?
269   checkBackend(c, {*log_probs}, Backend::CUDA);
270   checkBackend(c, {*targets}, Backend::CUDA);
271   const auto batch_size = log_probs->size(1);
272   int64_t input_lengths_size =
273       input_lengths.sizes().size() ? input_lengths.size(0) : 1;
274   int64_t target_lengths_size =
275       target_lengths.sizes().size() ? target_lengths.size(0) : 1;
276   TORCH_CHECK(
277       input_lengths_size == batch_size,
278       "input_lengths needs to have size to match batch_size");
279   TORCH_CHECK(
280       target_lengths_size == batch_size,
281       "target_lengths needs to have size to match batch_size");
282 
283   TORCH_CHECK(BLANK == 0, "blank must be label 0 for cudnn_ctc_loss");
284   // checked in dispatch:
285   // assert other conditions for cudnnCTCLoss: all label lengths <= 256
286   // all input lengths = logprob.size(0)
287 
288   const auto handle = getCudnnHandle();
289 
290   const cudnnCTCLossAlgo_t algo =
291       (deterministic ? CUDNN_CTC_LOSS_ALGO_DETERMINISTIC
292                      : CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC);
293 
294   CTCLossDescriptor ctc_loss_desc;
295 
296   ctc_loss_desc.set_v8_v9(
297       CUDNN_DATA_FLOAT,
298       CUDNN_LOSS_NORMALIZATION_SOFTMAX,
299       CUDNN_PROPAGATE_NAN,
300       255);
301   TensorDescriptor log_probs_desc{log_probs_t};
302   Tensor grad = at::empty_like(log_probs_t, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
303   TensorDescriptor grad_desc{grad};
304 
305   size_t workspace_size;
306   AT_CUDNN_CHECK(cudnnGetCTCLossWorkspaceSize_v8(
307       handle,
308       algo,
309       ctc_loss_desc.desc(),
310       log_probs_desc.desc(),
311       grad_desc.desc(),
312       &workspace_size));
313   Tensor workspace =
314       at::empty(workspace_size, log_probs->options().dtype(kByte));
315   Tensor costs = at::empty({log_probs->size(1)}, log_probs->options());
316 
317   AT_CUDNN_CHECK(cudnnCTCLoss_v8(
318       handle,
319       algo,
320       ctc_loss_desc.desc(),
321       log_probs_desc.desc(),
322       log_probs_t.data_ptr(),
323       targets_t_.data_ptr<int>(),
324       target_lengths.data_ptr<int>(),
325       input_lengths.data_ptr<int>(),
326       costs.data_ptr(),
327       grad_desc.desc(),
328       grad.data_ptr(),
329       workspace_size,
330       workspace.data_ptr()
331 
332           ));
333   return std::make_tuple(costs, grad);
334 }
335 
336 } // namespace native
337 } // namespace at
338 
339 #endif
340