1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Config.h>
3 #include <ATen/core/Tensor.h>
4 #include <ATen/cuda/CUDAConfig.h>
5 #include <ATen/cuda/CUDAGraphsUtils.cuh>
6 #if AT_CUDNN_ENABLED()
7 #include <ATen/cudnn/Descriptors.h>
8 #endif
9
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/_cudnn_ctc_loss.h>
15 #include <ATen/ops/_cudnn_ctc_loss_native.h>
16 #include <ATen/ops/_use_cudnn_ctc_loss.h>
17 #include <ATen/ops/_use_cudnn_ctc_loss_native.h>
18 #include <ATen/ops/empty.h>
19 #include <ATen/ops/empty_like.h>
20 #endif
21
22 #if (!AT_CUDNN_ENABLED())
23
24 namespace at {
25 namespace native {
26
27 // See Note [ATen preprocessor philosophy]
28
_use_cudnn_ctc_loss(const Tensor & log_probs,const Tensor & targets,IntArrayRef input_lengths,IntArrayRef target_lengths,int64_t BLANK)29 bool _use_cudnn_ctc_loss(
30 const Tensor& log_probs,
31 const Tensor& targets,
32 IntArrayRef input_lengths,
33 IntArrayRef target_lengths,
34 int64_t BLANK) {
35 return false;
36 }
37
_use_cudnn_ctc_loss_tensor(const Tensor & log_probs,const Tensor & targets,const Tensor & input_lengths,const Tensor & target_lengths,int64_t BLANK)38 bool _use_cudnn_ctc_loss_tensor(
39 const Tensor& log_probs,
40 const Tensor& targets,
41 const Tensor& input_lengths,
42 const Tensor& target_lengths,
43 int64_t BLANK) {
44 return false;
45 }
46
_cudnn_ctc_loss(const Tensor & log_probs,const Tensor & targets,IntArrayRef input_lengths,IntArrayRef target_lengths,int64_t BLANK,bool deterministic,bool zero_infinity)47 std::tuple<Tensor, Tensor> _cudnn_ctc_loss(
48 const Tensor& log_probs,
49 const Tensor& targets,
50 IntArrayRef input_lengths,
51 IntArrayRef target_lengths,
52 int64_t BLANK,
53 bool deterministic,
54 bool zero_infinity) {
55 AT_ERROR("cudnn_ctc_loss: ATen not compiled with cuDNN >= 7 support");
56 }
57
_cudnn_ctc_loss_tensor(const Tensor & log_probs,const Tensor & targets,const Tensor & input_lengths,const Tensor & target_lengths,int64_t BLANK,bool deterministic,bool zero_infinity)58 std::tuple<Tensor, Tensor> _cudnn_ctc_loss_tensor(
59 const Tensor& log_probs,
60 const Tensor& targets,
61 const Tensor& input_lengths,
62 const Tensor& target_lengths,
63 int64_t BLANK,
64 bool deterministic,
65 bool zero_infinity) {
66 AT_ERROR("cudnn_ctc_loss: ATen not compiled with cuDNN >= 8 support");
67 }
68
69 } // namespace native
70 } // namespace at
71
72 #else // AT_CUDNN_ENABLED
73
74 #include <ATen/cudnn/Descriptors.h>
75 #include <ATen/cudnn/Types.h>
76 #include <ATen/cudnn/Utils.h>
77
78 #include <ATen/TensorUtils.h>
79 #include <c10/util/irange.h>
80
81 namespace at {
82 namespace native {
83
84 namespace {
85 // "cache" whether we've previously failed the target lengths check
86 static bool tensor_failed_target_lengths_check = false;
87 } // namespace
88
_use_cudnn_ctc_loss(const Tensor & log_probs,const Tensor & targets,IntArrayRef input_lengths,IntArrayRef target_lengths,int64_t BLANK)89 bool _use_cudnn_ctc_loss(
90 const Tensor& log_probs,
91 const Tensor& targets,
92 IntArrayRef input_lengths,
93 IntArrayRef target_lengths,
94 int64_t BLANK) {
95 auto& ctx = at::globalContext();
96
97 bool use_cudnn = ctx.userEnabledCuDNN() && (BLANK == 0) &&
98 (targets.dim() == 1) && (log_probs.scalar_type() == at::kFloat) &&
99 (targets.scalar_type() == at::kInt) &&
100 (targets.device().type() == at::kCPU) && (targets.is_contiguous()) &&
101 (log_probs.device().type() == at::kCUDA) && (log_probs.dim() == 3);
102
103 if (use_cudnn) {
104 // we don't know that input_lengths and target_lengths have the same size
105 // (they should, but we didn't check yet)
106 int64_t max_input_length = log_probs.size(0);
107 for (const auto input_length : input_lengths) {
108 use_cudnn = use_cudnn && ((input_length == max_input_length) ? 1 : 0);
109 }
110 for (const auto b : c10::irange(target_lengths.size())) {
111 // target length < 256 is documented, but we see illegal memory accesses
112 // when target lengths > input lengths for CuDNN
113 use_cudnn = use_cudnn && (target_lengths[b] < 256) &&
114 (target_lengths[b] <= input_lengths[b]);
115 }
116 }
117 return use_cudnn;
118 }
119
_use_cudnn_ctc_loss_tensor(const Tensor & log_probs,const Tensor & targets,const Tensor & input_lengths,const Tensor & target_lengths,int64_t BLANK)120 bool _use_cudnn_ctc_loss_tensor(
121 const Tensor& log_probs,
122 const Tensor& targets,
123 const Tensor& input_lengths,
124 const Tensor& target_lengths,
125 int64_t BLANK) {
126 auto& ctx = at::globalContext();
127
128 bool use_cudnn = ctx.userEnabledCuDNN() && (BLANK == 0) &&
129 (targets.dim() == 1) && (log_probs.scalar_type() == at::kFloat) &&
130 (targets.scalar_type() == at::kInt) &&
131 (log_probs.device().type() == at::kCUDA) && (targets.is_contiguous()) &&
132 (log_probs.dim() == 3) && (input_lengths.scalar_type() == at::kInt) &&
133 (target_lengths.scalar_type() == at::kInt);
134
135 if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) {
136 Tensor tlc = target_lengths.to(Device(at::kCPU), at::kLong).contiguous();
137 IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel());
138 for (const auto b : c10::irange(tl.size())) {
139 // target length < 256 is documented, but we see illegal memory accesses
140 // when target lengths > input lengths for CuDNN
141 Tensor ilc = input_lengths.to(Device(at::kCPU), at::kLong).contiguous();
142 Tensor tlc = target_lengths.to(Device(at::kCPU), at::kLong).contiguous();
143 IntArrayRef il(ilc.data_ptr<int64_t>(), ilc.numel());
144 IntArrayRef tl(tlc.data_ptr<int64_t>(), tlc.numel());
145 use_cudnn = use_cudnn && (tl[b] < 256) && (tl[b] <= il[b]);
146 if (!use_cudnn) {
147 tensor_failed_target_lengths_check = true;
148 break;
149 }
150 }
151 } else {
152 use_cudnn = use_cudnn && !tensor_failed_target_lengths_check;
153 if (tensor_failed_target_lengths_check) {
154 TORCH_WARN(
155 "cuDNN max target length restriction < 256 cannot be checked during graph capture,"
156 " but target length >= 256 was observed previously e.g., during warmup, so we"
157 " presume it is unsafe to dispatch to cuDNN ctc_loss.");
158 }
159 }
160
161 return use_cudnn;
162 }
163
_cudnn_ctc_loss(const Tensor & log_probs_t,const Tensor & targets_t,IntArrayRef input_lengths_,IntArrayRef target_lengths_,int64_t BLANK,bool deterministic,bool zero_infinity)164 std::tuple<Tensor, Tensor> _cudnn_ctc_loss(
165 const Tensor& log_probs_t,
166 const Tensor& targets_t,
167 IntArrayRef input_lengths_,
168 IntArrayRef target_lengths_,
169 int64_t BLANK,
170 bool deterministic,
171 bool zero_infinity) {
172 (void)zero_infinity; // only used for backward
173 const CheckedFrom c = "cudnn_ctc_loss";
174 const TensorArg log_probs{log_probs_t, "log_probs", 1};
175 const TensorArg targets{targets_t, "targets", 2};
176 checkDim(c, log_probs, 3);
177 checkScalarType(c, log_probs, kFloat);
178 checkDim(c, targets, 1);
179 checkScalarType(c, targets, kInt);
180 checkContiguous(c, targets); // ?
181 checkBackend(c, {*log_probs}, Backend::CUDA);
182 checkBackend(c, {*targets}, Backend::CPU);
183 const auto batch_size = log_probs->size(1);
184 TORCH_CHECK(
185 static_cast<int64_t>(input_lengths_.size()) == batch_size,
186 "input_lengths needs to have size to match batch_size");
187 TORCH_CHECK(
188 static_cast<int64_t>(target_lengths_.size()) == batch_size,
189 "target_lengths needs to have size to match batch_size");
190
191 std::vector<int> input_lengths(input_lengths_.begin(), input_lengths_.end());
192 std::vector<int> target_lengths(
193 target_lengths_.begin(), target_lengths_.end());
194
195 TORCH_CHECK(BLANK == 0, "blank must be label 0 for cudnn_ctc_loss");
196 // checked in dispatch:
197 // assert other conditions for cudnnCTCLoss: all label lengths <= 256
198 // all input lengths = logprob.size(0)
199
200 const auto handle = getCudnnHandle();
201
202 const cudnnCTCLossAlgo_t algo =
203 (deterministic ? CUDNN_CTC_LOSS_ALGO_DETERMINISTIC
204 : CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC);
205
206 CTCLossDescriptor ctc_loss_desc;
207
208 // so the CuDNN gradient semantics have changed between 7.1 and 7.6,
209 // this is CuDNN 7.6 only, see PyTorch 1.2 for older CuDNN.
210 ctc_loss_desc.setEx(
211 CUDNN_DATA_FLOAT, CUDNN_LOSS_NORMALIZATION_SOFTMAX, CUDNN_PROPAGATE_NAN);
212 TensorDescriptor log_probs_desc{log_probs_t};
213 Tensor grad = at::empty_like(log_probs_t, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
214 TensorDescriptor grad_desc{grad};
215
216 size_t workspace_size;
217 AT_CUDNN_CHECK(cudnnGetCTCLossWorkspaceSize(
218 handle,
219 log_probs_desc.desc(),
220 grad_desc.desc(),
221 targets->data_ptr<int>(),
222 target_lengths.data(),
223 input_lengths.data(),
224 algo,
225 ctc_loss_desc.desc(),
226 &workspace_size));
227
228 Tensor workspace =
229 at::empty(workspace_size, log_probs->options().dtype(kByte));
230 Tensor costs = at::empty({log_probs->size(1)}, log_probs->options());
231
232 AT_CUDNN_CHECK(cudnnCTCLoss(
233 handle,
234 log_probs_desc.desc(),
235 log_probs_t.data_ptr(),
236 targets->data_ptr<int>(),
237 target_lengths.data(),
238 input_lengths.data(),
239 costs.data_ptr(),
240 grad_desc.desc(),
241 grad.data_ptr(),
242 algo,
243 ctc_loss_desc.desc(),
244 workspace.data_ptr(),
245 workspace_size));
246 return std::make_tuple(costs, grad);
247 }
248
_cudnn_ctc_loss_tensor(const Tensor & log_probs_t,const Tensor & targets_t,const Tensor & input_lengths,const Tensor & target_lengths,int64_t BLANK,bool deterministic,bool zero_infinity)249 std::tuple<Tensor, Tensor> _cudnn_ctc_loss_tensor(
250 const Tensor& log_probs_t,
251 const Tensor& targets_t,
252 const Tensor& input_lengths,
253 const Tensor& target_lengths,
254 int64_t BLANK,
255 bool deterministic,
256 bool zero_infinity) {
257 Tensor targets_t_ = targets_t;
258 if (targets_t.device().type() == at::kCPU) {
259 targets_t_ = targets_t.to(Device(at::kCUDA));
260 }
261 const CheckedFrom c = "cudnn_ctc_loss";
262 const TensorArg log_probs{log_probs_t, "log_probs", 1};
263 const TensorArg targets{targets_t_, "targets", 2};
264 checkDim(c, log_probs, 3);
265 checkScalarType(c, log_probs, kFloat);
266 checkDim(c, targets, 1);
267 checkScalarType(c, targets, kInt);
268 checkContiguous(c, targets); // ?
269 checkBackend(c, {*log_probs}, Backend::CUDA);
270 checkBackend(c, {*targets}, Backend::CUDA);
271 const auto batch_size = log_probs->size(1);
272 int64_t input_lengths_size =
273 input_lengths.sizes().size() ? input_lengths.size(0) : 1;
274 int64_t target_lengths_size =
275 target_lengths.sizes().size() ? target_lengths.size(0) : 1;
276 TORCH_CHECK(
277 input_lengths_size == batch_size,
278 "input_lengths needs to have size to match batch_size");
279 TORCH_CHECK(
280 target_lengths_size == batch_size,
281 "target_lengths needs to have size to match batch_size");
282
283 TORCH_CHECK(BLANK == 0, "blank must be label 0 for cudnn_ctc_loss");
284 // checked in dispatch:
285 // assert other conditions for cudnnCTCLoss: all label lengths <= 256
286 // all input lengths = logprob.size(0)
287
288 const auto handle = getCudnnHandle();
289
290 const cudnnCTCLossAlgo_t algo =
291 (deterministic ? CUDNN_CTC_LOSS_ALGO_DETERMINISTIC
292 : CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC);
293
294 CTCLossDescriptor ctc_loss_desc;
295
296 ctc_loss_desc.set_v8_v9(
297 CUDNN_DATA_FLOAT,
298 CUDNN_LOSS_NORMALIZATION_SOFTMAX,
299 CUDNN_PROPAGATE_NAN,
300 255);
301 TensorDescriptor log_probs_desc{log_probs_t};
302 Tensor grad = at::empty_like(log_probs_t, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
303 TensorDescriptor grad_desc{grad};
304
305 size_t workspace_size;
306 AT_CUDNN_CHECK(cudnnGetCTCLossWorkspaceSize_v8(
307 handle,
308 algo,
309 ctc_loss_desc.desc(),
310 log_probs_desc.desc(),
311 grad_desc.desc(),
312 &workspace_size));
313 Tensor workspace =
314 at::empty(workspace_size, log_probs->options().dtype(kByte));
315 Tensor costs = at::empty({log_probs->size(1)}, log_probs->options());
316
317 AT_CUDNN_CHECK(cudnnCTCLoss_v8(
318 handle,
319 algo,
320 ctc_loss_desc.desc(),
321 log_probs_desc.desc(),
322 log_probs_t.data_ptr(),
323 targets_t_.data_ptr<int>(),
324 target_lengths.data_ptr<int>(),
325 input_lengths.data_ptr<int>(),
326 costs.data_ptr(),
327 grad_desc.desc(),
328 grad.data_ptr(),
329 workspace_size,
330 workspace.data_ptr()
331
332 ));
333 return std::make_tuple(costs, grad);
334 }
335
336 } // namespace native
337 } // namespace at
338
339 #endif
340