xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/symbolic_script.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/runtime/symbolic_script.h>
2 
3 #include <torch/csrc/jit/frontend/ir_emitter.h>
4 #include <torch/csrc/jit/runtime/operator.h>
5 
6 namespace torch::jit {
7 namespace {
8 std::mutex lock;
9 const std::vector<std::string> functions = {
10     R"(
11         ####     HELPER FUNCTIONS           ###
12         ####     PREFIX: AD_                ###
13         ####     SCHEMA NOT SAVED IN CACHE  ###
14 
15         def AD_unsqueeze_multiple(t,
16                                   dims: List[int],
17                                   n_dims: int):
18             seen = [False] * n_dims
19             for i in range(len(dims)):
20                 seen[dims[i]] = True
21 
22             for d in range(n_dims):
23                 if seen[d]:
24                     t = t.unsqueeze(d)
25             return t
26 
27         def AD_sum_backward(grad,
28                             sizes: List[int],
29                             dims: Optional[List[int]],
30                             keepdim: bool):
31             if not keepdim and len(sizes) > 0:
32                 if dims is None:
33                     return grad.expand(sizes)
34                 elif len(dims) == 1:
35                     return grad.unsqueeze(dims[0]).expand(sizes)
36                 else:
37                     res = AD_unsqueeze_multiple(grad, dims, len(sizes))
38                     return res.expand(sizes)
39             else:
40                 return grad.expand(sizes)
41 
42         def AD_logsumexp_backward(grad, self, result,
43                                   dim: List[int],
44                                   keepdim: bool):
45             if not keepdim and self.dim() != 0:
46                 n_dims = len(self.size())
47                 grad = AD_unsqueeze_multiple(grad, dim, n_dims)
48                 result = AD_unsqueeze_multiple(result, dim, n_dims)
49             return grad * (self - result).exp()
50 
51         def mean_0(self, *, dtype: Optional[int]):
52             self_size = self.size()
53             self_numel = self.numel()
54             self_scalar_type = self.dtype
55             def backward(grad_output):
56                 return grad_output.expand(self_size).to(self_scalar_type) / self_numel, None
57 
58             return torch.mean(self, dtype=dtype), backward
59 
60         def mean_1(self,
61                    dim: Optional[List[int]],
62                    keepdim: bool,
63                    *,
64                    dtype: Optional[int]):
65             self_size = self.size()
66             self_scalar_type = self.dtype
67             def backward(grad_output):
68                 grad_self = AD_sum_backward(grad_output, self_size, dim, keepdim).to(self_scalar_type) / AD_safe_size(self_size, dim)
69                 return grad_self, None, None, None
70 
71             return torch.mean(self, dim, keepdim, dtype=dtype), backward
72 
73         def logsumexp(self,
74                       dim: List[int],
75                       keepdim: bool):
76             result = torch.logsumexp(self, dim, keepdim)
77             self_dim = self.dim()
78             def backward(grad_output):
79                 grad_self = AD_logsumexp_backward(grad_output, self, result, dim, keepdim)
80                 return grad_self, None, None
81 
82             return result, backward
83 
84         def AD_bool_to_int(b: bool):
85             # FIXME: torchscript: int - bool
86             if b:
87                 i = 1
88             else:
89                 i = 0
90             return i
91 
92         def AD_var_backward_0(grad, self, correction: number):
93             # FIXME: torchscript: div(float, float)
94             return  grad * (self - self.mean()) * 2.0 / (self.numel() - correction)
95 
96         def AD_safe_size(sizes: List[int],
97                          dims: Optional[List[int]]):
98             if len(sizes) == 0:
99                 return 1
100 
101             size = 1
102 
103             if dims is None:
104               for s in sizes:
105                 size *= s
106 
107             else:
108               for i in range(len(dims)):
109                   d = dims[i]
110                   size *= sizes[d]
111 
112             return size
113 
114         def AD_var_backward_1(grad,
115                               self,
116                               dim: List[int],
117                               correction: number,
118                               keepdim: bool):
119             if self.dim() == 0:
120                 return AD_var_backward_0(grad, self, correction)
121             self_size = self.size()
122             if not keepdim and self.dim() > 1:
123                 grad = AD_unsqueeze_multiple(grad, dim, len(self_size))
124 
125             # FIXME: torchscript: div(float, float)
126             return grad * (self - self.mean(dim, True)) * 2.0 / (AD_safe_size(self_size, dim) - correction)
127 
128         def AD_var_backward_2(grad,
129                               self,
130                               dim: Optional[List[int]],
131                               correction: Optional[number],
132                               keepdim: bool):
133             if correction is None:
134                 correction = 1
135             if self.dim() == 0 or dim is None:
136               return AD_var_backward_0(grad, self, correction)
137 
138             return AD_var_backward_1(grad, self, dim, correction, keepdim)
139 
140         def std_0(self,
141                   unbiased: bool=True):
142             std_out = torch.std(self, unbiased)
143             def backward(grad_output):
144                 correction = AD_bool_to_int(unbiased)
145                 grad_self = AD_var_backward_0(grad_output / (std_out * 2), self, correction)
146                 return grad_self, None
147 
148             return std_out, backward
149 
150         def std_1(self,
151                   dim: Optional[List[int]],
152                   unbiased: bool,
153                   keepdim: bool):
154             std_out = torch.std(self, dim, unbiased, keepdim)
155             def backward(grad_output):
156                 correction = AD_bool_to_int(unbiased)
157                 grad_self = AD_var_backward_2(grad_output / (std_out * 2), self, dim, correction, keepdim)
158                 return grad_self, None, None, None
159 
160             return std_out, backward
161 
162         def std_2(self,
163                   dim: Optional[List[int]],
164                   *,
165                   correction: Optional[number],
166                   keepdim: bool):
167             std_out = torch.std(self, dim, correction=correction, keepdim=keepdim)
168             def backward(grad_output):
169                 grad_self = AD_var_backward_2(grad_output / (std_out * 2), self, dim, correction, keepdim)
170                 return grad_self, None, None, None
171 
172             return std_out, backward
173 
174         def var_0(self,
175                   unbiased: bool=True):
176             def backward(grad_output):
177                 correction = AD_bool_to_int(unbiased)
178                 grad_self = AD_var_backward_0(grad_output, self, correction)
179                 return grad_self, None
180 
181             return torch.var(self, unbiased), backward
182 
183         def var_1(self,
184                   dim: Optional[List[int]],
185                   unbiased: bool,
186                   keepdim: bool):
187             def backward(grad_output):
188                 correction = AD_bool_to_int(unbiased)
189                 grad_self = AD_var_backward_2(grad_output, self, dim, correction, keepdim)
190                 return grad_self, None, None, None
191 
192             return torch.var(self, dim, unbiased, keepdim), backward
193 
194         def var_2(self,
195                   dim: Optional[List[int]],
196                   *,
197                   correction: Optional[number],
198                   keepdim: bool):
199             def backward(grad_output):
200                 grad_self = AD_var_backward_2(grad_output, self, dim, correction, keepdim)
201                 return grad_self, None, None, None
202 
203             return torch.var(self, dim, correction=correction, keepdim=keepdim), backward
204 
205         def tanh(self):
206             output = torch.tanh(self)
207             def backward(grad_output):
208                 return grad_output * (1 - output * output)
209 
210             return output, backward
211 
212         def AD_index_select_backward(grad,
213                                      dim: int,
214                                      indices,
215                                      sizes: List[int],
216                                      keepdim: bool):
217             if not keepdim and len(sizes) > 0:
218                 grad = grad.unsqueeze(dim)
219                 indices = indices.unsqueeze(dim)
220 
221             # FIXME: torchscript: torch.zeros(sizes, grad.options())
222             return torch.zeros(sizes).to(grad).scatter_(dim, indices, grad)
223 
224         # def topk(self,
225         #          k: int,
226         #          dim: int = -1,
227         #          largest: bool = True,
228         #          sorted: bool = True):
229         #     result0, result1 = torch.topk(self, k, dim, largest, sorted)
230         #     self_size = self.size()
231         #     def backward(grad_output):
232         #         grad_self = AD_index_select_backward(grad_output, dim, result1, self_size, True)
233         #         return grad_self, None, None, None, None
234 
235         #     return result0, result1, backward
236 
237         # def kthvalue(self,
238         #              k: int,
239         #              dim: int,
240         #              keepdim: bool):
241         #     result0, result1 = torch.kthvalue(self, k, dim, keepdim)
242         #     self_size = self.size()
243         #     def backward(grad_output):
244         #         grad_self = AD_index_select_backward(grad_output, dim, result1, self_size, keepdim)
245         #         return grad_self, None, None, None
246 
247         #     return result0, result1, backward
248 
249         def AD_mm_backward_self(grad, mat2):
250             return grad.mm(mat2.t())
251 
252         def AD_mm_backward_mat2(grad, self):
253             return self.t().mm(grad)
254 
255         def mm(self, mat2):
256             def backward(grad_output):
257                 grad_self = AD_mm_backward_self(grad_output, mat2)
258                 grad_mat2 = AD_mm_backward_mat2(grad_output, self)
259                 return grad_self, grad_mat2
260 
261             return torch.mm(self, mat2), backward
262 
263         def AD_permute_backward(grad,
264                                 fwd_dims: List[int]):
265             ndims = len(fwd_dims)
266             dims = [0] * ndims
267 
268             for i in range(ndims):
269                 dims[fwd_dims[i]] = i
270 
271             return grad.permute(dims)
272 
273         def permute(self,
274                     dims: List[int]):
275             def backward(grad_output):
276                 grad_self = AD_permute_backward(grad_output, dims)
277                 return grad_self, None
278 
279             return torch.permute(self, dims), backward
280 
281         def AD_select_backward(grad,
282                                input_sizes: List[int],
283                                dim: int,
284                                index: int):
285             # FIXME: torchscript: torch.zeros(sizes, grad.options())
286             grad_input = torch.zeros(input_sizes).to(grad)
287             grad_input.select(dim, index).copy_(grad)
288             return grad_input
289 
290         # TODO: fix torch.zeros(sizes, grad.options()) before enabling select, topk, kthvalue
291         # def select(self,
292         #            dim: int,
293         #            index: int):
294         #     self_size = self.size()
295         #     def backward(grad_output):
296         #         grad_self = AD_select_backward(grad_output, self_size, dim, index)
297         #         return grad_self, None, None
298 
299         #     return torch.select(self, dim, index), backward
300 
301         def AD_slice_backward(grad,
302                               input_sizes: List[int],
303                               dim: int,
304                               start: int,
305                               end: int,
306                               step: int):
307             # FIXME: torchscript: torch.zeros(sizes, grad.options())
308             grad_input = torch.zeros(input_sizes).to(grad)
309             grad_input.slice(dim, start, end, step).copy_(grad)
310             return grad_input
311 
312         # DON'T enable slice unless we can correctly handle view ops in graph executor.
313         # It triggers failure of TestJit.test_sample in test_distributions.py.
314         # def slice(self,
315         #           dim: int=0,
316         #           start: int=0,
317         #           end: int=9223372036854775807,
318         #           step: int=1):
319         #     def backward(grad_output):
320         #         grad_self = AD_slice_backward(grad_output, self.size(), dim, start, end, step)
321         #         return grad_self, None, None, None, None
322 
323         #     return torch.slice(self, dim, start, end, step), backward
324 
325         def AD_unsqueeze_to_0(self,
326                               sizes: List[int]):
327             ndims = len(sizes)
328             for i in range(ndims):
329                 if sizes[i] == 1:
330                     self = self.unsqueeze(i)
331 
332             return self
333 
334         def AD_unsqueeze_to_1(self,
335                               dim: int,
336                               sizes: List[int]):
337             if len(sizes) > 0 and sizes[dim] == 1:
338                 return self.unsqueeze(dim)
339             return self
340 
341         def squeeze_0(self):
342             self_size = self.size()
343             def backward(grad_output):
344                 grad_self = AD_unsqueeze_to_0(grad_output, self_size)
345                 return grad_self
346 
347             return torch.squeeze(self), backward
348 
349         def squeeze_1(self,
350                       dim: int):
351             self_size = self.size()
352             def backward(grad_output):
353                 grad_self = AD_unsqueeze_to_1(grad_output, dim, self_size)
354                 return grad_self, None
355 
356             return torch.squeeze(self, dim), backward
357 
358         def AD_infer_size(a: List[int],
359                           b: List[int]):
360             dimsA = len(a)
361             dimsB = len(b)
362 
363             ndim = dimsA if dimsA > dimsB else dimsB
364             expand_sizes = [0] * ndim
365 
366             for i in range(ndim):
367                 idx = - i + ndim - 1
368                 sizeA = a[i] if dimsA + i >= 0 else 1
369                 sizeB = b[i] if dimsB + i >= 0 else 1
370 
371                 # Assert sizeA == sizeB or sizeA == 1 or sizeB == 1
372                 expand_sizes[i] = sizeB if sizeA == 1 else sizeA
373 
374             return expand_sizes
375 
376         def AD_bmm_backward_self(grad, mat2):
377             return grad.bmm(mat2.transpose(1, 2))
378 
379         def AD_bmm_backward_mat2(grad, self):
380             return self.transpose(1, 2).bmm(grad)
381 
382         def bmm(self, mat2):
383             def backward(grad_output):
384                 grad_self = AD_bmm_backward_self(grad_output, mat2)
385                 grad_mat2 = AD_bmm_backward_mat2(grad_output, self)
386                 return grad_self, grad_mat2
387             return torch.bmm(self, mat2), backward
388     )",
389     R"(
390         def AD_mat_transpose(mat):
391             dim = mat.dim()
392             if dim == 1:
393                 out = mat
394             elif dim == 2:
395                 out = mat.t()
396             else:
397                 dims = rangelist(dim)
398                 dims[-1] = dim - 2
399                 dims[-2] = dim - 1
400                 out = mat.permute(dims)
401             return out
402 
403         # In matmul backward case of [b, m, n] * [b, n, p] => [m, p],
404         # instead of doing [b, m, p] and then reduce to [m, p]
405         # which potentially uses large intermediate of size b*m*p,
406         # we do [m, bn] * [bn, p] to avoid having the large
407         # intermediate, thus reduces max memory usage.
408         def AD_matmul_bw_special_fold(mat1, mat2):
409             mat1_transpose = AD_mat_transpose(mat1)
410             mat1_fold = mat1_transpose.reshape(-1, mat1_transpose.size()[-1])
411             mat2_fold = mat2.reshape(-1, mat2.size()[-1])
412             return mat1_fold.t().mm(mat2_fold)
413 
414         def AD_matmul_bw_size(mat1, mat2,
415                            out_size: List[int]):
416             dim1 = mat1.dim()
417             dim2 = mat2.dim()
418             dim_out = len(out_size)
419             if dim1 == 0 or dim2 == 0:
420                 out = mat1 * mat2
421             elif dim_out == 2 and dim1 == dim2 and dim1 >=3:
422                 out = AD_matmul_bw_special_fold(mat1, mat2)
423             elif dim_out == 1 and dim1 - dim2 == 1 and dim1 >= 3:
424                 mat2_unsqueeze = mat2.unsqueeze(-1)
425                 out = AD_matmul_bw_special_fold(mat1, mat2_unsqueeze)
426                 out = out.squeeze(-1)
427             elif dim1 + dim2 == dim_out:
428                 if dim2 == 1:
429                     target_dim2 = 0
430                 else:
431                     target_dim2 = -2
432                 out = torch.matmul(mat1.unsqueeze(dim1), mat2.unsqueeze(target_dim2))
433             elif dim_out == dim1 - dim2:
434                 out = torch.matmul(mat1, mat2.unsqueeze(dim2)).squeeze(-1)
435             elif dim_out == dim2 - dim1:
436                 out = torch.matmul(mat1.unsqueeze(-2), mat2).squeeze(-2)
437             else:
438                 out = torch.matmul(mat1, mat2)
439             return out
440 
441         def matmul(self, other):
442             def backward(grad_output):
443                 self_size = self.size()
444                 other_size = other.size()
445                 grad_self = AD_matmul_bw_size(grad_output, AD_mat_transpose(other), self_size)._grad_sum_to_size(self_size)
446                 grad_other = AD_matmul_bw_size(AD_mat_transpose(self), grad_output, other_size)._grad_sum_to_size(other_size)
447                 return grad_self, grad_other
448 
449             return torch.matmul(self, other), backward
450 
451         def linear(input : Tensor,
452                    weight : Tensor,
453                    bias : Optional[Tensor]):
454             result = torch.linear(input, weight, bias)
455 
456             def backward(grad_output):
457                 if bias is not None:
458                    grad_bias = grad_output._grad_sum_to_size(bias.size())
459                 else:
460                    grad_bias = None
461 
462                 weight_size = weight.size()
463                 grad_input = torch.matmul(grad_output, weight)
464                 grad_weight = torch.matmul(grad_output.reshape(-1, weight_size[0]).t(), input.reshape(-1, weight_size[1]))
465                 # Note: calling unchecked_unwrap_optional is only safe, when we
466                 #       directly return grad_bias directly back to bias.
467                 #       Because in the case where `bias is None`, unwrapped
468                 #       grad_bias would just be pruned away.
469                 return grad_input, grad_weight, grad_bias.unchecked_unwrap_optional
470             return result, backward
471     )",
472     R"(
473         def addcmul(self,
474                     tensor1,
475                     tensor2,
476                     *,
477                     value: number):
478             result = torch.addcmul(self, tensor1, tensor2, value=value)
479             self_size = torch._size_if_not_equal(self.size(), result.size())
480             tensor1_size = torch._size_if_not_equal(tensor1.size(), result.size())
481             tensor2_size = torch._size_if_not_equal(tensor2.size(), result.size())
482             def backward(grad_output):
483                 grad = grad_output * value
484                 grad_tensor1 = (grad * tensor2)._grad_sum_to_size(tensor1_size)
485                 grad_tensor2 = (grad * tensor1)._grad_sum_to_size(tensor2_size)
486                 return grad_output._grad_sum_to_size(self_size), grad_tensor1, grad_tensor2, None
487             return result, backward
488 
489         def _autocast_to_full_precision(self, cuda_enabled : bool, cpu_enabled : bool):
490             self_dtype = self.dtype
491             def backward(grad_output):
492                 return grad_output.to(self_dtype), None, None
493 
494             return torch._autocast_to_full_precision(self, cuda_enabled, cpu_enabled), backward
495 
496         def _autocast_to_reduced_precision(self,
497                                           cuda_enabled : bool,
498                                           cpu_enabled : bool,
499                                           cuda_dtype : int,
500                                           cpu_dtype : int):
501             self_dtype = self.dtype
502             def backward(grad_output):
503                 return grad_output.to(self_dtype), None, None, None, None
504 
505             return torch._autocast_to_reduced_precision(self, cuda_enabled, cpu_enabled, cuda_dtype, cpu_dtype), backward
506 
507         def _dim_arange(like,
508                         dim: int):
509             def backward(grad_output):
510                 return None, None
511 
512             return torch._dim_arange(like, dim), backward
513 
514         def contiguous(self, *, memory_format: int=0):
515             def backward(grad_output):
516                 return grad_output, None
517 
518             return self.contiguous(memory_format=memory_format), backward
519 
520         def dot(self, tensor):
521             def backward(grad_output):
522                 return grad_output * tensor, grad_output * self
523 
524             return torch.dot(self, tensor), backward
525 
526         def erf(self):
527             def backward(grad_output):
528                 # Precomputed constant C = 2.0 / math.sqrt(math.pi)
529                 C = 1.1283791670955126
530                 return C * torch.exp(- self * self) * grad_output
531 
532             return torch.erf(self), backward
533 
534         def expand(self,
535                    size: List[int],
536                    *,
537                    implicit: bool=False):
538             result = torch.expand(self, size, implicit=implicit)
539             self_size = torch._size_if_not_equal(self.size(), result.size())
540 
541             def backward(grad_output):
542                 return grad_output._grad_sum_to_size(self_size), None, None
543 
544             return result, backward
545 
546         def expand_as(self, other):
547             result = torch.expand_as(self, other)
548             self_size = torch._size_if_not_equal(self.size(), result.size())
549 
550             def backward(grad_output):
551                 return grad_output._grad_sum_to_size(self_size), None
552 
553             return result, backward
554 
555         def full_like(self,
556                       fill_value: float):
557             def backward(grad_output):
558                 return None, None
559 
560             return torch.full_like(self, fill_value, memory_format=1), backward
561 
562         def lerp_0(self,
563                    end,
564                    weight: number):
565             result = torch.lerp(self, end, weight)
566             self_size = torch._size_if_not_equal(self.size(), result.size())
567             end_size = torch._size_if_not_equal(end.size(), result.size())
568 
569             def backward(grad_output):
570                 grad_self = (grad_output * (1 - float(weight)))._grad_sum_to_size(self_size)
571                 grad_end = (grad_output * float(weight))._grad_sum_to_size(end_size)
572                 return grad_self, grad_end, None
573             return result, backward
574 
575         def lerp_1(self,
576                    end,
577                    weight):
578             result = torch.lerp(self, end, weight)
579             self_size = torch._size_if_not_equal(self.size(), result.size())
580             end_size = torch._size_if_not_equal(end.size(), result.size())
581             weight_size = torch._size_if_not_equal(weight.size(), result.size())
582 
583             def backward(grad_output):
584                 grad_self = (grad_output * (1 - weight))._grad_sum_to_size(self_size)
585                 grad_end = (grad_output * weight)._grad_sum_to_size(end_size)
586                 grad_weight = (grad_output * (end - self))._grad_sum_to_size(weight_size)
587                 return grad_self, grad_end, grad_weight
588 
589             return result, backward
590 
591         def reshape(self,
592                     shape: List[int]):
593             self_size = self.size()
594 
595             def backward(grad_output):
596                 return grad_output.reshape(self_size), None
597 
598             return torch.reshape(self, shape), backward
599 
600         def split(self,
601                   split_size: int,
602                   dim: int):
603             def backward(grad_outputs: List[Tensor]):
604                 grad_self = torch.cat(grad_outputs, dim)
605                 return grad_self, None, None
606 
607             return torch.split(self, split_size, dim), backward
608 
609         def split_with_sizes(self,
610                              split_sizes: List[int],
611                              dim: int):
612             def backward(grad_outputs: List[Tensor]):
613                 size = len(grad_outputs)
614                 grad_self = torch.cat(grad_outputs, dim)
615                 return grad_self, None, None
616 
617             return torch.split_with_sizes(self, split_sizes, dim), backward
618 
619         def stack(tensors: List[Tensor],
620                   dim: int=0):
621             def backward(grad_output):
622                 grad_tensors = torch.unbind(grad_output, dim)
623                 return grad_tensors, None
624 
625             return torch.stack(tensors, dim), backward
626 
627         def unbind(self,
628                    dim: int):
629             def backward(grad_outputs: List[Tensor]):
630                 grad_self = torch.stack(grad_outputs, dim)
631                 return grad_self, None
632 
633             return torch.unbind(self, dim), backward
634 
635         def cat(tensors: List[Tensor],
636                 dim: int):
637             size = len(tensors)
638             split_sizes = [0] * size
639             for i in range(size):
640                 if tensors[i].size() != [0]:
641                     split_sizes[i] = tensors[i].size()[dim]
642 
643             def backward(grad_output):
644                 grad_tensors = torch.split_with_sizes(grad_output, split_sizes, dim)
645                 return grad_tensors, None
646 
647             return torch.cat(tensors, dim), backward
648 
649         def index(self,
650                   indices: List[Tensor]):
651             def backward(grad_output):
652                 grad_self = torch.zeros_like(self, memory_format=1).index_put_(indices, grad_output, True)
653                 return grad_self, None
654 
655             return torch.index(self, indices), backward
656 
657         def meshgrid(tensors: List[Tensor]):
658             size = len(tensors)
659             sizes = [0] * size
660             for i in range(size):
661                 if tensors[i].dim() != 0:
662                     sizes[i] = tensors[i].size()[0]
663             def backward(grad_outputs: List[Tensor]):
664                 grads_tensors = []
665                 for i in range(size):
666                     view_shape = [1] * size
667                     if sizes[i] == 0:
668                         view_shape[i] = 1
669                         grads_tensors.append((grad_outputs[i]._grad_sum_to_size(view_shape)).reshape(()))
670                     else:
671                         view_shape[i] = sizes[i]
672                         grads_tensors.append((grad_outputs[i]._grad_sum_to_size(view_shape)).reshape([sizes[i]]))
673                 return grads_tensors
674             return torch.meshgrid(tensors), backward
675 
676         def mv(self, vec):
677             def backward(grad_output):
678                 return grad_output.ger(vec), self.t().mv(grad_output)
679 
680             return torch.mv(self, vec), backward
681 
682         def nonzero(self):
683             def backward(grad_output):
684                 return None
685 
686             return torch.nonzero(self), backward
687 
688         def ones_like(self):
689             def backward(grad_output):
690                 return None
691 
692             return torch.ones_like(self, memory_format=1), backward
693 
694         def pow_0(self,
695                   exponent: number):
696             def backward(grad_output):
697                 if float(exponent) == 0.0:
698                     grad_self = torch.zeros_like(self, memory_format=1)
699                 else:
700                     grad_self = grad_output * exponent * torch.pow(self, float(exponent) - 1)
701                 return grad_self, None
702 
703             return torch.pow(self, exponent), backward
704 
705         def pow_1(self, exponent):
706             result = torch.pow(self, exponent)
707             self_size = torch._size_if_not_equal(self.size(), result.size())
708             exponent_size = torch._size_if_not_equal(exponent.size(), result.size())
709 
710             def backward(grad_output):
711                 grad_self = torch.where(exponent == 0.0, torch.zeros_like(self, memory_format=1), grad_output * exponent * torch.pow(self, exponent - 1))._grad_sum_to_size(self_size)
712                 grad_exponent = (grad_output * torch.pow(self, exponent) * torch.log(self))._grad_sum_to_size(exponent_size)
713                 return grad_self, grad_exponent
714 
715             return result, backward
716 
717         def pow_2(self: number,
718                   exponent):
719             def backward(grad_output):
720                 grad_exponent = grad_output * torch.pow(self, exponent) * torch.log(float(self))
721                 return None, grad_exponent
722 
723             return torch.pow(self, exponent), backward
724 
725         def rsub_0(self,
726                    other,
727                    alpha: number):
728             result = torch.rsub(self, other, alpha=alpha)
729             self_size = torch._size_if_not_equal(self.size(), result.size())
730             other_size = torch._size_if_not_equal(other.size(), result.size())
731             def backward(grad_output):
732                 grad_self = (- grad_output * alpha)._grad_sum_to_size(self_size)
733                 grad_other = (grad_output)._grad_sum_to_size(other_size)
734                 return grad_self, grad_other, None
735 
736             return result, backward
737 
738         def rsub_1(self,
739                    other: number,
740                    alpha: number):
741             def backward(grad_output):
742                 grad_self = (- grad_output * alpha)
743                 return grad_self, None, None
744 
745             return torch.rsub(self, other, alpha), backward
746 
747         def sqrt(self):
748             result = torch.sqrt(self)
749             def backward(grad_output):
750                 return grad_output / (2 * result)
751 
752             return result, backward
753 
754         def t(self):
755             def backward(grad_output):
756                 return torch.t(grad_output)
757 
758             return torch.t(self), backward
759 
760         def to_0(self,
761                  device: Optional[Device],
762                  dtype: Optional[int],
763                  non_blocking: bool,
764                  copy: bool):
765             self_device = self.device
766             self_dtype = self.dtype
767             if device is not None:
768                 result = self.to(device, dtype=dtype, non_blocking=non_blocking, copy=copy)
769             else:
770                 result = self.to(dtype, non_blocking=non_blocking, copy=copy)
771             def backward(grad_output):
772                 grad_self = grad_output.to(self_device, dtype=self_dtype, non_blocking=non_blocking, copy=copy)
773                 return grad_self, None, None, None, None
774 
775             return result, backward
776 
777 
778         def to_1(self,
779                  dtype: int,
780                  non_blocking: bool,
781                  copy: bool):
782             self_dtype = self.dtype
783             def backward(grad_output):
784                 grad_self = grad_output.to(self_dtype, non_blocking, copy)
785                 return grad_self, None, None, None
786 
787             return self.to(dtype=dtype, non_blocking=non_blocking, copy=copy), backward
788 
789         def to_2(self,
790                  other,
791                  non_blocking: bool,
792                  copy: bool):
793             def backward(grad_output):
794                 grad_self = grad_output.to(self, non_blocking, copy)
795                 return grad_self, None, None, None
796 
797             return self.to(other, non_blocking=non_blocking, copy=copy), backward
798 
799         def transpose(self,
800                       dim0: int,
801                       dim1: int):
802             def backward(grad_output):
803                 return torch.transpose(grad_output, dim0, dim1), None, None
804 
805             return torch.transpose(self, dim0, dim1), backward
806 
807         def view(self,
808                  size: List[int]):
809             self_size = self.size()
810             def backward(grad_output):
811                 return grad_output.reshape(self_size), None
812 
813             return torch.view(self, size), backward
814     )",
815     R"(
816         def AD_sizes_if_not_equal_multi_0(t1, t2, res):
817             return torch._size_if_not_equal(t1.size(), res.size()), torch._size_if_not_equal(t2.size(), res.size())
818 
819         def mul_0(self, other):
820             result = self * other
821             self_size, other_size = AD_sizes_if_not_equal_multi_0(self, other, result)
822 
823             def backward(grad_output):
824                 grad_self = (grad_output * other)._grad_sum_to_size(self_size)
825                 grad_other = (grad_output * self)._grad_sum_to_size(other_size)
826                 return grad_self, grad_other
827 
828             return result, backward
829 
830         def mul_1(self, other: number):
831             def backward(grad_output):
832                 return grad_output * other, None
833             return self * other, backward
834 
835         def div_0(self, other):
836             result = self / other
837             self_size, other_size = AD_sizes_if_not_equal_multi_0(self, other, result)
838 
839             def backward(grad_output):
840                 grad_self = (grad_output / other)._grad_sum_to_size(self_size)
841                 grad_other = (-grad_output * self / (other * other))._grad_sum_to_size(other_size)
842                 return grad_self, grad_other
843 
844             return result, backward
845 
846         def div_1(self, other: number):
847             def backward(grad_output):
848                 return grad_output / other, None
849             return self / other, backward
850 
851         def div_2(self, other, *, rounding_mode: Optional[str]):
852             result = torch.div(self, other, rounding_mode=rounding_mode)
853             self_size, other_size = AD_sizes_if_not_equal_multi_0(self, other, result)
854             def backward(grad_output):
855                 if rounding_mode is None:
856                     grad_self = (grad_output / other)._grad_sum_to_size(self_size)
857                     grad_other = (-grad_output * self / (other * other))._grad_sum_to_size(other_size)
858                 else:
859                     grad_self = torch.zeros_like(self)
860                     grad_other = torch.zeros_like(other)
861 
862                 return grad_self, grad_other, None
863 
864             return result, backward
865 
866         def div_3(self, other: number, *, rounding_mode: Optional[str]):
867             result = torch.div(self, other, rounding_mode=rounding_mode)
868             def backward(grad_output):
869                 if rounding_mode is None:
870                     grad_self = (grad_output / other)
871                 else:
872                     grad_self = torch.zeros_like(self, memory_format=1)
873                 return grad_self, None, None
874             return result, backward
875 
876         def max(self, other):
877             result = torch.max(self, other)
878             self_size, other_size = AD_sizes_if_not_equal_multi_0(self, other, result)
879 
880             def backward(grad_output):
881                 grad_self = (grad_output * (self > other).type_as(grad_output))._grad_sum_to_size(self_size)
882                 grad_other = (grad_output * (other > self).type_as(grad_output))._grad_sum_to_size(other_size)
883                 return grad_self, grad_other
884 
885             return result, backward
886 
887         def min(self, other):
888             def backward(grad_output):
889                 grad_self = (grad_output * (self < other).type_as(grad_output))._grad_sum_to_size(self.size())
890                 grad_other = (grad_output * (other < self).type_as(grad_output))._grad_sum_to_size(other.size())
891                 return grad_self, grad_other
892 
893             return torch.min(self, other), backward
894 
895         def sigmoid(self):
896             result = torch.sigmoid(self)
897             def backward(grad_output):
898                 return (1 - result) * result * grad_output
899 
900             return result, backward
901 
902         # Share backward with threshold
903         def relu(self):
904             result = torch.relu(self)
905             def backward(grad_output):
906                 return grad_output * (result > 0).type_as(result)
907 
908             return result, backward
909 
910         def relu6(self):
911             result = torch.relu6(self)
912             def backward(grad_output):
913                 return grad_output * ((result > 0) & (result < 6.0))
914 
915             return result, backward
916 
917         def leaky_relu(self, negative_slope: number):
918             result = torch.leaky_relu(self, negative_slope)
919             def backward(grad_output):
920                 return grad_output * torch.where(self > 0, 1.0, negative_slope).type_as(result), None
921             return result, backward
922 
923         def gelu(self : Tensor, *, approximate : str):
924             result = torch.gelu(self, approximate=approximate)
925             def backward(grad_output):
926                 return torch.gelu_backward(grad_output, self, approximate=approximate), None
927             return result, backward
928 
929         def silu(self):
930             result = torch.silu(self)
931             def backward(grad_output):
932                 input_sigmoid = torch.sigmoid(self)
933                 return grad_output * (input_sigmoid * (1 + self * (1 - input_sigmoid)))
934             return result, backward
935 
936         def hardswish(self):
937             result = torch.hardswish(self)
938             def backward(grad_output):
939                 m = (self > 3.).type_as(result)
940                 m = torch.where((self >= -3.) & (self <= 3.),  self / 3. + .5, m)
941                 return grad_output * m
942             return result, backward
943 
944         def hardsigmoid(self):
945             result = torch.hardsigmoid(self)
946             def backward(grad_output):
947                 m = (self > -3.) & (self < 3.)
948                 lhs = grad_output * (1.0 / 6.0)
949                 return torch.where(m, lhs, m.type_as(self))
950             return result, backward
951 
952         def erfc(self):
953             def backward(grad_output):
954                 # Precomputed constant C = -2.0 / math.sqrt(math.pi)
955                 C = -1.1283791670955126
956                 return C * torch.exp(-self * self) * grad_output
957 
958             return torch.erfc(self), backward
959 
960         def exp(self):
961             result = torch.exp(self)
962             def backward(grad_output):
963                 return grad_output * result
964 
965             return result, backward
966 
967         def neg(self):
968             def backward(grad_output):
969                 return grad_output.neg()
970 
971             return torch.neg(self), backward
972 
973         def where(condition, self, other):
974             result = torch.where(condition, self, other)
975             self_size, other_size = AD_sizes_if_not_equal_multi_0(self, other, result)
976             def backward(grad_output):
977                 grad_self = (grad_output * condition.type_as(grad_output))._grad_sum_to_size(self_size)
978                 grad_other = (grad_output * (condition.bitwise_not()).type_as(grad_output))._grad_sum_to_size(other_size)
979                 return None, grad_self, grad_other
980 
981             return result, backward
982 
983         def type_as(self, other):
984             def backward(grad_output):
985                 return grad_output.type_as(self), None
986 
987             return torch.type_as(self, other), backward
988 
989         def unsqueeze(self, dim: int):
990             def backward(grad_output):
991                 return grad_output.squeeze(dim), None
992 
993             return torch.unsqueeze(self, dim), backward
994 
995         def abs(self):
996             def backward(grad_output):
997                 return grad_output * self.sign()
998 
999             return torch.abs(self), backward
1000 
1001         def acos(self):
1002             def backward(grad_output):
1003                 return grad_output * -((-self * self + 1).rsqrt())
1004 
1005             return torch.acos(self), backward
1006 
1007         def asin(self):
1008             def backward(grad_output):
1009                 return grad_output * (-self * self + 1).rsqrt()
1010 
1011             return torch.asin(self), backward
1012 
1013         def atan(self):
1014             def backward(grad_output):
1015                 return grad_output / (self * self + 1)
1016 
1017             return torch.atan(self), backward
1018 
1019         def ceil(self):
1020             def backward(grad_output):
1021                 return torch.zeros_like(grad_output, memory_format=1)
1022 
1023             return torch.ceil(self), backward
1024 
1025         def cos(self):
1026             def backward(grad_output):
1027                 return grad_output * -self.sin()
1028 
1029             return torch.cos(self), backward
1030 
1031         def cosh(self):
1032             def backward(grad_output):
1033                 return grad_output * self.sinh()
1034 
1035             return torch.cosh(self), backward
1036 
1037         def expm1(self):
1038             result = torch.expm1(self)
1039             def backward(grad_output):
1040                 return grad_output * (result + 1)
1041 
1042             return result, backward
1043 
1044         def floor(self):
1045             def backward(grad_output):
1046                 return torch.zeros_like(grad_output, memory_format=1)
1047 
1048             return torch.floor(self), backward
1049 
1050         def frac(self):
1051             def backward(grad_output):
1052                 return grad_output
1053 
1054             return torch.frac(self), backward
1055 
1056         def log(self):
1057             def backward(grad_output):
1058                 return grad_output.div(self)
1059 
1060             return torch.log(self), backward
1061 
1062         def log10(self):
1063             def backward(grad_output):
1064                 return grad_output / (self * 2.3025850929940456)
1065 
1066             return torch.log10(self), backward
1067 
1068         def log1p(self):
1069             def backward(grad_output):
1070                 return grad_output / (self + 1)
1071 
1072             return torch.log1p(self), backward
1073 
1074         def log2(self):
1075             def backward(grad_output):
1076                 return grad_output / (self * 0.6931471805599453)
1077 
1078             return torch.log2(self), backward
1079 
1080         # TODO: Fix rand_like to match expected format
1081         # def rand_like(self, *, memory_format: Optional[int]):
1082         #    def backward(grad_output):
1083         #        return None
1084 
1085         #    return torch.rand_like(self, memory_format=memory_format), backward
1086 
1087         def reciprocal(self):
1088             result = torch.reciprocal(self)
1089             def backward(grad_output):
1090                 return -grad_output * result * result
1091 
1092             return result, backward
1093 
1094         def round(self):
1095             def backward(grad_output):
1096                 return torch.zeros_like(grad_output, memory_format=1)
1097 
1098             return torch.round(self), backward
1099 
1100         def rsqrt(self):
1101             result = torch.rsqrt(self)
1102             def backward(grad_output):
1103                 return -grad_output * result * result * result / 2
1104 
1105             return result, backward
1106 
1107         def sin(self):
1108             def backward(grad_output):
1109                 return grad_output * self.cos()
1110 
1111             return torch.sin(self), backward
1112 
1113         def sinh(self):
1114             def backward(grad_output):
1115                 return grad_output * self.cosh()
1116 
1117             return torch.sinh(self), backward
1118 
1119         def tan(self):
1120             result = torch.tan(self)
1121             def backward(grad_output):
1122                 return grad_output * (1. + result * result)
1123 
1124             return result, backward
1125 
1126         def trunc(self):
1127             def backward(grad_output):
1128                 return torch.zeros_like(grad_output, memory_format=1)
1129 
1130             return torch.trunc(self), backward
1131 
1132         def _grad_sum_to_size(self,
1133                               size: Optional[List[int]]):
1134             result = torch._grad_sum_to_size(self, size)
1135             self_size = torch._size_if_not_equal(self.size(), result.size())
1136 
1137             def backward(grad_output):
1138                 if self_size is None:
1139                     grad_input = grad_output
1140                 else:
1141                     grad_input = grad_output.expand(self_size)
1142                 return grad_input, None
1143 
1144             return result, backward
1145     )",
1146     R"(
1147         def batch_norm(input : Tensor,
1148                        weight : Optional[Tensor],
1149                        bias : Optional[Tensor],
1150                        running_mean : Optional[Tensor],
1151                        running_var : Optional[Tensor],
1152                        training : bool,
1153                        momentum : float,
1154                        eps : float,
1155                        cudnn_enabled : bool):
1156 
1157             output, save1, save2, reserve, impl_idx = torch._batch_norm_impl_index(
1158                 input, weight, bias, running_mean, running_var, training,
1159                 momentum, eps, cudnn_enabled)
1160             has_weight = weight is not None
1161             has_bias = bias is not None
1162 
1163             def backward(grad_output):
1164                 dinput, dweight, dbias = torch._batch_norm_impl_index_backward(
1165                     impl_idx, input, grad_output, weight, running_mean, running_var,
1166                     save1, save2, training, eps, [True, has_weight, has_bias], reserve)
1167                 return dinput, dweight, dbias, None, None, None, None, None, None
1168 
1169             return output, backward
1170 
1171         def layer_norm(input : Tensor,
1172                        normalized_shape : List[int],
1173                        weight : Optional[Tensor],
1174                        bias : Optional[Tensor],
1175                        eps : float,
1176                        cudnn_enable : bool):
1177 
1178             output, mean, rstd = torch.native_layer_norm(input, normalized_shape, weight, bias, eps)
1179 
1180             def backward(grad_output):
1181                 output_mask = [True, weight is not None, bias is not None]
1182                 grad_input, grad_weight, grad_bias = torch.native_layer_norm_backward(grad_output, input, normalized_shape, mean, rstd, weight, bias, output_mask)
1183                 return grad_input, None, grad_weight, grad_bias, None, None
1184             return output, backward
1185 
1186         def dropout(input,
1187                     p: float,
1188                     train: bool):
1189             # if `train == false` we need to set `p1m` to 0 so `scale == 1`
1190             p1m = (1. - p) * float(train)
1191             scale = 1. / (float(p1m == 0.) + p1m)
1192             res,mask = torch.native_dropout(input, p, train)
1193 
1194             def backward(grad_output):
1195                 grad_input = torch.native_dropout_backward(grad_output, mask, scale)
1196                 return grad_input, None, None
1197             return res, backward
1198 
1199         def embedding(weight,
1200                       indices,
1201                       padding_idx: int,
1202                       scale_grad_by_freq: bool,
1203                       sparse: bool):
1204             weight_size_0 = weight.size()[0]
1205             def backward(grad_output):
1206                 grad_weight = torch.embedding_backward(grad_output, indices, weight_size_0, padding_idx, scale_grad_by_freq, sparse)
1207                 return grad_weight, None, None, None, None
1208 
1209             return torch.embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse), backward
1210 
1211         def log_softmax(self, dim: int, dtype: Optional[int]):
1212             result = torch.log_softmax(self, dim, dtype)
1213             def backward(grad_output):
1214                 grad_self = torch._log_softmax_backward_data(grad_output, result, dim, self.dtype)
1215                 return grad_self, None, None
1216 
1217             return result, backward
1218 
1219         def nll_loss(self, target, weight: Optional[Tensor], reduction: int, ignore_index: int):
1220             result, total_weight = torch.nll_loss_forward(self, target, weight, reduction, ignore_index)
1221             def backward(grad):
1222                 return torch.nll_loss_backward(grad, self, target, weight, reduction, ignore_index, total_weight), None, None, None, None
1223             return result, backward
1224 
1225         def softmax(self, dim: int, dtype: Optional[int]):
1226             result = torch.softmax(self, dim, dtype)
1227             def backward(grad_output):
1228                 grad_self = torch._softmax_backward_data(grad_output, result, dim, self.dtype)
1229                 return grad_self, None, None
1230 
1231             return result, backward
1232     )",
1233     R"(
1234         def AD_adaptive_avg_pool3d_backward(grad,
1235                                             self,
1236                                             output_size: List[int]):
1237             if output_size[0] == 1 and output_size[1] == 1 and output_size[2] == 1:
1238                 self_size = self.size()
1239                 grad_self = grad.expand(self.size()) / (self_size[-1] * self_size[-2] * self_size[-3])
1240             else:
1241                 grad_self = torch._adaptive_avg_pool3d_backward(grad, self)
1242 
1243             return grad_self
1244 
1245         def AD_adaptive_avg_pool2d_backward(grad,
1246                                             self,
1247                                             output_size: List[int]):
1248             if output_size[0] == 1 and output_size[1] == 1:
1249                 self_size = self.size()
1250                 grad_self = grad.expand(self.size()) / (self_size[-1] * self_size[-2])
1251             else:
1252                 grad_self = torch._adaptive_avg_pool2d_backward(grad, self)
1253 
1254             return grad_self
1255 
1256         def AD_adaptive_avg_pool1d_backward(grad,
1257                                             input,
1258                                             output_size: List[int]):
1259             output_size_2d = [1, output_size[0]]
1260             grad_input = AD_adaptive_avg_pool2d_backward(grad.unsqueeze(2), input.unsqueeze(2), output_size_2d).squeeze(2)
1261             return grad_input
1262 
1263         def adaptive_avg_pool1d(self,
1264                                 output_size: List[int]):
1265             def backward(grad_output):
1266                 grad_self = AD_adaptive_avg_pool1d_backward(grad_output, self, output_size)
1267                 return grad_self, None
1268 
1269             return torch.adaptive_avg_pool1d(self, output_size), backward
1270 
1271         def adaptive_avg_pool2d(self,
1272                                 output_size: List[int]):
1273             def backward(grad_output):
1274                 # self is used in backward, no need to pass in its size explicitly
1275                 grad_self = AD_adaptive_avg_pool2d_backward(grad_output, self, output_size)
1276                 return grad_self, None
1277             return torch.adaptive_avg_pool2d(self, output_size), backward
1278 
1279         def adaptive_avg_pool3d(self,
1280                                 output_size: List[int]):
1281             def backward(grad_output):
1282                 grad_self = AD_adaptive_avg_pool3d_backward(grad_output, self, output_size)
1283                 return grad_self, None
1284 
1285             return torch.adaptive_avg_pool3d(self, output_size), backward
1286 
1287         def avg_pool2d(self,
1288                        kernel_size: List[int],
1289                        stride: List[int],
1290                        padding: List[int],
1291                        ceil_mode: bool,
1292                        count_include_pad: bool,
1293                        divisor_override: Optional[int]):
1294             def backward(grad_output):
1295                 grad_self = torch.avg_pool2d_backward(grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
1296                 return grad_self, None, None, None, None, None, None
1297 
1298             return torch.avg_pool2d(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override), backward
1299 
1300         def max_pool2d(self,
1301                        kernel_size: List[int],
1302                        stride: List[int],
1303                        padding: List[int],
1304                        dilation: List[int],
1305                        ceil_mode: bool):
1306             output, indices = torch.max_pool2d_with_indices(self, kernel_size, stride, padding, dilation, ceil_mode)
1307             def backward(grad_output):
1308                 grad_self = torch.max_pool2d_with_indices_backward(grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices)
1309                 return grad_self, None, None, None, None, None
1310             return output, backward
1311 
1312         def max_pool2d_with_indices(self,
1313                                     kernel_size: List[int],
1314                                     stride: List[int],
1315                                     padding: List[int],
1316                                     dilation: List[int],
1317                                     ceil_mode: bool):
1318             output, indices = torch.max_pool2d_with_indices(self, kernel_size, stride, padding, dilation, ceil_mode)
1319             def backward(grad_output):
1320                 grad_self = torch.max_pool2d_with_indices_backward(grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices)
1321                 return grad_self, None, None, None, None, None
1322             return output, indices, backward
1323       )",
1324     R"(
1325         def AD_sizes_if_not_equal_multi_1(t1, t2, res):
1326             return torch._size_if_not_equal(t1.size(), res.size()), torch._size_if_not_equal(t2.size(), res.size())
1327 
1328         def add_0(self,
1329                   other,
1330                   *,
1331                   alpha: number):
1332             result = torch.add(self, other, alpha=alpha)
1333             self_size, other_size = AD_sizes_if_not_equal_multi_1(self, other, result)
1334             def backward(grad_output):
1335                 grad_other = (grad_output * alpha)._grad_sum_to_size(other_size)
1336                 grad_self = (grad_output)._grad_sum_to_size(self_size)
1337                 return grad_self, grad_other, None
1338             return result, backward
1339 
1340         def add_1(self,
1341                   other: number,
1342                   alpha: number):
1343             def backward(grad_output):
1344                 return grad_output, None, None
1345             return torch.add(self, other, alpha=alpha), backward
1346 
1347         def sub_0(self,
1348                   other,
1349                   *,
1350                   alpha: number):
1351             result = torch.sub(self, other, alpha=alpha)
1352             self_size, other_size = AD_sizes_if_not_equal_multi_1(self, other, result)
1353             def backward(grad_output):
1354                 grad_other = (-grad_output * alpha)._grad_sum_to_size(other_size)
1355                 grad_self = (grad_output)._grad_sum_to_size(self_size)
1356                 return grad_self, grad_other, None
1357             return result , backward
1358 
1359         def sub_1(self,
1360                   other: number,
1361                   alpha: number):
1362             def backward(grad_output):
1363                 return grad_output, None, None
1364             return torch.sub(self, other, alpha=alpha), backward
1365 
1366         def threshold(self,
1367                       threshold: number,
1368                       value: number):
1369             def backward(grad_output):
1370                 mask = (self >= threshold).type_as(self)
1371                 return grad_output * mask, None, None
1372             return torch.threshold(self, threshold, value), backward
1373 
1374         def softplus(self,
1375                       beta: number,
1376                       threshold: number):
1377             result = torch.softplus(self, beta, threshold)
1378             def backward(grad_output):
1379                 z = torch.exp(result * beta)
1380                 return torch.where((result * beta) > threshold, grad_output, grad_output * (z - 1.) / z), None, None
1381             return result, backward
1382 
1383         def fmod(self,
1384                  other: number):
1385             def backward(grad_output):
1386                 return grad_output, None
1387             return torch.fmod(self, other), backward
1388 
1389         def remainder(self,
1390                       other: number):
1391             def backward(grad_output):
1392                 return grad_output, None
1393             return torch.remainder(self, other), backward
1394 
1395         def addmm(self,
1396                   mat1,
1397                   mat2,
1398                   *,
1399                   beta: number,
1400                   alpha: number):
1401             result = torch.addmm(self, mat1, mat2, beta=beta, alpha=alpha)
1402             self_size = torch._size_if_not_equal(self.size(), result.size())
1403             def backward(grad_output):
1404                 self_grad = (grad_output * beta)._grad_sum_to_size(self_size)
1405                 mat1_grad = grad_output.mm(mat2.t()) * alpha
1406                 mat2_grad = mat1.t().mm(grad_output) * alpha
1407                 return self_grad, mat1_grad, mat2_grad, None, None
1408             return result, backward
1409 
1410         # Comparison operators
1411         def lt(self, other: number):
1412             def backward(grad_output):
1413                 return None, None
1414             return torch.lt(self, other), backward
1415 
1416         def le(self, other: number):
1417             def backward(grad_output):
1418                 return None, None
1419             return torch.le(self, other), backward
1420 
1421         def gt(self, other: number):
1422             def backward(grad_output):
1423                 return None, None
1424             return torch.gt(self, other), backward
1425 
1426         def ge(self, other: number):
1427             def backward(grad_output):
1428                 return None, None
1429             return torch.ge(self, other), backward
1430 
1431         def eq(self, other: number):
1432             def backward(grad_output):
1433                 return None, None
1434             return torch.eq(self, other), backward
1435 
1436         def ne(self, other: number):
1437             def backward(grad_output):
1438                 return None, None
1439             return torch.ne(self, other), backward
1440 
1441         def hardshrink(self, lambd: number):
1442           def backward(grad_output):
1443             mask = ((self > lambd) | (self < -lambd))
1444             return grad_output * mask, None
1445           return torch.hardshrink(self, lambd=lambd), backward
1446 
1447         def hardtanh(self, min_val: number, max_val: number):
1448           def backward(grad_output):
1449             mask = ((self >= min_val) * (self <= max_val))
1450             return grad_output * mask, None, None
1451           return torch.hardtanh(self, min_val=min_val, max_val=max_val), backward
1452 
1453         def clamp_1(self,
1454                     min: Optional[number],
1455                     max: Optional[number]):
1456           def backward(grad_output):
1457             if min is not None and max is not None:
1458                 mask = ((self >= float(min)) * (self <= float(max))).type_as(self)
1459                 return grad_output * mask, None, None
1460             elif min is not None:
1461                 mask = (self >= float(min)).type_as(self)
1462                 return grad_output * mask, None, None
1463             elif max is not None:
1464                 mask = (self <= float(max)).type_as(self)
1465                 return grad_output * mask, None, None
1466             else: #min is None and max is None
1467                 return grad_output, None, None
1468           return torch.clamp(self, min=min, max=max), backward
1469 
1470         def clamp_2(self,
1471                     min: Optional[Tensor],
1472                     max: Optional[Tensor]):
1473           def backward(grad_output):
1474             if min is not None and max is not None:
1475                 mask = ((self >= min) * (self <= max)).type_as(self)
1476                 return grad_output * mask, None, None
1477             elif min is not None:
1478                 mask = (self >= min).type_as(self)
1479                 return grad_output * mask, None, None
1480             elif max is not None:
1481                 mask = (self <= max).type_as(self)
1482                 return grad_output * mask, None, None
1483             else: #min is None and max is None
1484                 return grad_output, None, None
1485           return torch.clamp(self, min=min, max=max), backward
1486     )"};
1487 
1488 std::unordered_map<std::string, GradientPair> schema_to_graphs;
1489 
1490 // This map is a workaround to cache compiled gradient_pairs. Ideally this graph
1491 // should be compiled only once and saved in Operator structure.
1492 // This should be done along with merging into native_functions.yaml.
1493 std::unordered_map<const FunctionSchema*, GradientPair> cached_gradient_pairs;
1494 
1495 // CompilationUnit that holds all these Functions and keeps them alive.
1496 CompilationUnit compilation_unit;
1497 } // anonymous namespace
1498 
extractClosure(Value * closure)1499 static std::pair<std::shared_ptr<Graph>, Value*> extractClosure(
1500     Value* closure) {
1501   TORCH_CHECK(
1502       closure->node()->kind() == prim::TupleConstruct,
1503       "closure must be a literal tuple construct");
1504   Value* fn = closure->node()->inputs().at(0);
1505   Value* context = closure->node()->inputs().at(1);
1506 
1507   TORCH_CHECK(
1508       fn->node()->kind() == prim::Closure,
1509       "closure tuple must contain a prim::Closure");
1510   return std::make_pair(fn->node()->g(attr::Subgraph), context);
1511 }
1512 
originalReturnType(const TupleTypePtr & tup)1513 static Argument originalReturnType(const TupleTypePtr& tup) {
1514   TORCH_CHECK(tup->elements().size() > 1);
1515   if (tup->elements().size() == 2)
1516     return Argument("", tup->elements().at(0));
1517   std::vector<TypePtr> types = tup->elements().vec();
1518   types.pop_back();
1519   return Argument("", TupleType::create(std::move(types)));
1520 }
1521 
1522 // In torchscript AD formulas, we define {func_0, func_1, ...} as
1523 // overloaded functions of `func`.
1524 // Remove the suffix before adding the schema string to map
1525 // schema_to_graphs.
overloadedSchemaString(const FunctionSchema & schema)1526 static std::string overloadedSchemaString(const FunctionSchema& schema) {
1527   const auto& schema_name = schema.name();
1528   auto pos = schema_name.find_last_of('_');
1529   auto schema_name_suffix = schema_name.substr(pos + 1);
1530   std::string schema_string = canonicalSchemaString(schema);
1531   if (!schema_name_suffix.empty() &&
1532       schema_name_suffix.find_first_not_of("0123456789") == std::string::npos) {
1533     schema_string.replace(
1534         schema_string.find(schema_name),
1535         schema_name.length(),
1536         schema_name.substr(0, pos));
1537   }
1538 
1539   return schema_string;
1540 }
1541 
isHelperFunction(const std::string & method_name)1542 static bool isHelperFunction(const std::string& method_name) {
1543   std::string helper_prefix = "AD_";
1544   return method_name.compare(0, helper_prefix.length(), helper_prefix) == 0;
1545 }
1546 
loadModule(const CompilationUnit & module)1547 static void loadModule(const CompilationUnit& module) {
1548   for (const auto& method : module.get_functions()) {
1549     if (isHelperFunction(method->name()))
1550       continue;
1551 
1552     GradientPair pair;
1553     pair.forward = toGraphFunction(*method).graph();
1554 
1555     // lookup the backward function
1556     Node* forward_tuple = pair.forward->outputs().at(0)->node();
1557 
1558     if (forward_tuple->kind() != prim::TupleConstruct) {
1559       throw(
1560           ErrorReport(forward_tuple->sourceRange())
1561           << "gradient must return literal a tuple");
1562     }
1563 
1564     Value* context = nullptr;
1565     std::tie(pair.backward, context) =
1566         extractClosure(forward_tuple->inputs().back());
1567 
1568     // checks that num forward graph inputs equals num backward graph outputs
1569     TORCH_CHECK(
1570         pair.forward->inputs().size() ==
1571             unpackOutputs(pair.backward->outputs().vec()).size(),
1572         "The autodiff implementation of ",
1573         method->name(),
1574         " backward() returns an incorrect number of values: ",
1575         unpackOutputs(pair.backward->outputs().vec()).size(),
1576         " instead of ",
1577         pair.forward->inputs().size());
1578 
1579     // do surgery on the forward function to remove the closure tuple and
1580     // replace it with the context variable:
1581     //  backward = (<lambda>, context_tuple)
1582     //  return original, backward
1583     //  -----
1584     //  return original, context_tuple
1585     std::vector<Value*> new_inputs = forward_tuple->inputs().vec();
1586     new_inputs.back() = context;
1587     Value* new_tuple =
1588         pair.forward->appendNode(pair.forward->createTuple(new_inputs))
1589             ->output();
1590     pair.forward->eraseOutput(0);
1591     pair.forward->registerOutput(new_tuple);
1592     forward_tuple->destroy();
1593 
1594     // derive schema from original function's schema:
1595     const FunctionSchema& loaded_schema = method->getSchema();
1596     FunctionSchema actual_schema(
1597         Symbol::aten(loaded_schema.name()),
1598         loaded_schema.overload_name(),
1599         loaded_schema.arguments(),
1600         {originalReturnType(new_tuple->type()->expect<TupleType>())});
1601 
1602     // modify canonical string for function overloading
1603     // prefer not to modify the schema name
1604     auto schema_string = overloadedSchemaString(actual_schema);
1605 
1606     schema_to_graphs[schema_string] = std::move(pair);
1607   }
1608 }
1609 
loadFunctions()1610 static void loadFunctions() {
1611   for (const std::string& str : functions) {
1612     compilation_unit.define(std::nullopt, str, nativeResolver(), nullptr);
1613   }
1614   loadModule(compilation_unit);
1615 }
1616 
gradientInfoForSchema(const FunctionSchema & schema)1617 std::optional<GradientPair> gradientInfoForSchema(
1618     const FunctionSchema& schema) {
1619   std::lock_guard<std::mutex> guard(lock);
1620   if (schema_to_graphs.empty()) {
1621     loadFunctions();
1622   }
1623   auto cache_it = cached_gradient_pairs.find(&schema);
1624   if (cache_it != cached_gradient_pairs.end()) {
1625     return cache_it->second;
1626   } else {
1627     auto schema_str = canonicalSchemaString(schema);
1628     // For debugging AD change:
1629     // std::cout << "Looking for " << schema_str << std::endl;
1630     auto sym_script_it = schema_to_graphs.find(schema_str);
1631 
1632     if (sym_script_it != schema_to_graphs.end()) {
1633       cached_gradient_pairs.emplace_hint(
1634           cache_it, &schema, sym_script_it->second);
1635       return sym_script_it->second;
1636     }
1637   }
1638   return std::nullopt;
1639 }
1640 
hasGradientInfoForSchema(const FunctionSchema & schema)1641 bool hasGradientInfoForSchema(const FunctionSchema& schema) {
1642   return gradientInfoForSchema(schema).has_value();
1643 }
1644 
1645 } // namespace torch::jit
1646