1 #include <torch/csrc/jit/runtime/symbolic_script.h>
2
3 #include <torch/csrc/jit/frontend/ir_emitter.h>
4 #include <torch/csrc/jit/runtime/operator.h>
5
6 namespace torch::jit {
7 namespace {
8 std::mutex lock;
9 const std::vector<std::string> functions = {
10 R"(
11 #### HELPER FUNCTIONS ###
12 #### PREFIX: AD_ ###
13 #### SCHEMA NOT SAVED IN CACHE ###
14
15 def AD_unsqueeze_multiple(t,
16 dims: List[int],
17 n_dims: int):
18 seen = [False] * n_dims
19 for i in range(len(dims)):
20 seen[dims[i]] = True
21
22 for d in range(n_dims):
23 if seen[d]:
24 t = t.unsqueeze(d)
25 return t
26
27 def AD_sum_backward(grad,
28 sizes: List[int],
29 dims: Optional[List[int]],
30 keepdim: bool):
31 if not keepdim and len(sizes) > 0:
32 if dims is None:
33 return grad.expand(sizes)
34 elif len(dims) == 1:
35 return grad.unsqueeze(dims[0]).expand(sizes)
36 else:
37 res = AD_unsqueeze_multiple(grad, dims, len(sizes))
38 return res.expand(sizes)
39 else:
40 return grad.expand(sizes)
41
42 def AD_logsumexp_backward(grad, self, result,
43 dim: List[int],
44 keepdim: bool):
45 if not keepdim and self.dim() != 0:
46 n_dims = len(self.size())
47 grad = AD_unsqueeze_multiple(grad, dim, n_dims)
48 result = AD_unsqueeze_multiple(result, dim, n_dims)
49 return grad * (self - result).exp()
50
51 def mean_0(self, *, dtype: Optional[int]):
52 self_size = self.size()
53 self_numel = self.numel()
54 self_scalar_type = self.dtype
55 def backward(grad_output):
56 return grad_output.expand(self_size).to(self_scalar_type) / self_numel, None
57
58 return torch.mean(self, dtype=dtype), backward
59
60 def mean_1(self,
61 dim: Optional[List[int]],
62 keepdim: bool,
63 *,
64 dtype: Optional[int]):
65 self_size = self.size()
66 self_scalar_type = self.dtype
67 def backward(grad_output):
68 grad_self = AD_sum_backward(grad_output, self_size, dim, keepdim).to(self_scalar_type) / AD_safe_size(self_size, dim)
69 return grad_self, None, None, None
70
71 return torch.mean(self, dim, keepdim, dtype=dtype), backward
72
73 def logsumexp(self,
74 dim: List[int],
75 keepdim: bool):
76 result = torch.logsumexp(self, dim, keepdim)
77 self_dim = self.dim()
78 def backward(grad_output):
79 grad_self = AD_logsumexp_backward(grad_output, self, result, dim, keepdim)
80 return grad_self, None, None
81
82 return result, backward
83
84 def AD_bool_to_int(b: bool):
85 # FIXME: torchscript: int - bool
86 if b:
87 i = 1
88 else:
89 i = 0
90 return i
91
92 def AD_var_backward_0(grad, self, correction: number):
93 # FIXME: torchscript: div(float, float)
94 return grad * (self - self.mean()) * 2.0 / (self.numel() - correction)
95
96 def AD_safe_size(sizes: List[int],
97 dims: Optional[List[int]]):
98 if len(sizes) == 0:
99 return 1
100
101 size = 1
102
103 if dims is None:
104 for s in sizes:
105 size *= s
106
107 else:
108 for i in range(len(dims)):
109 d = dims[i]
110 size *= sizes[d]
111
112 return size
113
114 def AD_var_backward_1(grad,
115 self,
116 dim: List[int],
117 correction: number,
118 keepdim: bool):
119 if self.dim() == 0:
120 return AD_var_backward_0(grad, self, correction)
121 self_size = self.size()
122 if not keepdim and self.dim() > 1:
123 grad = AD_unsqueeze_multiple(grad, dim, len(self_size))
124
125 # FIXME: torchscript: div(float, float)
126 return grad * (self - self.mean(dim, True)) * 2.0 / (AD_safe_size(self_size, dim) - correction)
127
128 def AD_var_backward_2(grad,
129 self,
130 dim: Optional[List[int]],
131 correction: Optional[number],
132 keepdim: bool):
133 if correction is None:
134 correction = 1
135 if self.dim() == 0 or dim is None:
136 return AD_var_backward_0(grad, self, correction)
137
138 return AD_var_backward_1(grad, self, dim, correction, keepdim)
139
140 def std_0(self,
141 unbiased: bool=True):
142 std_out = torch.std(self, unbiased)
143 def backward(grad_output):
144 correction = AD_bool_to_int(unbiased)
145 grad_self = AD_var_backward_0(grad_output / (std_out * 2), self, correction)
146 return grad_self, None
147
148 return std_out, backward
149
150 def std_1(self,
151 dim: Optional[List[int]],
152 unbiased: bool,
153 keepdim: bool):
154 std_out = torch.std(self, dim, unbiased, keepdim)
155 def backward(grad_output):
156 correction = AD_bool_to_int(unbiased)
157 grad_self = AD_var_backward_2(grad_output / (std_out * 2), self, dim, correction, keepdim)
158 return grad_self, None, None, None
159
160 return std_out, backward
161
162 def std_2(self,
163 dim: Optional[List[int]],
164 *,
165 correction: Optional[number],
166 keepdim: bool):
167 std_out = torch.std(self, dim, correction=correction, keepdim=keepdim)
168 def backward(grad_output):
169 grad_self = AD_var_backward_2(grad_output / (std_out * 2), self, dim, correction, keepdim)
170 return grad_self, None, None, None
171
172 return std_out, backward
173
174 def var_0(self,
175 unbiased: bool=True):
176 def backward(grad_output):
177 correction = AD_bool_to_int(unbiased)
178 grad_self = AD_var_backward_0(grad_output, self, correction)
179 return grad_self, None
180
181 return torch.var(self, unbiased), backward
182
183 def var_1(self,
184 dim: Optional[List[int]],
185 unbiased: bool,
186 keepdim: bool):
187 def backward(grad_output):
188 correction = AD_bool_to_int(unbiased)
189 grad_self = AD_var_backward_2(grad_output, self, dim, correction, keepdim)
190 return grad_self, None, None, None
191
192 return torch.var(self, dim, unbiased, keepdim), backward
193
194 def var_2(self,
195 dim: Optional[List[int]],
196 *,
197 correction: Optional[number],
198 keepdim: bool):
199 def backward(grad_output):
200 grad_self = AD_var_backward_2(grad_output, self, dim, correction, keepdim)
201 return grad_self, None, None, None
202
203 return torch.var(self, dim, correction=correction, keepdim=keepdim), backward
204
205 def tanh(self):
206 output = torch.tanh(self)
207 def backward(grad_output):
208 return grad_output * (1 - output * output)
209
210 return output, backward
211
212 def AD_index_select_backward(grad,
213 dim: int,
214 indices,
215 sizes: List[int],
216 keepdim: bool):
217 if not keepdim and len(sizes) > 0:
218 grad = grad.unsqueeze(dim)
219 indices = indices.unsqueeze(dim)
220
221 # FIXME: torchscript: torch.zeros(sizes, grad.options())
222 return torch.zeros(sizes).to(grad).scatter_(dim, indices, grad)
223
224 # def topk(self,
225 # k: int,
226 # dim: int = -1,
227 # largest: bool = True,
228 # sorted: bool = True):
229 # result0, result1 = torch.topk(self, k, dim, largest, sorted)
230 # self_size = self.size()
231 # def backward(grad_output):
232 # grad_self = AD_index_select_backward(grad_output, dim, result1, self_size, True)
233 # return grad_self, None, None, None, None
234
235 # return result0, result1, backward
236
237 # def kthvalue(self,
238 # k: int,
239 # dim: int,
240 # keepdim: bool):
241 # result0, result1 = torch.kthvalue(self, k, dim, keepdim)
242 # self_size = self.size()
243 # def backward(grad_output):
244 # grad_self = AD_index_select_backward(grad_output, dim, result1, self_size, keepdim)
245 # return grad_self, None, None, None
246
247 # return result0, result1, backward
248
249 def AD_mm_backward_self(grad, mat2):
250 return grad.mm(mat2.t())
251
252 def AD_mm_backward_mat2(grad, self):
253 return self.t().mm(grad)
254
255 def mm(self, mat2):
256 def backward(grad_output):
257 grad_self = AD_mm_backward_self(grad_output, mat2)
258 grad_mat2 = AD_mm_backward_mat2(grad_output, self)
259 return grad_self, grad_mat2
260
261 return torch.mm(self, mat2), backward
262
263 def AD_permute_backward(grad,
264 fwd_dims: List[int]):
265 ndims = len(fwd_dims)
266 dims = [0] * ndims
267
268 for i in range(ndims):
269 dims[fwd_dims[i]] = i
270
271 return grad.permute(dims)
272
273 def permute(self,
274 dims: List[int]):
275 def backward(grad_output):
276 grad_self = AD_permute_backward(grad_output, dims)
277 return grad_self, None
278
279 return torch.permute(self, dims), backward
280
281 def AD_select_backward(grad,
282 input_sizes: List[int],
283 dim: int,
284 index: int):
285 # FIXME: torchscript: torch.zeros(sizes, grad.options())
286 grad_input = torch.zeros(input_sizes).to(grad)
287 grad_input.select(dim, index).copy_(grad)
288 return grad_input
289
290 # TODO: fix torch.zeros(sizes, grad.options()) before enabling select, topk, kthvalue
291 # def select(self,
292 # dim: int,
293 # index: int):
294 # self_size = self.size()
295 # def backward(grad_output):
296 # grad_self = AD_select_backward(grad_output, self_size, dim, index)
297 # return grad_self, None, None
298
299 # return torch.select(self, dim, index), backward
300
301 def AD_slice_backward(grad,
302 input_sizes: List[int],
303 dim: int,
304 start: int,
305 end: int,
306 step: int):
307 # FIXME: torchscript: torch.zeros(sizes, grad.options())
308 grad_input = torch.zeros(input_sizes).to(grad)
309 grad_input.slice(dim, start, end, step).copy_(grad)
310 return grad_input
311
312 # DON'T enable slice unless we can correctly handle view ops in graph executor.
313 # It triggers failure of TestJit.test_sample in test_distributions.py.
314 # def slice(self,
315 # dim: int=0,
316 # start: int=0,
317 # end: int=9223372036854775807,
318 # step: int=1):
319 # def backward(grad_output):
320 # grad_self = AD_slice_backward(grad_output, self.size(), dim, start, end, step)
321 # return grad_self, None, None, None, None
322
323 # return torch.slice(self, dim, start, end, step), backward
324
325 def AD_unsqueeze_to_0(self,
326 sizes: List[int]):
327 ndims = len(sizes)
328 for i in range(ndims):
329 if sizes[i] == 1:
330 self = self.unsqueeze(i)
331
332 return self
333
334 def AD_unsqueeze_to_1(self,
335 dim: int,
336 sizes: List[int]):
337 if len(sizes) > 0 and sizes[dim] == 1:
338 return self.unsqueeze(dim)
339 return self
340
341 def squeeze_0(self):
342 self_size = self.size()
343 def backward(grad_output):
344 grad_self = AD_unsqueeze_to_0(grad_output, self_size)
345 return grad_self
346
347 return torch.squeeze(self), backward
348
349 def squeeze_1(self,
350 dim: int):
351 self_size = self.size()
352 def backward(grad_output):
353 grad_self = AD_unsqueeze_to_1(grad_output, dim, self_size)
354 return grad_self, None
355
356 return torch.squeeze(self, dim), backward
357
358 def AD_infer_size(a: List[int],
359 b: List[int]):
360 dimsA = len(a)
361 dimsB = len(b)
362
363 ndim = dimsA if dimsA > dimsB else dimsB
364 expand_sizes = [0] * ndim
365
366 for i in range(ndim):
367 idx = - i + ndim - 1
368 sizeA = a[i] if dimsA + i >= 0 else 1
369 sizeB = b[i] if dimsB + i >= 0 else 1
370
371 # Assert sizeA == sizeB or sizeA == 1 or sizeB == 1
372 expand_sizes[i] = sizeB if sizeA == 1 else sizeA
373
374 return expand_sizes
375
376 def AD_bmm_backward_self(grad, mat2):
377 return grad.bmm(mat2.transpose(1, 2))
378
379 def AD_bmm_backward_mat2(grad, self):
380 return self.transpose(1, 2).bmm(grad)
381
382 def bmm(self, mat2):
383 def backward(grad_output):
384 grad_self = AD_bmm_backward_self(grad_output, mat2)
385 grad_mat2 = AD_bmm_backward_mat2(grad_output, self)
386 return grad_self, grad_mat2
387 return torch.bmm(self, mat2), backward
388 )",
389 R"(
390 def AD_mat_transpose(mat):
391 dim = mat.dim()
392 if dim == 1:
393 out = mat
394 elif dim == 2:
395 out = mat.t()
396 else:
397 dims = rangelist(dim)
398 dims[-1] = dim - 2
399 dims[-2] = dim - 1
400 out = mat.permute(dims)
401 return out
402
403 # In matmul backward case of [b, m, n] * [b, n, p] => [m, p],
404 # instead of doing [b, m, p] and then reduce to [m, p]
405 # which potentially uses large intermediate of size b*m*p,
406 # we do [m, bn] * [bn, p] to avoid having the large
407 # intermediate, thus reduces max memory usage.
408 def AD_matmul_bw_special_fold(mat1, mat2):
409 mat1_transpose = AD_mat_transpose(mat1)
410 mat1_fold = mat1_transpose.reshape(-1, mat1_transpose.size()[-1])
411 mat2_fold = mat2.reshape(-1, mat2.size()[-1])
412 return mat1_fold.t().mm(mat2_fold)
413
414 def AD_matmul_bw_size(mat1, mat2,
415 out_size: List[int]):
416 dim1 = mat1.dim()
417 dim2 = mat2.dim()
418 dim_out = len(out_size)
419 if dim1 == 0 or dim2 == 0:
420 out = mat1 * mat2
421 elif dim_out == 2 and dim1 == dim2 and dim1 >=3:
422 out = AD_matmul_bw_special_fold(mat1, mat2)
423 elif dim_out == 1 and dim1 - dim2 == 1 and dim1 >= 3:
424 mat2_unsqueeze = mat2.unsqueeze(-1)
425 out = AD_matmul_bw_special_fold(mat1, mat2_unsqueeze)
426 out = out.squeeze(-1)
427 elif dim1 + dim2 == dim_out:
428 if dim2 == 1:
429 target_dim2 = 0
430 else:
431 target_dim2 = -2
432 out = torch.matmul(mat1.unsqueeze(dim1), mat2.unsqueeze(target_dim2))
433 elif dim_out == dim1 - dim2:
434 out = torch.matmul(mat1, mat2.unsqueeze(dim2)).squeeze(-1)
435 elif dim_out == dim2 - dim1:
436 out = torch.matmul(mat1.unsqueeze(-2), mat2).squeeze(-2)
437 else:
438 out = torch.matmul(mat1, mat2)
439 return out
440
441 def matmul(self, other):
442 def backward(grad_output):
443 self_size = self.size()
444 other_size = other.size()
445 grad_self = AD_matmul_bw_size(grad_output, AD_mat_transpose(other), self_size)._grad_sum_to_size(self_size)
446 grad_other = AD_matmul_bw_size(AD_mat_transpose(self), grad_output, other_size)._grad_sum_to_size(other_size)
447 return grad_self, grad_other
448
449 return torch.matmul(self, other), backward
450
451 def linear(input : Tensor,
452 weight : Tensor,
453 bias : Optional[Tensor]):
454 result = torch.linear(input, weight, bias)
455
456 def backward(grad_output):
457 if bias is not None:
458 grad_bias = grad_output._grad_sum_to_size(bias.size())
459 else:
460 grad_bias = None
461
462 weight_size = weight.size()
463 grad_input = torch.matmul(grad_output, weight)
464 grad_weight = torch.matmul(grad_output.reshape(-1, weight_size[0]).t(), input.reshape(-1, weight_size[1]))
465 # Note: calling unchecked_unwrap_optional is only safe, when we
466 # directly return grad_bias directly back to bias.
467 # Because in the case where `bias is None`, unwrapped
468 # grad_bias would just be pruned away.
469 return grad_input, grad_weight, grad_bias.unchecked_unwrap_optional
470 return result, backward
471 )",
472 R"(
473 def addcmul(self,
474 tensor1,
475 tensor2,
476 *,
477 value: number):
478 result = torch.addcmul(self, tensor1, tensor2, value=value)
479 self_size = torch._size_if_not_equal(self.size(), result.size())
480 tensor1_size = torch._size_if_not_equal(tensor1.size(), result.size())
481 tensor2_size = torch._size_if_not_equal(tensor2.size(), result.size())
482 def backward(grad_output):
483 grad = grad_output * value
484 grad_tensor1 = (grad * tensor2)._grad_sum_to_size(tensor1_size)
485 grad_tensor2 = (grad * tensor1)._grad_sum_to_size(tensor2_size)
486 return grad_output._grad_sum_to_size(self_size), grad_tensor1, grad_tensor2, None
487 return result, backward
488
489 def _autocast_to_full_precision(self, cuda_enabled : bool, cpu_enabled : bool):
490 self_dtype = self.dtype
491 def backward(grad_output):
492 return grad_output.to(self_dtype), None, None
493
494 return torch._autocast_to_full_precision(self, cuda_enabled, cpu_enabled), backward
495
496 def _autocast_to_reduced_precision(self,
497 cuda_enabled : bool,
498 cpu_enabled : bool,
499 cuda_dtype : int,
500 cpu_dtype : int):
501 self_dtype = self.dtype
502 def backward(grad_output):
503 return grad_output.to(self_dtype), None, None, None, None
504
505 return torch._autocast_to_reduced_precision(self, cuda_enabled, cpu_enabled, cuda_dtype, cpu_dtype), backward
506
507 def _dim_arange(like,
508 dim: int):
509 def backward(grad_output):
510 return None, None
511
512 return torch._dim_arange(like, dim), backward
513
514 def contiguous(self, *, memory_format: int=0):
515 def backward(grad_output):
516 return grad_output, None
517
518 return self.contiguous(memory_format=memory_format), backward
519
520 def dot(self, tensor):
521 def backward(grad_output):
522 return grad_output * tensor, grad_output * self
523
524 return torch.dot(self, tensor), backward
525
526 def erf(self):
527 def backward(grad_output):
528 # Precomputed constant C = 2.0 / math.sqrt(math.pi)
529 C = 1.1283791670955126
530 return C * torch.exp(- self * self) * grad_output
531
532 return torch.erf(self), backward
533
534 def expand(self,
535 size: List[int],
536 *,
537 implicit: bool=False):
538 result = torch.expand(self, size, implicit=implicit)
539 self_size = torch._size_if_not_equal(self.size(), result.size())
540
541 def backward(grad_output):
542 return grad_output._grad_sum_to_size(self_size), None, None
543
544 return result, backward
545
546 def expand_as(self, other):
547 result = torch.expand_as(self, other)
548 self_size = torch._size_if_not_equal(self.size(), result.size())
549
550 def backward(grad_output):
551 return grad_output._grad_sum_to_size(self_size), None
552
553 return result, backward
554
555 def full_like(self,
556 fill_value: float):
557 def backward(grad_output):
558 return None, None
559
560 return torch.full_like(self, fill_value, memory_format=1), backward
561
562 def lerp_0(self,
563 end,
564 weight: number):
565 result = torch.lerp(self, end, weight)
566 self_size = torch._size_if_not_equal(self.size(), result.size())
567 end_size = torch._size_if_not_equal(end.size(), result.size())
568
569 def backward(grad_output):
570 grad_self = (grad_output * (1 - float(weight)))._grad_sum_to_size(self_size)
571 grad_end = (grad_output * float(weight))._grad_sum_to_size(end_size)
572 return grad_self, grad_end, None
573 return result, backward
574
575 def lerp_1(self,
576 end,
577 weight):
578 result = torch.lerp(self, end, weight)
579 self_size = torch._size_if_not_equal(self.size(), result.size())
580 end_size = torch._size_if_not_equal(end.size(), result.size())
581 weight_size = torch._size_if_not_equal(weight.size(), result.size())
582
583 def backward(grad_output):
584 grad_self = (grad_output * (1 - weight))._grad_sum_to_size(self_size)
585 grad_end = (grad_output * weight)._grad_sum_to_size(end_size)
586 grad_weight = (grad_output * (end - self))._grad_sum_to_size(weight_size)
587 return grad_self, grad_end, grad_weight
588
589 return result, backward
590
591 def reshape(self,
592 shape: List[int]):
593 self_size = self.size()
594
595 def backward(grad_output):
596 return grad_output.reshape(self_size), None
597
598 return torch.reshape(self, shape), backward
599
600 def split(self,
601 split_size: int,
602 dim: int):
603 def backward(grad_outputs: List[Tensor]):
604 grad_self = torch.cat(grad_outputs, dim)
605 return grad_self, None, None
606
607 return torch.split(self, split_size, dim), backward
608
609 def split_with_sizes(self,
610 split_sizes: List[int],
611 dim: int):
612 def backward(grad_outputs: List[Tensor]):
613 size = len(grad_outputs)
614 grad_self = torch.cat(grad_outputs, dim)
615 return grad_self, None, None
616
617 return torch.split_with_sizes(self, split_sizes, dim), backward
618
619 def stack(tensors: List[Tensor],
620 dim: int=0):
621 def backward(grad_output):
622 grad_tensors = torch.unbind(grad_output, dim)
623 return grad_tensors, None
624
625 return torch.stack(tensors, dim), backward
626
627 def unbind(self,
628 dim: int):
629 def backward(grad_outputs: List[Tensor]):
630 grad_self = torch.stack(grad_outputs, dim)
631 return grad_self, None
632
633 return torch.unbind(self, dim), backward
634
635 def cat(tensors: List[Tensor],
636 dim: int):
637 size = len(tensors)
638 split_sizes = [0] * size
639 for i in range(size):
640 if tensors[i].size() != [0]:
641 split_sizes[i] = tensors[i].size()[dim]
642
643 def backward(grad_output):
644 grad_tensors = torch.split_with_sizes(grad_output, split_sizes, dim)
645 return grad_tensors, None
646
647 return torch.cat(tensors, dim), backward
648
649 def index(self,
650 indices: List[Tensor]):
651 def backward(grad_output):
652 grad_self = torch.zeros_like(self, memory_format=1).index_put_(indices, grad_output, True)
653 return grad_self, None
654
655 return torch.index(self, indices), backward
656
657 def meshgrid(tensors: List[Tensor]):
658 size = len(tensors)
659 sizes = [0] * size
660 for i in range(size):
661 if tensors[i].dim() != 0:
662 sizes[i] = tensors[i].size()[0]
663 def backward(grad_outputs: List[Tensor]):
664 grads_tensors = []
665 for i in range(size):
666 view_shape = [1] * size
667 if sizes[i] == 0:
668 view_shape[i] = 1
669 grads_tensors.append((grad_outputs[i]._grad_sum_to_size(view_shape)).reshape(()))
670 else:
671 view_shape[i] = sizes[i]
672 grads_tensors.append((grad_outputs[i]._grad_sum_to_size(view_shape)).reshape([sizes[i]]))
673 return grads_tensors
674 return torch.meshgrid(tensors), backward
675
676 def mv(self, vec):
677 def backward(grad_output):
678 return grad_output.ger(vec), self.t().mv(grad_output)
679
680 return torch.mv(self, vec), backward
681
682 def nonzero(self):
683 def backward(grad_output):
684 return None
685
686 return torch.nonzero(self), backward
687
688 def ones_like(self):
689 def backward(grad_output):
690 return None
691
692 return torch.ones_like(self, memory_format=1), backward
693
694 def pow_0(self,
695 exponent: number):
696 def backward(grad_output):
697 if float(exponent) == 0.0:
698 grad_self = torch.zeros_like(self, memory_format=1)
699 else:
700 grad_self = grad_output * exponent * torch.pow(self, float(exponent) - 1)
701 return grad_self, None
702
703 return torch.pow(self, exponent), backward
704
705 def pow_1(self, exponent):
706 result = torch.pow(self, exponent)
707 self_size = torch._size_if_not_equal(self.size(), result.size())
708 exponent_size = torch._size_if_not_equal(exponent.size(), result.size())
709
710 def backward(grad_output):
711 grad_self = torch.where(exponent == 0.0, torch.zeros_like(self, memory_format=1), grad_output * exponent * torch.pow(self, exponent - 1))._grad_sum_to_size(self_size)
712 grad_exponent = (grad_output * torch.pow(self, exponent) * torch.log(self))._grad_sum_to_size(exponent_size)
713 return grad_self, grad_exponent
714
715 return result, backward
716
717 def pow_2(self: number,
718 exponent):
719 def backward(grad_output):
720 grad_exponent = grad_output * torch.pow(self, exponent) * torch.log(float(self))
721 return None, grad_exponent
722
723 return torch.pow(self, exponent), backward
724
725 def rsub_0(self,
726 other,
727 alpha: number):
728 result = torch.rsub(self, other, alpha=alpha)
729 self_size = torch._size_if_not_equal(self.size(), result.size())
730 other_size = torch._size_if_not_equal(other.size(), result.size())
731 def backward(grad_output):
732 grad_self = (- grad_output * alpha)._grad_sum_to_size(self_size)
733 grad_other = (grad_output)._grad_sum_to_size(other_size)
734 return grad_self, grad_other, None
735
736 return result, backward
737
738 def rsub_1(self,
739 other: number,
740 alpha: number):
741 def backward(grad_output):
742 grad_self = (- grad_output * alpha)
743 return grad_self, None, None
744
745 return torch.rsub(self, other, alpha), backward
746
747 def sqrt(self):
748 result = torch.sqrt(self)
749 def backward(grad_output):
750 return grad_output / (2 * result)
751
752 return result, backward
753
754 def t(self):
755 def backward(grad_output):
756 return torch.t(grad_output)
757
758 return torch.t(self), backward
759
760 def to_0(self,
761 device: Optional[Device],
762 dtype: Optional[int],
763 non_blocking: bool,
764 copy: bool):
765 self_device = self.device
766 self_dtype = self.dtype
767 if device is not None:
768 result = self.to(device, dtype=dtype, non_blocking=non_blocking, copy=copy)
769 else:
770 result = self.to(dtype, non_blocking=non_blocking, copy=copy)
771 def backward(grad_output):
772 grad_self = grad_output.to(self_device, dtype=self_dtype, non_blocking=non_blocking, copy=copy)
773 return grad_self, None, None, None, None
774
775 return result, backward
776
777
778 def to_1(self,
779 dtype: int,
780 non_blocking: bool,
781 copy: bool):
782 self_dtype = self.dtype
783 def backward(grad_output):
784 grad_self = grad_output.to(self_dtype, non_blocking, copy)
785 return grad_self, None, None, None
786
787 return self.to(dtype=dtype, non_blocking=non_blocking, copy=copy), backward
788
789 def to_2(self,
790 other,
791 non_blocking: bool,
792 copy: bool):
793 def backward(grad_output):
794 grad_self = grad_output.to(self, non_blocking, copy)
795 return grad_self, None, None, None
796
797 return self.to(other, non_blocking=non_blocking, copy=copy), backward
798
799 def transpose(self,
800 dim0: int,
801 dim1: int):
802 def backward(grad_output):
803 return torch.transpose(grad_output, dim0, dim1), None, None
804
805 return torch.transpose(self, dim0, dim1), backward
806
807 def view(self,
808 size: List[int]):
809 self_size = self.size()
810 def backward(grad_output):
811 return grad_output.reshape(self_size), None
812
813 return torch.view(self, size), backward
814 )",
815 R"(
816 def AD_sizes_if_not_equal_multi_0(t1, t2, res):
817 return torch._size_if_not_equal(t1.size(), res.size()), torch._size_if_not_equal(t2.size(), res.size())
818
819 def mul_0(self, other):
820 result = self * other
821 self_size, other_size = AD_sizes_if_not_equal_multi_0(self, other, result)
822
823 def backward(grad_output):
824 grad_self = (grad_output * other)._grad_sum_to_size(self_size)
825 grad_other = (grad_output * self)._grad_sum_to_size(other_size)
826 return grad_self, grad_other
827
828 return result, backward
829
830 def mul_1(self, other: number):
831 def backward(grad_output):
832 return grad_output * other, None
833 return self * other, backward
834
835 def div_0(self, other):
836 result = self / other
837 self_size, other_size = AD_sizes_if_not_equal_multi_0(self, other, result)
838
839 def backward(grad_output):
840 grad_self = (grad_output / other)._grad_sum_to_size(self_size)
841 grad_other = (-grad_output * self / (other * other))._grad_sum_to_size(other_size)
842 return grad_self, grad_other
843
844 return result, backward
845
846 def div_1(self, other: number):
847 def backward(grad_output):
848 return grad_output / other, None
849 return self / other, backward
850
851 def div_2(self, other, *, rounding_mode: Optional[str]):
852 result = torch.div(self, other, rounding_mode=rounding_mode)
853 self_size, other_size = AD_sizes_if_not_equal_multi_0(self, other, result)
854 def backward(grad_output):
855 if rounding_mode is None:
856 grad_self = (grad_output / other)._grad_sum_to_size(self_size)
857 grad_other = (-grad_output * self / (other * other))._grad_sum_to_size(other_size)
858 else:
859 grad_self = torch.zeros_like(self)
860 grad_other = torch.zeros_like(other)
861
862 return grad_self, grad_other, None
863
864 return result, backward
865
866 def div_3(self, other: number, *, rounding_mode: Optional[str]):
867 result = torch.div(self, other, rounding_mode=rounding_mode)
868 def backward(grad_output):
869 if rounding_mode is None:
870 grad_self = (grad_output / other)
871 else:
872 grad_self = torch.zeros_like(self, memory_format=1)
873 return grad_self, None, None
874 return result, backward
875
876 def max(self, other):
877 result = torch.max(self, other)
878 self_size, other_size = AD_sizes_if_not_equal_multi_0(self, other, result)
879
880 def backward(grad_output):
881 grad_self = (grad_output * (self > other).type_as(grad_output))._grad_sum_to_size(self_size)
882 grad_other = (grad_output * (other > self).type_as(grad_output))._grad_sum_to_size(other_size)
883 return grad_self, grad_other
884
885 return result, backward
886
887 def min(self, other):
888 def backward(grad_output):
889 grad_self = (grad_output * (self < other).type_as(grad_output))._grad_sum_to_size(self.size())
890 grad_other = (grad_output * (other < self).type_as(grad_output))._grad_sum_to_size(other.size())
891 return grad_self, grad_other
892
893 return torch.min(self, other), backward
894
895 def sigmoid(self):
896 result = torch.sigmoid(self)
897 def backward(grad_output):
898 return (1 - result) * result * grad_output
899
900 return result, backward
901
902 # Share backward with threshold
903 def relu(self):
904 result = torch.relu(self)
905 def backward(grad_output):
906 return grad_output * (result > 0).type_as(result)
907
908 return result, backward
909
910 def relu6(self):
911 result = torch.relu6(self)
912 def backward(grad_output):
913 return grad_output * ((result > 0) & (result < 6.0))
914
915 return result, backward
916
917 def leaky_relu(self, negative_slope: number):
918 result = torch.leaky_relu(self, negative_slope)
919 def backward(grad_output):
920 return grad_output * torch.where(self > 0, 1.0, negative_slope).type_as(result), None
921 return result, backward
922
923 def gelu(self : Tensor, *, approximate : str):
924 result = torch.gelu(self, approximate=approximate)
925 def backward(grad_output):
926 return torch.gelu_backward(grad_output, self, approximate=approximate), None
927 return result, backward
928
929 def silu(self):
930 result = torch.silu(self)
931 def backward(grad_output):
932 input_sigmoid = torch.sigmoid(self)
933 return grad_output * (input_sigmoid * (1 + self * (1 - input_sigmoid)))
934 return result, backward
935
936 def hardswish(self):
937 result = torch.hardswish(self)
938 def backward(grad_output):
939 m = (self > 3.).type_as(result)
940 m = torch.where((self >= -3.) & (self <= 3.), self / 3. + .5, m)
941 return grad_output * m
942 return result, backward
943
944 def hardsigmoid(self):
945 result = torch.hardsigmoid(self)
946 def backward(grad_output):
947 m = (self > -3.) & (self < 3.)
948 lhs = grad_output * (1.0 / 6.0)
949 return torch.where(m, lhs, m.type_as(self))
950 return result, backward
951
952 def erfc(self):
953 def backward(grad_output):
954 # Precomputed constant C = -2.0 / math.sqrt(math.pi)
955 C = -1.1283791670955126
956 return C * torch.exp(-self * self) * grad_output
957
958 return torch.erfc(self), backward
959
960 def exp(self):
961 result = torch.exp(self)
962 def backward(grad_output):
963 return grad_output * result
964
965 return result, backward
966
967 def neg(self):
968 def backward(grad_output):
969 return grad_output.neg()
970
971 return torch.neg(self), backward
972
973 def where(condition, self, other):
974 result = torch.where(condition, self, other)
975 self_size, other_size = AD_sizes_if_not_equal_multi_0(self, other, result)
976 def backward(grad_output):
977 grad_self = (grad_output * condition.type_as(grad_output))._grad_sum_to_size(self_size)
978 grad_other = (grad_output * (condition.bitwise_not()).type_as(grad_output))._grad_sum_to_size(other_size)
979 return None, grad_self, grad_other
980
981 return result, backward
982
983 def type_as(self, other):
984 def backward(grad_output):
985 return grad_output.type_as(self), None
986
987 return torch.type_as(self, other), backward
988
989 def unsqueeze(self, dim: int):
990 def backward(grad_output):
991 return grad_output.squeeze(dim), None
992
993 return torch.unsqueeze(self, dim), backward
994
995 def abs(self):
996 def backward(grad_output):
997 return grad_output * self.sign()
998
999 return torch.abs(self), backward
1000
1001 def acos(self):
1002 def backward(grad_output):
1003 return grad_output * -((-self * self + 1).rsqrt())
1004
1005 return torch.acos(self), backward
1006
1007 def asin(self):
1008 def backward(grad_output):
1009 return grad_output * (-self * self + 1).rsqrt()
1010
1011 return torch.asin(self), backward
1012
1013 def atan(self):
1014 def backward(grad_output):
1015 return grad_output / (self * self + 1)
1016
1017 return torch.atan(self), backward
1018
1019 def ceil(self):
1020 def backward(grad_output):
1021 return torch.zeros_like(grad_output, memory_format=1)
1022
1023 return torch.ceil(self), backward
1024
1025 def cos(self):
1026 def backward(grad_output):
1027 return grad_output * -self.sin()
1028
1029 return torch.cos(self), backward
1030
1031 def cosh(self):
1032 def backward(grad_output):
1033 return grad_output * self.sinh()
1034
1035 return torch.cosh(self), backward
1036
1037 def expm1(self):
1038 result = torch.expm1(self)
1039 def backward(grad_output):
1040 return grad_output * (result + 1)
1041
1042 return result, backward
1043
1044 def floor(self):
1045 def backward(grad_output):
1046 return torch.zeros_like(grad_output, memory_format=1)
1047
1048 return torch.floor(self), backward
1049
1050 def frac(self):
1051 def backward(grad_output):
1052 return grad_output
1053
1054 return torch.frac(self), backward
1055
1056 def log(self):
1057 def backward(grad_output):
1058 return grad_output.div(self)
1059
1060 return torch.log(self), backward
1061
1062 def log10(self):
1063 def backward(grad_output):
1064 return grad_output / (self * 2.3025850929940456)
1065
1066 return torch.log10(self), backward
1067
1068 def log1p(self):
1069 def backward(grad_output):
1070 return grad_output / (self + 1)
1071
1072 return torch.log1p(self), backward
1073
1074 def log2(self):
1075 def backward(grad_output):
1076 return grad_output / (self * 0.6931471805599453)
1077
1078 return torch.log2(self), backward
1079
1080 # TODO: Fix rand_like to match expected format
1081 # def rand_like(self, *, memory_format: Optional[int]):
1082 # def backward(grad_output):
1083 # return None
1084
1085 # return torch.rand_like(self, memory_format=memory_format), backward
1086
1087 def reciprocal(self):
1088 result = torch.reciprocal(self)
1089 def backward(grad_output):
1090 return -grad_output * result * result
1091
1092 return result, backward
1093
1094 def round(self):
1095 def backward(grad_output):
1096 return torch.zeros_like(grad_output, memory_format=1)
1097
1098 return torch.round(self), backward
1099
1100 def rsqrt(self):
1101 result = torch.rsqrt(self)
1102 def backward(grad_output):
1103 return -grad_output * result * result * result / 2
1104
1105 return result, backward
1106
1107 def sin(self):
1108 def backward(grad_output):
1109 return grad_output * self.cos()
1110
1111 return torch.sin(self), backward
1112
1113 def sinh(self):
1114 def backward(grad_output):
1115 return grad_output * self.cosh()
1116
1117 return torch.sinh(self), backward
1118
1119 def tan(self):
1120 result = torch.tan(self)
1121 def backward(grad_output):
1122 return grad_output * (1. + result * result)
1123
1124 return result, backward
1125
1126 def trunc(self):
1127 def backward(grad_output):
1128 return torch.zeros_like(grad_output, memory_format=1)
1129
1130 return torch.trunc(self), backward
1131
1132 def _grad_sum_to_size(self,
1133 size: Optional[List[int]]):
1134 result = torch._grad_sum_to_size(self, size)
1135 self_size = torch._size_if_not_equal(self.size(), result.size())
1136
1137 def backward(grad_output):
1138 if self_size is None:
1139 grad_input = grad_output
1140 else:
1141 grad_input = grad_output.expand(self_size)
1142 return grad_input, None
1143
1144 return result, backward
1145 )",
1146 R"(
1147 def batch_norm(input : Tensor,
1148 weight : Optional[Tensor],
1149 bias : Optional[Tensor],
1150 running_mean : Optional[Tensor],
1151 running_var : Optional[Tensor],
1152 training : bool,
1153 momentum : float,
1154 eps : float,
1155 cudnn_enabled : bool):
1156
1157 output, save1, save2, reserve, impl_idx = torch._batch_norm_impl_index(
1158 input, weight, bias, running_mean, running_var, training,
1159 momentum, eps, cudnn_enabled)
1160 has_weight = weight is not None
1161 has_bias = bias is not None
1162
1163 def backward(grad_output):
1164 dinput, dweight, dbias = torch._batch_norm_impl_index_backward(
1165 impl_idx, input, grad_output, weight, running_mean, running_var,
1166 save1, save2, training, eps, [True, has_weight, has_bias], reserve)
1167 return dinput, dweight, dbias, None, None, None, None, None, None
1168
1169 return output, backward
1170
1171 def layer_norm(input : Tensor,
1172 normalized_shape : List[int],
1173 weight : Optional[Tensor],
1174 bias : Optional[Tensor],
1175 eps : float,
1176 cudnn_enable : bool):
1177
1178 output, mean, rstd = torch.native_layer_norm(input, normalized_shape, weight, bias, eps)
1179
1180 def backward(grad_output):
1181 output_mask = [True, weight is not None, bias is not None]
1182 grad_input, grad_weight, grad_bias = torch.native_layer_norm_backward(grad_output, input, normalized_shape, mean, rstd, weight, bias, output_mask)
1183 return grad_input, None, grad_weight, grad_bias, None, None
1184 return output, backward
1185
1186 def dropout(input,
1187 p: float,
1188 train: bool):
1189 # if `train == false` we need to set `p1m` to 0 so `scale == 1`
1190 p1m = (1. - p) * float(train)
1191 scale = 1. / (float(p1m == 0.) + p1m)
1192 res,mask = torch.native_dropout(input, p, train)
1193
1194 def backward(grad_output):
1195 grad_input = torch.native_dropout_backward(grad_output, mask, scale)
1196 return grad_input, None, None
1197 return res, backward
1198
1199 def embedding(weight,
1200 indices,
1201 padding_idx: int,
1202 scale_grad_by_freq: bool,
1203 sparse: bool):
1204 weight_size_0 = weight.size()[0]
1205 def backward(grad_output):
1206 grad_weight = torch.embedding_backward(grad_output, indices, weight_size_0, padding_idx, scale_grad_by_freq, sparse)
1207 return grad_weight, None, None, None, None
1208
1209 return torch.embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse), backward
1210
1211 def log_softmax(self, dim: int, dtype: Optional[int]):
1212 result = torch.log_softmax(self, dim, dtype)
1213 def backward(grad_output):
1214 grad_self = torch._log_softmax_backward_data(grad_output, result, dim, self.dtype)
1215 return grad_self, None, None
1216
1217 return result, backward
1218
1219 def nll_loss(self, target, weight: Optional[Tensor], reduction: int, ignore_index: int):
1220 result, total_weight = torch.nll_loss_forward(self, target, weight, reduction, ignore_index)
1221 def backward(grad):
1222 return torch.nll_loss_backward(grad, self, target, weight, reduction, ignore_index, total_weight), None, None, None, None
1223 return result, backward
1224
1225 def softmax(self, dim: int, dtype: Optional[int]):
1226 result = torch.softmax(self, dim, dtype)
1227 def backward(grad_output):
1228 grad_self = torch._softmax_backward_data(grad_output, result, dim, self.dtype)
1229 return grad_self, None, None
1230
1231 return result, backward
1232 )",
1233 R"(
1234 def AD_adaptive_avg_pool3d_backward(grad,
1235 self,
1236 output_size: List[int]):
1237 if output_size[0] == 1 and output_size[1] == 1 and output_size[2] == 1:
1238 self_size = self.size()
1239 grad_self = grad.expand(self.size()) / (self_size[-1] * self_size[-2] * self_size[-3])
1240 else:
1241 grad_self = torch._adaptive_avg_pool3d_backward(grad, self)
1242
1243 return grad_self
1244
1245 def AD_adaptive_avg_pool2d_backward(grad,
1246 self,
1247 output_size: List[int]):
1248 if output_size[0] == 1 and output_size[1] == 1:
1249 self_size = self.size()
1250 grad_self = grad.expand(self.size()) / (self_size[-1] * self_size[-2])
1251 else:
1252 grad_self = torch._adaptive_avg_pool2d_backward(grad, self)
1253
1254 return grad_self
1255
1256 def AD_adaptive_avg_pool1d_backward(grad,
1257 input,
1258 output_size: List[int]):
1259 output_size_2d = [1, output_size[0]]
1260 grad_input = AD_adaptive_avg_pool2d_backward(grad.unsqueeze(2), input.unsqueeze(2), output_size_2d).squeeze(2)
1261 return grad_input
1262
1263 def adaptive_avg_pool1d(self,
1264 output_size: List[int]):
1265 def backward(grad_output):
1266 grad_self = AD_adaptive_avg_pool1d_backward(grad_output, self, output_size)
1267 return grad_self, None
1268
1269 return torch.adaptive_avg_pool1d(self, output_size), backward
1270
1271 def adaptive_avg_pool2d(self,
1272 output_size: List[int]):
1273 def backward(grad_output):
1274 # self is used in backward, no need to pass in its size explicitly
1275 grad_self = AD_adaptive_avg_pool2d_backward(grad_output, self, output_size)
1276 return grad_self, None
1277 return torch.adaptive_avg_pool2d(self, output_size), backward
1278
1279 def adaptive_avg_pool3d(self,
1280 output_size: List[int]):
1281 def backward(grad_output):
1282 grad_self = AD_adaptive_avg_pool3d_backward(grad_output, self, output_size)
1283 return grad_self, None
1284
1285 return torch.adaptive_avg_pool3d(self, output_size), backward
1286
1287 def avg_pool2d(self,
1288 kernel_size: List[int],
1289 stride: List[int],
1290 padding: List[int],
1291 ceil_mode: bool,
1292 count_include_pad: bool,
1293 divisor_override: Optional[int]):
1294 def backward(grad_output):
1295 grad_self = torch.avg_pool2d_backward(grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
1296 return grad_self, None, None, None, None, None, None
1297
1298 return torch.avg_pool2d(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override), backward
1299
1300 def max_pool2d(self,
1301 kernel_size: List[int],
1302 stride: List[int],
1303 padding: List[int],
1304 dilation: List[int],
1305 ceil_mode: bool):
1306 output, indices = torch.max_pool2d_with_indices(self, kernel_size, stride, padding, dilation, ceil_mode)
1307 def backward(grad_output):
1308 grad_self = torch.max_pool2d_with_indices_backward(grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices)
1309 return grad_self, None, None, None, None, None
1310 return output, backward
1311
1312 def max_pool2d_with_indices(self,
1313 kernel_size: List[int],
1314 stride: List[int],
1315 padding: List[int],
1316 dilation: List[int],
1317 ceil_mode: bool):
1318 output, indices = torch.max_pool2d_with_indices(self, kernel_size, stride, padding, dilation, ceil_mode)
1319 def backward(grad_output):
1320 grad_self = torch.max_pool2d_with_indices_backward(grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices)
1321 return grad_self, None, None, None, None, None
1322 return output, indices, backward
1323 )",
1324 R"(
1325 def AD_sizes_if_not_equal_multi_1(t1, t2, res):
1326 return torch._size_if_not_equal(t1.size(), res.size()), torch._size_if_not_equal(t2.size(), res.size())
1327
1328 def add_0(self,
1329 other,
1330 *,
1331 alpha: number):
1332 result = torch.add(self, other, alpha=alpha)
1333 self_size, other_size = AD_sizes_if_not_equal_multi_1(self, other, result)
1334 def backward(grad_output):
1335 grad_other = (grad_output * alpha)._grad_sum_to_size(other_size)
1336 grad_self = (grad_output)._grad_sum_to_size(self_size)
1337 return grad_self, grad_other, None
1338 return result, backward
1339
1340 def add_1(self,
1341 other: number,
1342 alpha: number):
1343 def backward(grad_output):
1344 return grad_output, None, None
1345 return torch.add(self, other, alpha=alpha), backward
1346
1347 def sub_0(self,
1348 other,
1349 *,
1350 alpha: number):
1351 result = torch.sub(self, other, alpha=alpha)
1352 self_size, other_size = AD_sizes_if_not_equal_multi_1(self, other, result)
1353 def backward(grad_output):
1354 grad_other = (-grad_output * alpha)._grad_sum_to_size(other_size)
1355 grad_self = (grad_output)._grad_sum_to_size(self_size)
1356 return grad_self, grad_other, None
1357 return result , backward
1358
1359 def sub_1(self,
1360 other: number,
1361 alpha: number):
1362 def backward(grad_output):
1363 return grad_output, None, None
1364 return torch.sub(self, other, alpha=alpha), backward
1365
1366 def threshold(self,
1367 threshold: number,
1368 value: number):
1369 def backward(grad_output):
1370 mask = (self >= threshold).type_as(self)
1371 return grad_output * mask, None, None
1372 return torch.threshold(self, threshold, value), backward
1373
1374 def softplus(self,
1375 beta: number,
1376 threshold: number):
1377 result = torch.softplus(self, beta, threshold)
1378 def backward(grad_output):
1379 z = torch.exp(result * beta)
1380 return torch.where((result * beta) > threshold, grad_output, grad_output * (z - 1.) / z), None, None
1381 return result, backward
1382
1383 def fmod(self,
1384 other: number):
1385 def backward(grad_output):
1386 return grad_output, None
1387 return torch.fmod(self, other), backward
1388
1389 def remainder(self,
1390 other: number):
1391 def backward(grad_output):
1392 return grad_output, None
1393 return torch.remainder(self, other), backward
1394
1395 def addmm(self,
1396 mat1,
1397 mat2,
1398 *,
1399 beta: number,
1400 alpha: number):
1401 result = torch.addmm(self, mat1, mat2, beta=beta, alpha=alpha)
1402 self_size = torch._size_if_not_equal(self.size(), result.size())
1403 def backward(grad_output):
1404 self_grad = (grad_output * beta)._grad_sum_to_size(self_size)
1405 mat1_grad = grad_output.mm(mat2.t()) * alpha
1406 mat2_grad = mat1.t().mm(grad_output) * alpha
1407 return self_grad, mat1_grad, mat2_grad, None, None
1408 return result, backward
1409
1410 # Comparison operators
1411 def lt(self, other: number):
1412 def backward(grad_output):
1413 return None, None
1414 return torch.lt(self, other), backward
1415
1416 def le(self, other: number):
1417 def backward(grad_output):
1418 return None, None
1419 return torch.le(self, other), backward
1420
1421 def gt(self, other: number):
1422 def backward(grad_output):
1423 return None, None
1424 return torch.gt(self, other), backward
1425
1426 def ge(self, other: number):
1427 def backward(grad_output):
1428 return None, None
1429 return torch.ge(self, other), backward
1430
1431 def eq(self, other: number):
1432 def backward(grad_output):
1433 return None, None
1434 return torch.eq(self, other), backward
1435
1436 def ne(self, other: number):
1437 def backward(grad_output):
1438 return None, None
1439 return torch.ne(self, other), backward
1440
1441 def hardshrink(self, lambd: number):
1442 def backward(grad_output):
1443 mask = ((self > lambd) | (self < -lambd))
1444 return grad_output * mask, None
1445 return torch.hardshrink(self, lambd=lambd), backward
1446
1447 def hardtanh(self, min_val: number, max_val: number):
1448 def backward(grad_output):
1449 mask = ((self >= min_val) * (self <= max_val))
1450 return grad_output * mask, None, None
1451 return torch.hardtanh(self, min_val=min_val, max_val=max_val), backward
1452
1453 def clamp_1(self,
1454 min: Optional[number],
1455 max: Optional[number]):
1456 def backward(grad_output):
1457 if min is not None and max is not None:
1458 mask = ((self >= float(min)) * (self <= float(max))).type_as(self)
1459 return grad_output * mask, None, None
1460 elif min is not None:
1461 mask = (self >= float(min)).type_as(self)
1462 return grad_output * mask, None, None
1463 elif max is not None:
1464 mask = (self <= float(max)).type_as(self)
1465 return grad_output * mask, None, None
1466 else: #min is None and max is None
1467 return grad_output, None, None
1468 return torch.clamp(self, min=min, max=max), backward
1469
1470 def clamp_2(self,
1471 min: Optional[Tensor],
1472 max: Optional[Tensor]):
1473 def backward(grad_output):
1474 if min is not None and max is not None:
1475 mask = ((self >= min) * (self <= max)).type_as(self)
1476 return grad_output * mask, None, None
1477 elif min is not None:
1478 mask = (self >= min).type_as(self)
1479 return grad_output * mask, None, None
1480 elif max is not None:
1481 mask = (self <= max).type_as(self)
1482 return grad_output * mask, None, None
1483 else: #min is None and max is None
1484 return grad_output, None, None
1485 return torch.clamp(self, min=min, max=max), backward
1486 )"};
1487
1488 std::unordered_map<std::string, GradientPair> schema_to_graphs;
1489
1490 // This map is a workaround to cache compiled gradient_pairs. Ideally this graph
1491 // should be compiled only once and saved in Operator structure.
1492 // This should be done along with merging into native_functions.yaml.
1493 std::unordered_map<const FunctionSchema*, GradientPair> cached_gradient_pairs;
1494
1495 // CompilationUnit that holds all these Functions and keeps them alive.
1496 CompilationUnit compilation_unit;
1497 } // anonymous namespace
1498
extractClosure(Value * closure)1499 static std::pair<std::shared_ptr<Graph>, Value*> extractClosure(
1500 Value* closure) {
1501 TORCH_CHECK(
1502 closure->node()->kind() == prim::TupleConstruct,
1503 "closure must be a literal tuple construct");
1504 Value* fn = closure->node()->inputs().at(0);
1505 Value* context = closure->node()->inputs().at(1);
1506
1507 TORCH_CHECK(
1508 fn->node()->kind() == prim::Closure,
1509 "closure tuple must contain a prim::Closure");
1510 return std::make_pair(fn->node()->g(attr::Subgraph), context);
1511 }
1512
originalReturnType(const TupleTypePtr & tup)1513 static Argument originalReturnType(const TupleTypePtr& tup) {
1514 TORCH_CHECK(tup->elements().size() > 1);
1515 if (tup->elements().size() == 2)
1516 return Argument("", tup->elements().at(0));
1517 std::vector<TypePtr> types = tup->elements().vec();
1518 types.pop_back();
1519 return Argument("", TupleType::create(std::move(types)));
1520 }
1521
1522 // In torchscript AD formulas, we define {func_0, func_1, ...} as
1523 // overloaded functions of `func`.
1524 // Remove the suffix before adding the schema string to map
1525 // schema_to_graphs.
overloadedSchemaString(const FunctionSchema & schema)1526 static std::string overloadedSchemaString(const FunctionSchema& schema) {
1527 const auto& schema_name = schema.name();
1528 auto pos = schema_name.find_last_of('_');
1529 auto schema_name_suffix = schema_name.substr(pos + 1);
1530 std::string schema_string = canonicalSchemaString(schema);
1531 if (!schema_name_suffix.empty() &&
1532 schema_name_suffix.find_first_not_of("0123456789") == std::string::npos) {
1533 schema_string.replace(
1534 schema_string.find(schema_name),
1535 schema_name.length(),
1536 schema_name.substr(0, pos));
1537 }
1538
1539 return schema_string;
1540 }
1541
isHelperFunction(const std::string & method_name)1542 static bool isHelperFunction(const std::string& method_name) {
1543 std::string helper_prefix = "AD_";
1544 return method_name.compare(0, helper_prefix.length(), helper_prefix) == 0;
1545 }
1546
loadModule(const CompilationUnit & module)1547 static void loadModule(const CompilationUnit& module) {
1548 for (const auto& method : module.get_functions()) {
1549 if (isHelperFunction(method->name()))
1550 continue;
1551
1552 GradientPair pair;
1553 pair.forward = toGraphFunction(*method).graph();
1554
1555 // lookup the backward function
1556 Node* forward_tuple = pair.forward->outputs().at(0)->node();
1557
1558 if (forward_tuple->kind() != prim::TupleConstruct) {
1559 throw(
1560 ErrorReport(forward_tuple->sourceRange())
1561 << "gradient must return literal a tuple");
1562 }
1563
1564 Value* context = nullptr;
1565 std::tie(pair.backward, context) =
1566 extractClosure(forward_tuple->inputs().back());
1567
1568 // checks that num forward graph inputs equals num backward graph outputs
1569 TORCH_CHECK(
1570 pair.forward->inputs().size() ==
1571 unpackOutputs(pair.backward->outputs().vec()).size(),
1572 "The autodiff implementation of ",
1573 method->name(),
1574 " backward() returns an incorrect number of values: ",
1575 unpackOutputs(pair.backward->outputs().vec()).size(),
1576 " instead of ",
1577 pair.forward->inputs().size());
1578
1579 // do surgery on the forward function to remove the closure tuple and
1580 // replace it with the context variable:
1581 // backward = (<lambda>, context_tuple)
1582 // return original, backward
1583 // -----
1584 // return original, context_tuple
1585 std::vector<Value*> new_inputs = forward_tuple->inputs().vec();
1586 new_inputs.back() = context;
1587 Value* new_tuple =
1588 pair.forward->appendNode(pair.forward->createTuple(new_inputs))
1589 ->output();
1590 pair.forward->eraseOutput(0);
1591 pair.forward->registerOutput(new_tuple);
1592 forward_tuple->destroy();
1593
1594 // derive schema from original function's schema:
1595 const FunctionSchema& loaded_schema = method->getSchema();
1596 FunctionSchema actual_schema(
1597 Symbol::aten(loaded_schema.name()),
1598 loaded_schema.overload_name(),
1599 loaded_schema.arguments(),
1600 {originalReturnType(new_tuple->type()->expect<TupleType>())});
1601
1602 // modify canonical string for function overloading
1603 // prefer not to modify the schema name
1604 auto schema_string = overloadedSchemaString(actual_schema);
1605
1606 schema_to_graphs[schema_string] = std::move(pair);
1607 }
1608 }
1609
loadFunctions()1610 static void loadFunctions() {
1611 for (const std::string& str : functions) {
1612 compilation_unit.define(std::nullopt, str, nativeResolver(), nullptr);
1613 }
1614 loadModule(compilation_unit);
1615 }
1616
gradientInfoForSchema(const FunctionSchema & schema)1617 std::optional<GradientPair> gradientInfoForSchema(
1618 const FunctionSchema& schema) {
1619 std::lock_guard<std::mutex> guard(lock);
1620 if (schema_to_graphs.empty()) {
1621 loadFunctions();
1622 }
1623 auto cache_it = cached_gradient_pairs.find(&schema);
1624 if (cache_it != cached_gradient_pairs.end()) {
1625 return cache_it->second;
1626 } else {
1627 auto schema_str = canonicalSchemaString(schema);
1628 // For debugging AD change:
1629 // std::cout << "Looking for " << schema_str << std::endl;
1630 auto sym_script_it = schema_to_graphs.find(schema_str);
1631
1632 if (sym_script_it != schema_to_graphs.end()) {
1633 cached_gradient_pairs.emplace_hint(
1634 cache_it, &schema, sym_script_it->second);
1635 return sym_script_it->second;
1636 }
1637 }
1638 return std::nullopt;
1639 }
1640
hasGradientInfoForSchema(const FunctionSchema & schema)1641 bool hasGradientInfoForSchema(const FunctionSchema& schema) {
1642 return gradientInfoForSchema(schema).has_value();
1643 }
1644
1645 } // namespace torch::jit
1646