xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/jit_metaprogramming_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3# Torch
4from torch.jit.annotations import BroadcastingList2, BroadcastingList3  # noqa: F401
5import torch.nn.functional as F
6import torch
7import torch.cuda
8import torch.jit
9import torch.jit._logging
10import torch.jit.frontend
11from torch.testing._internal.common_nn import module_tests, new_module_tests
12from torch.testing._internal.common_utils import is_iterable_of_tensors, noncontiguous_like
13
14import collections
15from copy import deepcopy
16from typing import Any, Dict, List, Union
17import math  # noqa: F401
18
19# Testing utils
20from torch import inf
21
22assert torch.get_default_dtype() == torch.float32
23
24L = 20
25M = 10
26S = 5
27
28
29def unpack_variables(args):
30    if isinstance(args, tuple):
31        return tuple(unpack_variables(elem) for elem in args)
32    else:
33        return args
34
35class dont_convert(tuple):
36    pass
37
38non_differentiable = collections.namedtuple('non_differentiable', ['tensor'])
39
40def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwargs=None, dtype=torch.float, device=None):
41    if not isinstance(call_args, tuple):
42        call_args = (call_args,)
43
44    def map_arg(arg):
45        def maybe_non_contig(tensor):
46            if not non_contiguous or tensor.numel() < 2:
47                return tensor.clone()
48
49            return noncontiguous_like(tensor)
50
51        def conjugate(tensor):
52            return tensor.conj()
53
54        if isinstance(arg, (torch.Size, dont_convert)):
55            return arg
56        elif isinstance(arg, tuple) and len(arg) == 0:
57            var = conjugate(torch.randn((), dtype=dtype, device=device))
58            var.requires_grad = requires_grad
59            return var
60        elif isinstance(arg, tuple) and not isinstance(arg[0], torch.Tensor):
61            return conjugate(maybe_non_contig(torch.randn(*arg, dtype=dtype, device=device))).requires_grad_(requires_grad)
62        # double check casting
63        elif isinstance(arg, non_differentiable):
64            if isinstance(arg.tensor, torch.Tensor):
65                return conjugate(maybe_non_contig(arg.tensor.to(device=device)))
66            return conjugate(maybe_non_contig(arg.tensor.to(device=device)))
67        elif isinstance(arg, torch.Tensor):
68            if arg.is_complex() != dtype.is_complex:
69                raise RuntimeError("User provided tensor is real for a test that runs with complex dtype, ",
70                                   "which is not supported for now")
71            # NOTE: We do clone() after detach() here because we need to be able to change size/storage of v afterwards
72            v = conjugate(maybe_non_contig(arg)).detach().to(device=device).clone()
73            v.requires_grad = requires_grad and (v.is_floating_point() or v.is_complex())
74            return v
75        elif callable(arg):
76            return map_arg(arg(dtype=dtype, device=device))
77        else:
78            return arg
79    args_out = tuple(map_arg(arg) for arg in call_args)
80    kwargs_out = {k: map_arg(v) for k, v in call_kwargs.items()} if call_kwargs else {}
81    return args_out, kwargs_out
82
83# NB: JIT script tests for all nn functional interfaces, script mode does
84# not support in_place operations yet, so no inplace operation tests added.
85# removed all the deprecated functions
86#
87# (
88#   method name,
89#   input size/constructing fn,
90#   args (tuple represents shape of a tensor arg),
91#   test variant name(will be used at test name suffix,
92#       'inplace' skips grad tests),                         // optional
93#   (True, nonfusible_nodes, fusible_nodes) for autodiff     // optional
94#   fn to determine if test should be skipped,               // optional
95#   fn mapping output to part that should be gradcheck'ed,   // optional
96#   kwargs for function,                                     // optional
97# )
98nn_functional_tests = [
99    ('conv1d', (S, S, S), ((S, S, S),)),
100    ('conv2d', (S, S, S, S), ((S, S, S, S),)),
101    ('conv3d', (S, S, S, S, S), ((S, S, S, S, S),)),
102    ('conv_transpose1d', (S, S, S), ((S, S, S),)),
103    ('conv_transpose2d', (S, S, S, S), ((S, S, S, S),)),
104    ('conv_transpose3d', (S, S, S, S, S), ((S, S, S, S, S),)),
105    ('conv_tbc', (S, S, S), ((S, S, S), (S,), 2)),
106    ('avg_pool1d', (S, S, S), (3,)),
107    ('avg_pool2d', (S, S, S, S), (3,), '', (True,)),
108    ('avg_pool3d', (S, S, S, S, S), (3,)),
109    ('fractional_max_pool2d', (S, S, S, S), (3, [2, 3],)),
110    ('max_pool1d', (S, S, S), (2, 1)),
111    ('max_pool1d', (S, S, S), (2, 1, 1, 1, False, True), 'with_indices'),
112    ('max_pool2d', (S, S, S, S), (2, 1), '', (True, 'aten::max_pool2d_with_indices')),
113    ('max_pool2d', (S, S, S, S), (2, 1, 1, 1, False, True), 'with_indices', (True, 'aten::max_pool2d_with_indices')),
114    ('max_pool3d', (S, S, S, S, S), (2, 1)),
115    ('max_unpool1d', torch.tensor([[[2., 4]]]), (torch.tensor([[[1, 3]]]), 2, 2, 0)),
116    ('max_unpool2d', torch.tensor([[[[2., 4]]]]), (torch.tensor([[[[1, 3]]]]), 2, 2, 0)),
117    ('max_unpool3d', torch.tensor([[[[[2., 4]]]]]), (torch.tensor([[[[[1, 3]]]]]), 2, 2, 0)),
118    ('lp_pool1d', (S, S, S), (2., 3, 2,)),
119    ('lp_pool2d', (S, S, S, S), (2., 3, 2,)),
120    ('lp_pool3d', (S, S, S, S, S), (2., 3, 2,)),
121    ('adaptive_max_pool1d', (S, S, S), (5,)),
122    ('adaptive_max_pool2d', (S, S, S, S), ([5, 7],)),
123    ('adaptive_max_pool3d', (S, S, S, S, S), ([3, 2, 2],)),
124    ('adaptive_avg_pool1d', (S, S, S), (5,), '', (True,)),
125    ('adaptive_avg_pool2d', (S, S, S, S), ([5, 7],), '', (True,)),
126    ('adaptive_avg_pool3d', (S, S, S, S, S), ([3, 2, 2],), '', (True,)),
127    ('dropout', (S, S, S), (0.5,), '', (True, 'aten::native_dropout')),
128    ('alpha_dropout', (S, S, S), (0.5,)),
129    ('dropout2d', (S, S, S), (0.5,)),
130    ('dropout2d', (S, S, S, S), (0.5,), 'batched'),
131    ('dropout3d', (S, S, S, S), (0.5,)),
132    ('dropout3d', (S, S, S, S, S), (0.5,), 'batched'),
133    ('feature_alpha_dropout', (S, S, S), (0.5,)),
134    ('threshold', (S, S, S), (0.1, 2.), '', (True,)),
135    ('threshold', (S, S, S), (0.1, 2., True), 'inplace'),
136    ('relu', (S, S, S), (), '', (True,)),
137    ('relu', (S, S, S), (), 'inplace'),
138    ('glu', (S - 1, S - 1, S - 1), (),),
139    ('hardtanh', (S, S, S), (-0.5, 0.5), '', (True,)),
140    ('hardtanh', (S, S, S), (-0.5, 0.5, True), 'inplace'),
141    ('relu6', (S, S, S), (), '', (True,)),
142    ('relu6', (S, S, S), (True), 'inplace'),
143    ('elu', (S, S, S), (0.9,),),
144    ('elu', (S, S, S), (0.9, True), 'inplace'),
145    ('selu', (S, S, S), (),),
146    ('selu', (S, S, S), (True), 'inplace'),
147    ('celu', (S, S, S), (0.9,),),
148    ('celu', (S, S, S), (0.9, True), 'inplace'),
149    ('leaky_relu', (S, S, S), (0.02,), '', (True,)),
150    ('leaky_relu', (S, S, S), (0.02,), 'inplace'),
151    ('rrelu', (S, S), (0.1, 0.3, False),),
152    ('rrelu', (S, S), (0.1, 0.3, False, True), 'inplace'),
153    ('hardshrink', (S, S, S), (0.4,), '', (True,)),
154    ('tanhshrink', (S, S, S), (),),
155    ('softsign', (S, S, S), (),),
156    ('softplus', (S, S, S), (), '', (True,)),
157    ('softmin', (S, S, S), (0,),),
158    ('softmax', (S, S, S), (0,), '', (True,)),
159    ('softmax', (S, S, S), (0, 3, torch.double), 'with_all_args', (True,)),
160    ('tanh', (S, S, S), (), '', (True,)),
161    ('sigmoid', (S, S, S), (), '', (True,)),
162    ('silu', (S, S, S), (), '', (True,)),
163    ('log_softmax', (S, S, S), (0,), '', (True,)),
164    ('linear', (S, S), ((M, S),), '', (True, ['aten::linear'])),
165    ('linear', (S, S), ((M, S), (M,)), 'addmm', (True, ['aten::linear'])),
166    ('bilinear', (S, S, S), ((S, S, M), torch.zeros(M, S, M),),),
167    ('embedding', torch.tensor([[1, 2, 4, 5], [4, 3, 2, 5]]), (torch.rand(6, 3), ), '', (True,)),
168    ('embedding_bag', torch.tensor([1, 2, 4, 2]), (torch.rand(5, 3), torch.tensor([0, 4]),),),
169    ('batch_norm', (S, S),
170        (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), None, None, True, ),
171        'training', (True, 'aten::_batch_norm_impl_index')),
172    ('batch_norm', (0, S, S, S),
173        (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
174         non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
175        'size_zero', (True, 'aten::_batch_norm_impl_index')),
176    ('batch_norm', (0, S, S, S),
177        (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
178         non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
179        'size_zero_inference', (True, 'aten::_batch_norm_impl_index')),
180    ('batch_norm', (S, S),
181        (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
182         non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ),
183        'with_weight_and_bias_training', (True, 'aten::_batch_norm_impl_index')),
184    ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
185                            None, non_differentiable(torch.ones(S)), True, ),
186        'with_only_bias_training', (True, 'aten::_batch_norm_impl_index')),
187    ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
188                            non_differentiable(torch.randn(S)), None, True, ),
189        'with_only_weight_training', (True, 'aten::_batch_norm_impl_index')),
190    ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
191                            None, None, False, ),
192        'inference', (True, 'aten::_batch_norm_impl_index')),
193    ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
194                            non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), False, ),
195        'with_weight_and_bias_inference', (True, 'aten::_batch_norm_impl_index')),
196    ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
197                            None, non_differentiable(torch.ones(S)), False, ),
198        'with_only_bias_inference', (True, 'aten::_batch_norm_impl_index')),
199    ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)),
200                            non_differentiable(torch.randn(S)), None, False, ),
201        'with_only_weight_inference', (True, 'aten::_batch_norm_impl_index')),
202    ('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),),
203    ('layer_norm', (S, S, S, S), ([5],), '',
204     (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
205    ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),), 'with_only_weight',
206     (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
207    ('layer_norm', (S, S, S, S), ([5], None, non_differentiable(torch.rand(S)),), 'with_only_bias',
208     (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
209    ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),
210                                  non_differentiable(torch.rand(S))), 'with_weight_and_bias',
211     (False, ['aten::contiguous', 'aten::_batch_norm_impl_index', 'aten::addcmul'])),
212    ('group_norm', (S, S, S), (1, torch.rand(5),),),
213    ('local_response_norm', (S, S, S), (2, ),),
214    ('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),), '',),
215    ('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2),),),
216    ('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2), True, True), 'full'),
217    ('kl_div', F.log_softmax(torch.randn(S, 10), 1), (F.softmax(torch.randn(S, 10), 1),),),
218    ('cross_entropy', (3, S), (torch.randint(S, (3,), dtype=torch.int64),),),
219    ('binary_cross_entropy_with_logits', (3,), (torch.empty(3).random_(2), ),),
220    ('smooth_l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
221    ('huber_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
222    ('l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
223    ('mse_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
224    ('smooth_l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
225    ('huber_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
226    ('l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
227    ('mse_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
228    ('margin_ranking_loss', (S,), ((S,), (S,)),),
229    ('hinge_embedding_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
230    ('soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
231    ('multilabel_soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
232    ('cosine_embedding_loss', (S, S), ((S, S), non_differentiable(torch.rand(S,))),),
233    ('pixel_shuffle', (1, 9, 4, 4), (3,),),
234    ('pixel_unshuffle', (1, 1, 12, 12), (3,),),
235    ('affine_grid', (S, 2, 3), (torch.Size([S, 1, 7, 7]),),),
236    ('pad', (3, 3, 4, 2), ([1, 1],),),
237    ('pairwise_distance', (S, S), ((S, S),),),
238    ('pdist', (S, S), (),),
239    ('cosine_similarity', (S, S), ((S, S),),),
240    ('triplet_margin_loss', (S, S), ((S, S), (S, S)),),
241    ('normalize', (S, S, S), (),),
242    ('unfold', (S, S, S, S), ([2, 3]),),
243    ('fold', (1, 3 * 2 * 2, 12), ([4, 5], [2, 2]),),
244    ('grid_sample', (S, S, S, S), (non_differentiable(torch.rand(S, S, S, 2)),),),
245    ('gumbel_softmax', (S, S), (2.,), '', (True, ['aten::softmax', 'aten::add', 'aten::div'], ['aten::neg'])),
246    ('gumbel_softmax', (S, S), (2., True,), 'hard', (True, ['aten::softmax', 'aten::add', 'aten::div'], ['aten::neg'])),
247    ('multilabel_margin_loss', torch.tensor([[0.2, -0.2, 0.07]]), (torch.tensor([[0, 0, 1]]),),),
248    ('multi_margin_loss', (S, S), (non_differentiable(torch.randint(S, (S, ), dtype=torch.int64)),
249                                   1, 1., non_differentiable(torch.randn(S))),),
250    ('binary_cross_entropy', torch.randn(3, 2).sigmoid(), (non_differentiable(torch.rand(3, 2)),
251                                                           non_differentiable(torch.randn(3, 2))),),
252    ('binary_cross_entropy', torch.randn(3, 2).sigmoid(),
253        (non_differentiable(torch.rand(3, 2)),
254         non_differentiable(torch.randn(3, 2)), None, None, 'mean'), 'size_average'),
255    ('ctc_loss', torch.rand(S, S, S).log_softmax(2).detach().requires_grad_(),
256     (torch.randint(1, S, (S, S), dtype=torch.long), torch.full((S,), S, dtype=torch.long),
257      torch.randint(1, S, (S,), dtype=torch.long))),
258    ('upsample', torch.randn(S, S, M, M), (None, 2.), 'with_scale'),
259    ('upsample', torch.randn(S, S, M, M), (4,), 'with_size'),
260    ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'nearest_4d'),
261    ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'nearest_4d_with_scale'),
262    ('interpolate', torch.randn(S, S, M, M), (4,), 'nearest_4d_with_size'),
263    ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'area_4d'),
264    ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'area_4d_with_scale'),
265    ('interpolate', torch.randn(S, S, M, M), (4,), 'area_4d_with_size'),
266    ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bilinear_4d'),
267    ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bilinear_4d_with_scale'),
268    ('interpolate', torch.randn(S, S, M, M), (4,), 'bilinear_4d_with_size'),
269    ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bicubic_4d'),
270    ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bicubic_4d_with_scale'),
271    ('interpolate', torch.randn(S, S, M, M), (4,), 'bicubic_4d_with_size'),
272    ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'nearest_3d'),
273    ('interpolate', torch.randn(S, M, M), (None, 2.), 'nearest_3d_with_scale'),
274    ('interpolate', torch.randn(S, M, M), (4,), 'nearest_3d_with_size'),
275    ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'area_3d'),
276    ('interpolate', torch.randn(S, M, M), (None, 2.), 'area_3d_with_scale'),
277    ('interpolate', torch.randn(S, M, M), (4,), 'area_3d_with_size'),
278    ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'linear_3d'),
279    ('interpolate', torch.randn(S, M, M), (None, 2.), 'linear_3d_with_scale'),
280    ('interpolate', torch.randn(S, M, M), (4,), 'linear_3d_with_size'),
281    ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'nearest_5d_with_scale'),
282    ('interpolate', torch.randn(S, M, M, M, M), (4,), 'nearest_5d_with_size'),
283    ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'area_5d'),
284    ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'area_5d_with_scale'),
285    ('interpolate', torch.randn(S, M, M, M, M), (4,), 'area_5d_with_size'),
286    ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'trilinear_5d'),
287    ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'trilinear_5d_with_scale'),
288    ('interpolate', torch.randn(S, M, M, M, M), (4,), 'trilinear_5d_with_size'),
289    ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2, None, 'nearest', None, False),
290     'nearest_4d_not_recompute_scale_factor'),
291    ('interpolate', torch.randn(S, S, M, M), (4, None, 'nearest', None, False),
292     'nearest_4d_with_size_not_recompute_scale_factor'),
293    ('interpolate', torch.randn(S, S, M, M), (None, 2., 'bilinear', None, False),
294     'bilinear_4d_with_scale_not_recompute_scale_factor'),
295    ('interpolate', torch.randn(S, S, M, M), (4, None, 'bilinear', None, False),
296     'bilinear_4d_with_size_not_recompute_scale_factor'),
297    ('interpolate', torch.randn(S, S, M, M), (None, 2., 'bicubic', None, False),
298     'bicubic_4d_with_scale_not_recompute_scale_factor'),
299    ('interpolate', torch.randn(S, S, M, M), (4, None, 'bicubic', None, False),
300     'bicubic_4d_with_size_not_recompute_scale_factor'),
301    ('interpolate', torch.randn(S, M, M), (None, 2., 'nearest', None, False),
302     'nearest_3d_with_scale_not_recompute_scale_factor'),
303    ('interpolate', torch.randn(S, M, M), (4, None, 'nearest', None, False),
304     'nearest_3d_with_size_not_recompute_scale_factor'),
305    ('interpolate', torch.randn(S, M, M), (None, 2., 'linear', None, False),
306     'linear_3d_with_scale_not_recompute_scale_factor'),
307    ('interpolate', torch.randn(S, M, M), (4, None, 'linear', None, False),
308     'linear_3d_with_size_not_recompute_scale_factor'),
309    ('interpolate', torch.randn(S, M, M, M, M), (None, 2., 'nearest', None, False),
310     'nearest_5d_with_scale_not_recompute_scale_factor'),
311    ('interpolate', torch.randn(S, M, M, M, M), (4, None, 'nearest', None, False),
312     'nearest_5d_with_size_not_recompute_scale_factor'),
313    ('interpolate', torch.randn(S, M, M, M, M), (None, 2., 'trilinear', None, False),
314     'trilinear_5d_with_scale_not_recompute_scale_factor'),
315    ('interpolate', torch.randn(S, M, M, M, M), (4, None, 'trilinear', None, False),
316     'trilinear_5d_with_size_not_recompute_scale_factor'),
317]
318
319script_template = '''
320def the_method({}):
321    return {}
322'''
323
324def value_to_literal(value):
325    if isinstance(value, str):
326        # Quotes string and escapes special characters
327        return ascii(value)
328    if isinstance(value, torch.Tensor):
329        return 'torch.' + str(value)
330    else:
331        return str(value)
332
333def get_call(method_name, func_type, args, kwargs):
334    kwargs_str = ', '.join([k + '=' + value_to_literal(v) for k, v in kwargs.items()])
335    self_arg = args[0]
336    if func_type == 'method':
337        args = args[1:]
338
339    argument_str = ', '.join(args)
340    argument_str += ', ' if len(args) and len(kwargs) else ''
341    argument_str += kwargs_str
342
343    if func_type == 'functional' or func_type == 'function':
344        call = f'torch.{method_name}({argument_str})'
345    elif func_type == 'method':
346        call = f'{self_arg}.{method_name}({argument_str})'
347    elif func_type == 'nn_functional':
348        call = f'torch.nn.functional.{method_name}({argument_str})'
349    else:
350        raise TypeError('Unsupported function type')
351
352    return call
353
354def get_constant(x):
355    if x == inf:
356        return 'math.inf'
357    if x == -inf:
358        return '-math.inf'
359    return x
360
361def get_script_args(args):
362    formals: List[str] = []
363    tensors: List[Union[torch.Tensor, List[torch.Tensor]]] = []
364    actuals: List[str] = []
365    for arg in args:
366        if isinstance(arg, torch.Tensor):
367            name = f'i{len(formals)}'
368            formals.append(name)
369            actuals.append(name)
370            tensors.append(arg)
371        elif is_iterable_of_tensors(arg):
372            name = f'i{len(formals)}'
373            formals.append(name + ': List[torch.Tensor]')
374            actuals.append(name)
375            tensors.append(list(arg))
376        elif isinstance(arg, str):
377            actuals.append(f"'{arg}'")
378        else:
379            actuals.append(str(get_constant(arg)))
380    return (formals, tensors, actuals)
381
382# create a script function from (name, func_type, output_process_fn),
383# and returns the compiled function and example inputs
384def gen_script_fn_and_args(method_name, func_type, *args, **kwargs):
385    formals, tensors, actuals = get_script_args(args)
386    call = get_call(method_name, func_type, actuals, kwargs)
387    script = script_template.format(', '.join(formals), call)
388    CU = torch.jit.CompilationUnit(script)
389    return CU.the_method, tensors
390
391# create a script function from (name, func_type),
392# returns a function takes in (args, kwargs) and runs the compiled function
393def create_script_fn(self, method_name, func_type):
394    # function returns tuple containing original output and
395    # filtered output to be used in checking gradients
396    def script_fn(*args, **kwargs):
397        fn, tensors = gen_script_fn_and_args(method_name, func_type, *args, **kwargs)
398        self.assertExportImport(fn.graph, tensors)
399        output = fn(*tensors)
400        # skip type annotate function attributes for now, see: https://github.com/python/mypy/issues/2087
401        script_fn.last_graph = fn.graph_for(*tensors)  # type: ignore[attr-defined]
402        return output
403    return script_fn
404
405class SplitInputs:
406    all_tensors: List[Any]
407    tensor_args: List[Any]
408    nontensor_args: List[Any]
409    arg_types: List[str]
410    tensor_kwargs: Dict[str, Any]
411    kwarg_order: List[str]
412    nontensor_kwargs: Dict[str, Any]
413    kwarg_types: Dict[str, Any]
414
415    @staticmethod
416    def _is_tensor_input(arg):
417        return isinstance(arg, torch.Tensor) or is_iterable_of_tensors(arg)
418
419    def __init__(self, args, kwargs):
420        self.arg_types = ['t' if self._is_tensor_input(arg) else 's' for arg in args]
421        self.kwarg_types = {k: 't' if self._is_tensor_input(v) else 's' for k, v in kwargs.items()}
422        self.tensor_args = [arg for arg in args if self._is_tensor_input(arg)]
423        self.nontensor_args = [arg for arg in args if not self._is_tensor_input(arg)]
424        self.tensor_kwargs = {k: v for k, v in kwargs.items() if self._is_tensor_input(v)}
425        self.nontensor_kwargs = {k: v for k, v in kwargs.items() if not self._is_tensor_input(v)}
426        self.all_tensors = [*self.tensor_args, *[v for k, v in self.tensor_kwargs.items()]]
427        self.kwarg_order = [k for k, v in kwargs.items()]
428
429    def nontensors_match(self, other: 'SplitInputs'):
430        if self.arg_types != other.arg_types:
431            return False
432        if self.kwarg_types != other.kwarg_types:
433            return False
434        if self.kwarg_order != other.kwarg_order:
435            return False
436        if self.nontensor_args != other.nontensor_args:
437            return False
438        if self.nontensor_kwargs != other.nontensor_kwargs:
439            return False
440        return True
441
442# make a new function where all non-tensor arguments in 'args' have been partially
443# applied, and all tensor arguments remain.
444# used to trace functions when some arguments are not tensors
445def partial_apply_nontensors(fn, args, kwargs):
446    inputs = SplitInputs(args, kwargs)
447
448    def new_fn(*tensors_):
449        tensors = iter(tensors_)
450        full_args = [args[i] if s == 's' else next(tensors) for i, s in enumerate(inputs.arg_types)]
451        full_kwargs = {k: kwargs[k] if s == 's' else next(tensors) for k, s in inputs.kwarg_types.items()}
452        return fn(*full_args, **full_kwargs)
453
454    return new_fn, inputs
455
456# create a trace function from input fn
457def create_traced_fn(self, fn, cache_traced_fn=False):
458    def traced_fn(*inputs, **kwargs):
459        # `check_trace` is set to False because check_trace is run with @no_grad
460        # Also, `check_against_reference` already does all the checks
461        # against python function
462        fn_tensors, split_inputs = partial_apply_nontensors(fn, inputs, kwargs)
463        if not cache_traced_fn or not hasattr(traced_fn, 'traced'):
464            traced = torch.jit.trace(fn_tensors, split_inputs.all_tensors, check_trace=False)
465            self.assertExportImport(traced.graph, split_inputs.all_tensors)
466            output = traced(*split_inputs.all_tensors)
467            if cache_traced_fn:
468                traced_fn.traced = traced
469                traced_fn.split_inputs = split_inputs
470        else:
471            # Guard to check that nontensor inputs are the same as during tracing
472            self.assertTrue(traced_fn.split_inputs.nontensors_match(split_inputs))
473            output = traced_fn.traced(*split_inputs.all_tensors)
474            traced = traced_fn.traced
475        # skip type annotate function attributes for now, see: https://github.com/python/mypy/issues/2087
476        traced_fn.last_graph = traced.graph_for(*split_inputs.all_tensors)  # type: ignore[attr-defined]
477        traced_fn.graph = traced.graph  # type: ignore[attr-defined]
478        return output
479    return traced_fn
480
481# known to be failing in script
482EXCLUDE_SCRIPT = {
483    'test_norm_fro_default',
484    'test_norm_fro_cpu',
485    'test_norm_nuc',
486    'test_norm_fro',
487    'test_norm_nuc_batched',
488
489    # aten op has additional cudnn argument
490    'test_nn_unfold',
491
492    # flaky test - TODO fix
493    'test_nn_ctc_loss',
494
495    # unknown builtin op
496    'test_nn_fold',
497
498    # jit doesn't support sparse tensors.
499    'test_to_sparse',
500    'test_to_sparse_dim',
501}
502
503# generates a script function and set of example inputs
504# from a specified test in the format of nn_functional_tests
505def get_nn_functional_compiled_fn_and_inputs(name, self_size, args, variant_name='', *extra_args):
506    test_name = 'test_nn_' + name
507
508    if variant_name != '':
509        test_name = test_name + '_' + variant_name
510
511    no_grad = variant_name == 'inplace'
512
513    self_variable = create_input((self_size,))[0][0]
514    kwargs = None
515
516    # need to record this because methods can change the size (e.g. unsqueeze)
517    args_variable, kwargs_variable = create_input(args)
518
519    self_tensor = deepcopy(self_variable.data)
520    args_tensor = deepcopy(unpack_variables(args_variable))
521
522    f_args_variable = (self_variable,) + args_variable
523    f_args_tensor = (self_tensor,) + args_tensor
524    with torch._jit_internal._disable_emit_hooks():
525        script_fn, inputs = gen_script_fn_and_args(name, "nn_functional", *f_args_variable)
526    return script_fn, inputs
527
528
529# additional modules test
530# TODO: delete this list once we make all nn_tests work
531additional_module_tests = [
532    {
533        'module_name': 'Bilinear',
534        'constructor_args': (S, S, M),
535        'input_size': (S, S),
536        'extra_args': ((S, S),)
537    },
538    {
539        'module_name': 'RNNCell',
540        'constructor_args': (S, S),
541        'input_size': (S, S),
542    },
543    {
544        'module_name': 'LSTMCell',
545        'constructor_args': (S, S),
546        'input_size': (S, S),
547    },
548    {
549        'module_name': 'GRUCell',
550        'constructor_args': (S, S),
551        'input_size': (S, S),
552    },
553    {
554        'module_name': 'MultiheadAttention',
555        'constructor_args': (128, 8),
556        'input_size': (10, 8, 128),
557        'extra_args': (torch.randn(10, 8, 128), torch.randn(10, 8, 128)),
558        'slowTest': True
559    },
560    {
561        'module_name': 'Transformer',
562        'constructor_args': (1, 1, 1, 1, 2),
563        'input_size': (3, 1, 1),
564        'extra_args': (torch.randn(1, 1, 1),),
565        'slowTest': True
566    }
567]
568
569EXCLUDE_SCRIPT_MODULES = {
570    'test_nn_AdaptiveAvgPool2d_tuple_none',
571    'test_nn_AdaptiveAvgPool3d_tuple_none',
572    'test_nn_AdaptiveMaxPool2d_tuple_none',
573    'test_nn_AdaptiveMaxPool3d_tuple_none',
574
575    # Doesn't use future division, so this is not supported
576    'test_nn_CrossMapLRN2d',
577    # Derivative for aten::_scaled_dot_product_flash_attention_backward is not implemented
578    'test_nn_TransformerDecoderLayer_gelu_activation',
579    'test_nn_TransformerDecoderLayer_relu_activation',
580    'test_nn_TransformerEncoderLayer_gelu_activation',
581    'test_nn_TransformerEncoderLayer_relu_activation',
582    'test_nn_Transformer_multilayer_coder',
583}
584
585script_method_template = '''
586def forward({}):
587    return {}
588'''
589
590def create_script_module(self, nn_module, constructor_args, *args, **kwargs):
591    def script_module(*args, **kwargs):
592        formals, tensors, actuals = get_script_args(args)
593
594        method_args = ', '.join(['self'] + actuals)
595        call_args_str = ', '.join(actuals)
596        call = f"self.submodule({call_args_str})"
597        script = script_method_template.format(method_args, call)
598
599        submodule_constants = []
600        if kwargs.get('is_constant'):
601            submodule_constants = ['submodule']
602
603        # Create module to use the script method
604        class TheModule(torch.jit.ScriptModule):
605            __constants__ = submodule_constants
606
607            def __init__(self) -> None:
608                super().__init__()
609                self.submodule = nn_module(*constructor_args)
610
611        def make_module(script):
612            module = TheModule()
613            # check __repr__
614            str(module)
615            module.define(script)
616            return module
617
618        module = make_module(script)
619        if self:
620            self.assertExportImportModule(module, tensors)
621            module(*args)
622        # skip type annotate function attributes for now, see: https://github.com/python/mypy/issues/2087
623        create_script_module.last_graph = module.graph  # type: ignore[attr-defined]
624        return module
625    return script_module
626
627def check_alias_annotation(method_name, args, kwargs, *, aten_name, func_type='method'):
628    formals, tensors, actuals = get_script_args(args)
629    call = get_call(method_name, func_type, actuals, kwargs)
630    script = script_template.format(', '.join(formals), call)
631    CU = torch.jit.CompilationUnit(script)
632    # to clean up IR
633    torch._C._jit_pass_inline(CU.the_method.graph)
634    torch._C._jit_pass_constant_propagation(CU.the_method.graph)
635    torch._C._jit_check_alias_annotation(CU.the_method.graph, tuple(tensors), aten_name)
636
637def get_nn_module_name_from_kwargs(**kwargs):
638    if 'module_name' in kwargs:
639        return kwargs['module_name']
640    elif 'fullname' in kwargs:
641        return kwargs['fullname']
642    elif 'constructor' in kwargs:
643        return kwargs['constructor'].__name__
644
645def get_nn_mod_test_name(**kwargs):
646    if 'fullname' in kwargs:
647        test_name = kwargs['fullname']
648    else:
649        test_name = get_nn_module_name_from_kwargs(**kwargs)
650        if 'desc' in kwargs:
651            test_name = f"{test_name}_{kwargs['desc']}"
652    return f'test_nn_{test_name}'
653
654def get_nn_module_class_from_kwargs(**kwargs):
655    name = get_nn_module_name_from_kwargs(**kwargs)
656    index = name.find("_")
657    if index == -1:
658        return name
659    else:
660        return name[0:name.find("_")]
661
662def try_get_nn_module_compiled_mod_and_inputs(*args, **kwargs):
663    name = get_nn_module_name_from_kwargs(**kwargs)
664
665    if 'desc' in kwargs and 'eval' in kwargs['desc']:
666        # eval() is not supported, so skip these tests
667        return
668
669    test_name = name
670    if 'desc' in kwargs:
671        test_name = f"{test_name}_{kwargs['desc']}"
672    test_name = get_nn_mod_test_name(**kwargs)
673
674    if test_name in EXCLUDE_SCRIPT_MODULES:
675        return
676    if 'constructor' in kwargs:
677        nn_module = kwargs['constructor']
678    else:
679        nn_module = getattr(torch.nn, name)
680
681    if "FunctionalModule" in str(nn_module):
682        return
683
684    if 'constructor_args_fn' in kwargs:
685        constructor_args = kwargs['constructor_args_fn']()
686    else:
687        constructor_args = kwargs.get('constructor_args', ())
688
689    # Set up inputs from tuple of sizes or constructor fn
690    input_dtype = torch.double
691    if 'input_fn' in kwargs:
692        input = kwargs['input_fn']()
693        if isinstance(input, torch.Tensor):
694            input = (input,)
695
696        if all(tensor.is_complex() for tensor in input):
697            input_dtype = torch.cdouble
698    else:
699        input = (kwargs['input_size'],)
700
701    # Extra parameters to forward()
702    if 'extra_args' in kwargs:
703        input = input + kwargs['extra_args']
704
705    if 'target_size' in kwargs:
706        input = input + (kwargs['target_size'],)
707    elif 'target_fn' in kwargs:
708        if torch.is_tensor(input):
709            input = (input,)
710        input = input + (kwargs['target_fn'](),)
711
712    args_variable, kwargs_variable = create_input(input, dtype=input_dtype)
713    f_args_variable = deepcopy(unpack_variables(args_variable))
714    out_var = deepcopy(f_args_variable)
715
716    args, mod = f_args_variable, create_script_module(None, nn_module, constructor_args, *f_args_variable)(*f_args_variable)
717
718    return mod, out_var
719
720
721def get_all_nn_module_tests():
722    return module_tests + new_module_tests + additional_module_tests
723