1# mypy: ignore-errors 2 3# Torch 4from torch.jit.annotations import BroadcastingList2, BroadcastingList3 # noqa: F401 5import torch.nn.functional as F 6import torch 7import torch.cuda 8import torch.jit 9import torch.jit._logging 10import torch.jit.frontend 11from torch.testing._internal.common_nn import module_tests, new_module_tests 12from torch.testing._internal.common_utils import is_iterable_of_tensors, noncontiguous_like 13 14import collections 15from copy import deepcopy 16from typing import Any, Dict, List, Union 17import math # noqa: F401 18 19# Testing utils 20from torch import inf 21 22assert torch.get_default_dtype() == torch.float32 23 24L = 20 25M = 10 26S = 5 27 28 29def unpack_variables(args): 30 if isinstance(args, tuple): 31 return tuple(unpack_variables(elem) for elem in args) 32 else: 33 return args 34 35class dont_convert(tuple): 36 pass 37 38non_differentiable = collections.namedtuple('non_differentiable', ['tensor']) 39 40def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwargs=None, dtype=torch.float, device=None): 41 if not isinstance(call_args, tuple): 42 call_args = (call_args,) 43 44 def map_arg(arg): 45 def maybe_non_contig(tensor): 46 if not non_contiguous or tensor.numel() < 2: 47 return tensor.clone() 48 49 return noncontiguous_like(tensor) 50 51 def conjugate(tensor): 52 return tensor.conj() 53 54 if isinstance(arg, (torch.Size, dont_convert)): 55 return arg 56 elif isinstance(arg, tuple) and len(arg) == 0: 57 var = conjugate(torch.randn((), dtype=dtype, device=device)) 58 var.requires_grad = requires_grad 59 return var 60 elif isinstance(arg, tuple) and not isinstance(arg[0], torch.Tensor): 61 return conjugate(maybe_non_contig(torch.randn(*arg, dtype=dtype, device=device))).requires_grad_(requires_grad) 62 # double check casting 63 elif isinstance(arg, non_differentiable): 64 if isinstance(arg.tensor, torch.Tensor): 65 return conjugate(maybe_non_contig(arg.tensor.to(device=device))) 66 return conjugate(maybe_non_contig(arg.tensor.to(device=device))) 67 elif isinstance(arg, torch.Tensor): 68 if arg.is_complex() != dtype.is_complex: 69 raise RuntimeError("User provided tensor is real for a test that runs with complex dtype, ", 70 "which is not supported for now") 71 # NOTE: We do clone() after detach() here because we need to be able to change size/storage of v afterwards 72 v = conjugate(maybe_non_contig(arg)).detach().to(device=device).clone() 73 v.requires_grad = requires_grad and (v.is_floating_point() or v.is_complex()) 74 return v 75 elif callable(arg): 76 return map_arg(arg(dtype=dtype, device=device)) 77 else: 78 return arg 79 args_out = tuple(map_arg(arg) for arg in call_args) 80 kwargs_out = {k: map_arg(v) for k, v in call_kwargs.items()} if call_kwargs else {} 81 return args_out, kwargs_out 82 83# NB: JIT script tests for all nn functional interfaces, script mode does 84# not support in_place operations yet, so no inplace operation tests added. 85# removed all the deprecated functions 86# 87# ( 88# method name, 89# input size/constructing fn, 90# args (tuple represents shape of a tensor arg), 91# test variant name(will be used at test name suffix, 92# 'inplace' skips grad tests), // optional 93# (True, nonfusible_nodes, fusible_nodes) for autodiff // optional 94# fn to determine if test should be skipped, // optional 95# fn mapping output to part that should be gradcheck'ed, // optional 96# kwargs for function, // optional 97# ) 98nn_functional_tests = [ 99 ('conv1d', (S, S, S), ((S, S, S),)), 100 ('conv2d', (S, S, S, S), ((S, S, S, S),)), 101 ('conv3d', (S, S, S, S, S), ((S, S, S, S, S),)), 102 ('conv_transpose1d', (S, S, S), ((S, S, S),)), 103 ('conv_transpose2d', (S, S, S, S), ((S, S, S, S),)), 104 ('conv_transpose3d', (S, S, S, S, S), ((S, S, S, S, S),)), 105 ('conv_tbc', (S, S, S), ((S, S, S), (S,), 2)), 106 ('avg_pool1d', (S, S, S), (3,)), 107 ('avg_pool2d', (S, S, S, S), (3,), '', (True,)), 108 ('avg_pool3d', (S, S, S, S, S), (3,)), 109 ('fractional_max_pool2d', (S, S, S, S), (3, [2, 3],)), 110 ('max_pool1d', (S, S, S), (2, 1)), 111 ('max_pool1d', (S, S, S), (2, 1, 1, 1, False, True), 'with_indices'), 112 ('max_pool2d', (S, S, S, S), (2, 1), '', (True, 'aten::max_pool2d_with_indices')), 113 ('max_pool2d', (S, S, S, S), (2, 1, 1, 1, False, True), 'with_indices', (True, 'aten::max_pool2d_with_indices')), 114 ('max_pool3d', (S, S, S, S, S), (2, 1)), 115 ('max_unpool1d', torch.tensor([[[2., 4]]]), (torch.tensor([[[1, 3]]]), 2, 2, 0)), 116 ('max_unpool2d', torch.tensor([[[[2., 4]]]]), (torch.tensor([[[[1, 3]]]]), 2, 2, 0)), 117 ('max_unpool3d', torch.tensor([[[[[2., 4]]]]]), (torch.tensor([[[[[1, 3]]]]]), 2, 2, 0)), 118 ('lp_pool1d', (S, S, S), (2., 3, 2,)), 119 ('lp_pool2d', (S, S, S, S), (2., 3, 2,)), 120 ('lp_pool3d', (S, S, S, S, S), (2., 3, 2,)), 121 ('adaptive_max_pool1d', (S, S, S), (5,)), 122 ('adaptive_max_pool2d', (S, S, S, S), ([5, 7],)), 123 ('adaptive_max_pool3d', (S, S, S, S, S), ([3, 2, 2],)), 124 ('adaptive_avg_pool1d', (S, S, S), (5,), '', (True,)), 125 ('adaptive_avg_pool2d', (S, S, S, S), ([5, 7],), '', (True,)), 126 ('adaptive_avg_pool3d', (S, S, S, S, S), ([3, 2, 2],), '', (True,)), 127 ('dropout', (S, S, S), (0.5,), '', (True, 'aten::native_dropout')), 128 ('alpha_dropout', (S, S, S), (0.5,)), 129 ('dropout2d', (S, S, S), (0.5,)), 130 ('dropout2d', (S, S, S, S), (0.5,), 'batched'), 131 ('dropout3d', (S, S, S, S), (0.5,)), 132 ('dropout3d', (S, S, S, S, S), (0.5,), 'batched'), 133 ('feature_alpha_dropout', (S, S, S), (0.5,)), 134 ('threshold', (S, S, S), (0.1, 2.), '', (True,)), 135 ('threshold', (S, S, S), (0.1, 2., True), 'inplace'), 136 ('relu', (S, S, S), (), '', (True,)), 137 ('relu', (S, S, S), (), 'inplace'), 138 ('glu', (S - 1, S - 1, S - 1), (),), 139 ('hardtanh', (S, S, S), (-0.5, 0.5), '', (True,)), 140 ('hardtanh', (S, S, S), (-0.5, 0.5, True), 'inplace'), 141 ('relu6', (S, S, S), (), '', (True,)), 142 ('relu6', (S, S, S), (True), 'inplace'), 143 ('elu', (S, S, S), (0.9,),), 144 ('elu', (S, S, S), (0.9, True), 'inplace'), 145 ('selu', (S, S, S), (),), 146 ('selu', (S, S, S), (True), 'inplace'), 147 ('celu', (S, S, S), (0.9,),), 148 ('celu', (S, S, S), (0.9, True), 'inplace'), 149 ('leaky_relu', (S, S, S), (0.02,), '', (True,)), 150 ('leaky_relu', (S, S, S), (0.02,), 'inplace'), 151 ('rrelu', (S, S), (0.1, 0.3, False),), 152 ('rrelu', (S, S), (0.1, 0.3, False, True), 'inplace'), 153 ('hardshrink', (S, S, S), (0.4,), '', (True,)), 154 ('tanhshrink', (S, S, S), (),), 155 ('softsign', (S, S, S), (),), 156 ('softplus', (S, S, S), (), '', (True,)), 157 ('softmin', (S, S, S), (0,),), 158 ('softmax', (S, S, S), (0,), '', (True,)), 159 ('softmax', (S, S, S), (0, 3, torch.double), 'with_all_args', (True,)), 160 ('tanh', (S, S, S), (), '', (True,)), 161 ('sigmoid', (S, S, S), (), '', (True,)), 162 ('silu', (S, S, S), (), '', (True,)), 163 ('log_softmax', (S, S, S), (0,), '', (True,)), 164 ('linear', (S, S), ((M, S),), '', (True, ['aten::linear'])), 165 ('linear', (S, S), ((M, S), (M,)), 'addmm', (True, ['aten::linear'])), 166 ('bilinear', (S, S, S), ((S, S, M), torch.zeros(M, S, M),),), 167 ('embedding', torch.tensor([[1, 2, 4, 5], [4, 3, 2, 5]]), (torch.rand(6, 3), ), '', (True,)), 168 ('embedding_bag', torch.tensor([1, 2, 4, 2]), (torch.rand(5, 3), torch.tensor([0, 4]),),), 169 ('batch_norm', (S, S), 170 (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), None, None, True, ), 171 'training', (True, 'aten::_batch_norm_impl_index')), 172 ('batch_norm', (0, S, S, S), 173 (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), 174 non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ), 175 'size_zero', (True, 'aten::_batch_norm_impl_index')), 176 ('batch_norm', (0, S, S, S), 177 (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), 178 non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ), 179 'size_zero_inference', (True, 'aten::_batch_norm_impl_index')), 180 ('batch_norm', (S, S), 181 (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), 182 non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), True, ), 183 'with_weight_and_bias_training', (True, 'aten::_batch_norm_impl_index')), 184 ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), 185 None, non_differentiable(torch.ones(S)), True, ), 186 'with_only_bias_training', (True, 'aten::_batch_norm_impl_index')), 187 ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), 188 non_differentiable(torch.randn(S)), None, True, ), 189 'with_only_weight_training', (True, 'aten::_batch_norm_impl_index')), 190 ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), 191 None, None, False, ), 192 'inference', (True, 'aten::_batch_norm_impl_index')), 193 ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), 194 non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), False, ), 195 'with_weight_and_bias_inference', (True, 'aten::_batch_norm_impl_index')), 196 ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), 197 None, non_differentiable(torch.ones(S)), False, ), 198 'with_only_bias_inference', (True, 'aten::_batch_norm_impl_index')), 199 ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), 200 non_differentiable(torch.randn(S)), None, False, ), 201 'with_only_weight_inference', (True, 'aten::_batch_norm_impl_index')), 202 ('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),), 203 ('layer_norm', (S, S, S, S), ([5],), '', 204 (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])), 205 ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),), 'with_only_weight', 206 (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])), 207 ('layer_norm', (S, S, S, S), ([5], None, non_differentiable(torch.rand(S)),), 'with_only_bias', 208 (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])), 209 ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)), 210 non_differentiable(torch.rand(S))), 'with_weight_and_bias', 211 (False, ['aten::contiguous', 'aten::_batch_norm_impl_index', 'aten::addcmul'])), 212 ('group_norm', (S, S, S), (1, torch.rand(5),),), 213 ('local_response_norm', (S, S, S), (2, ),), 214 ('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),), '',), 215 ('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2),),), 216 ('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2), True, True), 'full'), 217 ('kl_div', F.log_softmax(torch.randn(S, 10), 1), (F.softmax(torch.randn(S, 10), 1),),), 218 ('cross_entropy', (3, S), (torch.randint(S, (3,), dtype=torch.int64),),), 219 ('binary_cross_entropy_with_logits', (3,), (torch.empty(3).random_(2), ),), 220 ('smooth_l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), 221 ('huber_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), 222 ('l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), 223 ('mse_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), 224 ('smooth_l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'), 225 ('huber_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'), 226 ('l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'), 227 ('mse_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'), 228 ('margin_ranking_loss', (S,), ((S,), (S,)),), 229 ('hinge_embedding_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), 230 ('soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), 231 ('multilabel_soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),), 232 ('cosine_embedding_loss', (S, S), ((S, S), non_differentiable(torch.rand(S,))),), 233 ('pixel_shuffle', (1, 9, 4, 4), (3,),), 234 ('pixel_unshuffle', (1, 1, 12, 12), (3,),), 235 ('affine_grid', (S, 2, 3), (torch.Size([S, 1, 7, 7]),),), 236 ('pad', (3, 3, 4, 2), ([1, 1],),), 237 ('pairwise_distance', (S, S), ((S, S),),), 238 ('pdist', (S, S), (),), 239 ('cosine_similarity', (S, S), ((S, S),),), 240 ('triplet_margin_loss', (S, S), ((S, S), (S, S)),), 241 ('normalize', (S, S, S), (),), 242 ('unfold', (S, S, S, S), ([2, 3]),), 243 ('fold', (1, 3 * 2 * 2, 12), ([4, 5], [2, 2]),), 244 ('grid_sample', (S, S, S, S), (non_differentiable(torch.rand(S, S, S, 2)),),), 245 ('gumbel_softmax', (S, S), (2.,), '', (True, ['aten::softmax', 'aten::add', 'aten::div'], ['aten::neg'])), 246 ('gumbel_softmax', (S, S), (2., True,), 'hard', (True, ['aten::softmax', 'aten::add', 'aten::div'], ['aten::neg'])), 247 ('multilabel_margin_loss', torch.tensor([[0.2, -0.2, 0.07]]), (torch.tensor([[0, 0, 1]]),),), 248 ('multi_margin_loss', (S, S), (non_differentiable(torch.randint(S, (S, ), dtype=torch.int64)), 249 1, 1., non_differentiable(torch.randn(S))),), 250 ('binary_cross_entropy', torch.randn(3, 2).sigmoid(), (non_differentiable(torch.rand(3, 2)), 251 non_differentiable(torch.randn(3, 2))),), 252 ('binary_cross_entropy', torch.randn(3, 2).sigmoid(), 253 (non_differentiable(torch.rand(3, 2)), 254 non_differentiable(torch.randn(3, 2)), None, None, 'mean'), 'size_average'), 255 ('ctc_loss', torch.rand(S, S, S).log_softmax(2).detach().requires_grad_(), 256 (torch.randint(1, S, (S, S), dtype=torch.long), torch.full((S,), S, dtype=torch.long), 257 torch.randint(1, S, (S,), dtype=torch.long))), 258 ('upsample', torch.randn(S, S, M, M), (None, 2.), 'with_scale'), 259 ('upsample', torch.randn(S, S, M, M), (4,), 'with_size'), 260 ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'nearest_4d'), 261 ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'nearest_4d_with_scale'), 262 ('interpolate', torch.randn(S, S, M, M), (4,), 'nearest_4d_with_size'), 263 ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'area_4d'), 264 ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'area_4d_with_scale'), 265 ('interpolate', torch.randn(S, S, M, M), (4,), 'area_4d_with_size'), 266 ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bilinear_4d'), 267 ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bilinear_4d_with_scale'), 268 ('interpolate', torch.randn(S, S, M, M), (4,), 'bilinear_4d_with_size'), 269 ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bicubic_4d'), 270 ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bicubic_4d_with_scale'), 271 ('interpolate', torch.randn(S, S, M, M), (4,), 'bicubic_4d_with_size'), 272 ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'nearest_3d'), 273 ('interpolate', torch.randn(S, M, M), (None, 2.), 'nearest_3d_with_scale'), 274 ('interpolate', torch.randn(S, M, M), (4,), 'nearest_3d_with_size'), 275 ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'area_3d'), 276 ('interpolate', torch.randn(S, M, M), (None, 2.), 'area_3d_with_scale'), 277 ('interpolate', torch.randn(S, M, M), (4,), 'area_3d_with_size'), 278 ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'linear_3d'), 279 ('interpolate', torch.randn(S, M, M), (None, 2.), 'linear_3d_with_scale'), 280 ('interpolate', torch.randn(S, M, M), (4,), 'linear_3d_with_size'), 281 ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'nearest_5d_with_scale'), 282 ('interpolate', torch.randn(S, M, M, M, M), (4,), 'nearest_5d_with_size'), 283 ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'area_5d'), 284 ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'area_5d_with_scale'), 285 ('interpolate', torch.randn(S, M, M, M, M), (4,), 'area_5d_with_size'), 286 ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'trilinear_5d'), 287 ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'trilinear_5d_with_scale'), 288 ('interpolate', torch.randn(S, M, M, M, M), (4,), 'trilinear_5d_with_size'), 289 ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2, None, 'nearest', None, False), 290 'nearest_4d_not_recompute_scale_factor'), 291 ('interpolate', torch.randn(S, S, M, M), (4, None, 'nearest', None, False), 292 'nearest_4d_with_size_not_recompute_scale_factor'), 293 ('interpolate', torch.randn(S, S, M, M), (None, 2., 'bilinear', None, False), 294 'bilinear_4d_with_scale_not_recompute_scale_factor'), 295 ('interpolate', torch.randn(S, S, M, M), (4, None, 'bilinear', None, False), 296 'bilinear_4d_with_size_not_recompute_scale_factor'), 297 ('interpolate', torch.randn(S, S, M, M), (None, 2., 'bicubic', None, False), 298 'bicubic_4d_with_scale_not_recompute_scale_factor'), 299 ('interpolate', torch.randn(S, S, M, M), (4, None, 'bicubic', None, False), 300 'bicubic_4d_with_size_not_recompute_scale_factor'), 301 ('interpolate', torch.randn(S, M, M), (None, 2., 'nearest', None, False), 302 'nearest_3d_with_scale_not_recompute_scale_factor'), 303 ('interpolate', torch.randn(S, M, M), (4, None, 'nearest', None, False), 304 'nearest_3d_with_size_not_recompute_scale_factor'), 305 ('interpolate', torch.randn(S, M, M), (None, 2., 'linear', None, False), 306 'linear_3d_with_scale_not_recompute_scale_factor'), 307 ('interpolate', torch.randn(S, M, M), (4, None, 'linear', None, False), 308 'linear_3d_with_size_not_recompute_scale_factor'), 309 ('interpolate', torch.randn(S, M, M, M, M), (None, 2., 'nearest', None, False), 310 'nearest_5d_with_scale_not_recompute_scale_factor'), 311 ('interpolate', torch.randn(S, M, M, M, M), (4, None, 'nearest', None, False), 312 'nearest_5d_with_size_not_recompute_scale_factor'), 313 ('interpolate', torch.randn(S, M, M, M, M), (None, 2., 'trilinear', None, False), 314 'trilinear_5d_with_scale_not_recompute_scale_factor'), 315 ('interpolate', torch.randn(S, M, M, M, M), (4, None, 'trilinear', None, False), 316 'trilinear_5d_with_size_not_recompute_scale_factor'), 317] 318 319script_template = ''' 320def the_method({}): 321 return {} 322''' 323 324def value_to_literal(value): 325 if isinstance(value, str): 326 # Quotes string and escapes special characters 327 return ascii(value) 328 if isinstance(value, torch.Tensor): 329 return 'torch.' + str(value) 330 else: 331 return str(value) 332 333def get_call(method_name, func_type, args, kwargs): 334 kwargs_str = ', '.join([k + '=' + value_to_literal(v) for k, v in kwargs.items()]) 335 self_arg = args[0] 336 if func_type == 'method': 337 args = args[1:] 338 339 argument_str = ', '.join(args) 340 argument_str += ', ' if len(args) and len(kwargs) else '' 341 argument_str += kwargs_str 342 343 if func_type == 'functional' or func_type == 'function': 344 call = f'torch.{method_name}({argument_str})' 345 elif func_type == 'method': 346 call = f'{self_arg}.{method_name}({argument_str})' 347 elif func_type == 'nn_functional': 348 call = f'torch.nn.functional.{method_name}({argument_str})' 349 else: 350 raise TypeError('Unsupported function type') 351 352 return call 353 354def get_constant(x): 355 if x == inf: 356 return 'math.inf' 357 if x == -inf: 358 return '-math.inf' 359 return x 360 361def get_script_args(args): 362 formals: List[str] = [] 363 tensors: List[Union[torch.Tensor, List[torch.Tensor]]] = [] 364 actuals: List[str] = [] 365 for arg in args: 366 if isinstance(arg, torch.Tensor): 367 name = f'i{len(formals)}' 368 formals.append(name) 369 actuals.append(name) 370 tensors.append(arg) 371 elif is_iterable_of_tensors(arg): 372 name = f'i{len(formals)}' 373 formals.append(name + ': List[torch.Tensor]') 374 actuals.append(name) 375 tensors.append(list(arg)) 376 elif isinstance(arg, str): 377 actuals.append(f"'{arg}'") 378 else: 379 actuals.append(str(get_constant(arg))) 380 return (formals, tensors, actuals) 381 382# create a script function from (name, func_type, output_process_fn), 383# and returns the compiled function and example inputs 384def gen_script_fn_and_args(method_name, func_type, *args, **kwargs): 385 formals, tensors, actuals = get_script_args(args) 386 call = get_call(method_name, func_type, actuals, kwargs) 387 script = script_template.format(', '.join(formals), call) 388 CU = torch.jit.CompilationUnit(script) 389 return CU.the_method, tensors 390 391# create a script function from (name, func_type), 392# returns a function takes in (args, kwargs) and runs the compiled function 393def create_script_fn(self, method_name, func_type): 394 # function returns tuple containing original output and 395 # filtered output to be used in checking gradients 396 def script_fn(*args, **kwargs): 397 fn, tensors = gen_script_fn_and_args(method_name, func_type, *args, **kwargs) 398 self.assertExportImport(fn.graph, tensors) 399 output = fn(*tensors) 400 # skip type annotate function attributes for now, see: https://github.com/python/mypy/issues/2087 401 script_fn.last_graph = fn.graph_for(*tensors) # type: ignore[attr-defined] 402 return output 403 return script_fn 404 405class SplitInputs: 406 all_tensors: List[Any] 407 tensor_args: List[Any] 408 nontensor_args: List[Any] 409 arg_types: List[str] 410 tensor_kwargs: Dict[str, Any] 411 kwarg_order: List[str] 412 nontensor_kwargs: Dict[str, Any] 413 kwarg_types: Dict[str, Any] 414 415 @staticmethod 416 def _is_tensor_input(arg): 417 return isinstance(arg, torch.Tensor) or is_iterable_of_tensors(arg) 418 419 def __init__(self, args, kwargs): 420 self.arg_types = ['t' if self._is_tensor_input(arg) else 's' for arg in args] 421 self.kwarg_types = {k: 't' if self._is_tensor_input(v) else 's' for k, v in kwargs.items()} 422 self.tensor_args = [arg for arg in args if self._is_tensor_input(arg)] 423 self.nontensor_args = [arg for arg in args if not self._is_tensor_input(arg)] 424 self.tensor_kwargs = {k: v for k, v in kwargs.items() if self._is_tensor_input(v)} 425 self.nontensor_kwargs = {k: v for k, v in kwargs.items() if not self._is_tensor_input(v)} 426 self.all_tensors = [*self.tensor_args, *[v for k, v in self.tensor_kwargs.items()]] 427 self.kwarg_order = [k for k, v in kwargs.items()] 428 429 def nontensors_match(self, other: 'SplitInputs'): 430 if self.arg_types != other.arg_types: 431 return False 432 if self.kwarg_types != other.kwarg_types: 433 return False 434 if self.kwarg_order != other.kwarg_order: 435 return False 436 if self.nontensor_args != other.nontensor_args: 437 return False 438 if self.nontensor_kwargs != other.nontensor_kwargs: 439 return False 440 return True 441 442# make a new function where all non-tensor arguments in 'args' have been partially 443# applied, and all tensor arguments remain. 444# used to trace functions when some arguments are not tensors 445def partial_apply_nontensors(fn, args, kwargs): 446 inputs = SplitInputs(args, kwargs) 447 448 def new_fn(*tensors_): 449 tensors = iter(tensors_) 450 full_args = [args[i] if s == 's' else next(tensors) for i, s in enumerate(inputs.arg_types)] 451 full_kwargs = {k: kwargs[k] if s == 's' else next(tensors) for k, s in inputs.kwarg_types.items()} 452 return fn(*full_args, **full_kwargs) 453 454 return new_fn, inputs 455 456# create a trace function from input fn 457def create_traced_fn(self, fn, cache_traced_fn=False): 458 def traced_fn(*inputs, **kwargs): 459 # `check_trace` is set to False because check_trace is run with @no_grad 460 # Also, `check_against_reference` already does all the checks 461 # against python function 462 fn_tensors, split_inputs = partial_apply_nontensors(fn, inputs, kwargs) 463 if not cache_traced_fn or not hasattr(traced_fn, 'traced'): 464 traced = torch.jit.trace(fn_tensors, split_inputs.all_tensors, check_trace=False) 465 self.assertExportImport(traced.graph, split_inputs.all_tensors) 466 output = traced(*split_inputs.all_tensors) 467 if cache_traced_fn: 468 traced_fn.traced = traced 469 traced_fn.split_inputs = split_inputs 470 else: 471 # Guard to check that nontensor inputs are the same as during tracing 472 self.assertTrue(traced_fn.split_inputs.nontensors_match(split_inputs)) 473 output = traced_fn.traced(*split_inputs.all_tensors) 474 traced = traced_fn.traced 475 # skip type annotate function attributes for now, see: https://github.com/python/mypy/issues/2087 476 traced_fn.last_graph = traced.graph_for(*split_inputs.all_tensors) # type: ignore[attr-defined] 477 traced_fn.graph = traced.graph # type: ignore[attr-defined] 478 return output 479 return traced_fn 480 481# known to be failing in script 482EXCLUDE_SCRIPT = { 483 'test_norm_fro_default', 484 'test_norm_fro_cpu', 485 'test_norm_nuc', 486 'test_norm_fro', 487 'test_norm_nuc_batched', 488 489 # aten op has additional cudnn argument 490 'test_nn_unfold', 491 492 # flaky test - TODO fix 493 'test_nn_ctc_loss', 494 495 # unknown builtin op 496 'test_nn_fold', 497 498 # jit doesn't support sparse tensors. 499 'test_to_sparse', 500 'test_to_sparse_dim', 501} 502 503# generates a script function and set of example inputs 504# from a specified test in the format of nn_functional_tests 505def get_nn_functional_compiled_fn_and_inputs(name, self_size, args, variant_name='', *extra_args): 506 test_name = 'test_nn_' + name 507 508 if variant_name != '': 509 test_name = test_name + '_' + variant_name 510 511 no_grad = variant_name == 'inplace' 512 513 self_variable = create_input((self_size,))[0][0] 514 kwargs = None 515 516 # need to record this because methods can change the size (e.g. unsqueeze) 517 args_variable, kwargs_variable = create_input(args) 518 519 self_tensor = deepcopy(self_variable.data) 520 args_tensor = deepcopy(unpack_variables(args_variable)) 521 522 f_args_variable = (self_variable,) + args_variable 523 f_args_tensor = (self_tensor,) + args_tensor 524 with torch._jit_internal._disable_emit_hooks(): 525 script_fn, inputs = gen_script_fn_and_args(name, "nn_functional", *f_args_variable) 526 return script_fn, inputs 527 528 529# additional modules test 530# TODO: delete this list once we make all nn_tests work 531additional_module_tests = [ 532 { 533 'module_name': 'Bilinear', 534 'constructor_args': (S, S, M), 535 'input_size': (S, S), 536 'extra_args': ((S, S),) 537 }, 538 { 539 'module_name': 'RNNCell', 540 'constructor_args': (S, S), 541 'input_size': (S, S), 542 }, 543 { 544 'module_name': 'LSTMCell', 545 'constructor_args': (S, S), 546 'input_size': (S, S), 547 }, 548 { 549 'module_name': 'GRUCell', 550 'constructor_args': (S, S), 551 'input_size': (S, S), 552 }, 553 { 554 'module_name': 'MultiheadAttention', 555 'constructor_args': (128, 8), 556 'input_size': (10, 8, 128), 557 'extra_args': (torch.randn(10, 8, 128), torch.randn(10, 8, 128)), 558 'slowTest': True 559 }, 560 { 561 'module_name': 'Transformer', 562 'constructor_args': (1, 1, 1, 1, 2), 563 'input_size': (3, 1, 1), 564 'extra_args': (torch.randn(1, 1, 1),), 565 'slowTest': True 566 } 567] 568 569EXCLUDE_SCRIPT_MODULES = { 570 'test_nn_AdaptiveAvgPool2d_tuple_none', 571 'test_nn_AdaptiveAvgPool3d_tuple_none', 572 'test_nn_AdaptiveMaxPool2d_tuple_none', 573 'test_nn_AdaptiveMaxPool3d_tuple_none', 574 575 # Doesn't use future division, so this is not supported 576 'test_nn_CrossMapLRN2d', 577 # Derivative for aten::_scaled_dot_product_flash_attention_backward is not implemented 578 'test_nn_TransformerDecoderLayer_gelu_activation', 579 'test_nn_TransformerDecoderLayer_relu_activation', 580 'test_nn_TransformerEncoderLayer_gelu_activation', 581 'test_nn_TransformerEncoderLayer_relu_activation', 582 'test_nn_Transformer_multilayer_coder', 583} 584 585script_method_template = ''' 586def forward({}): 587 return {} 588''' 589 590def create_script_module(self, nn_module, constructor_args, *args, **kwargs): 591 def script_module(*args, **kwargs): 592 formals, tensors, actuals = get_script_args(args) 593 594 method_args = ', '.join(['self'] + actuals) 595 call_args_str = ', '.join(actuals) 596 call = f"self.submodule({call_args_str})" 597 script = script_method_template.format(method_args, call) 598 599 submodule_constants = [] 600 if kwargs.get('is_constant'): 601 submodule_constants = ['submodule'] 602 603 # Create module to use the script method 604 class TheModule(torch.jit.ScriptModule): 605 __constants__ = submodule_constants 606 607 def __init__(self) -> None: 608 super().__init__() 609 self.submodule = nn_module(*constructor_args) 610 611 def make_module(script): 612 module = TheModule() 613 # check __repr__ 614 str(module) 615 module.define(script) 616 return module 617 618 module = make_module(script) 619 if self: 620 self.assertExportImportModule(module, tensors) 621 module(*args) 622 # skip type annotate function attributes for now, see: https://github.com/python/mypy/issues/2087 623 create_script_module.last_graph = module.graph # type: ignore[attr-defined] 624 return module 625 return script_module 626 627def check_alias_annotation(method_name, args, kwargs, *, aten_name, func_type='method'): 628 formals, tensors, actuals = get_script_args(args) 629 call = get_call(method_name, func_type, actuals, kwargs) 630 script = script_template.format(', '.join(formals), call) 631 CU = torch.jit.CompilationUnit(script) 632 # to clean up IR 633 torch._C._jit_pass_inline(CU.the_method.graph) 634 torch._C._jit_pass_constant_propagation(CU.the_method.graph) 635 torch._C._jit_check_alias_annotation(CU.the_method.graph, tuple(tensors), aten_name) 636 637def get_nn_module_name_from_kwargs(**kwargs): 638 if 'module_name' in kwargs: 639 return kwargs['module_name'] 640 elif 'fullname' in kwargs: 641 return kwargs['fullname'] 642 elif 'constructor' in kwargs: 643 return kwargs['constructor'].__name__ 644 645def get_nn_mod_test_name(**kwargs): 646 if 'fullname' in kwargs: 647 test_name = kwargs['fullname'] 648 else: 649 test_name = get_nn_module_name_from_kwargs(**kwargs) 650 if 'desc' in kwargs: 651 test_name = f"{test_name}_{kwargs['desc']}" 652 return f'test_nn_{test_name}' 653 654def get_nn_module_class_from_kwargs(**kwargs): 655 name = get_nn_module_name_from_kwargs(**kwargs) 656 index = name.find("_") 657 if index == -1: 658 return name 659 else: 660 return name[0:name.find("_")] 661 662def try_get_nn_module_compiled_mod_and_inputs(*args, **kwargs): 663 name = get_nn_module_name_from_kwargs(**kwargs) 664 665 if 'desc' in kwargs and 'eval' in kwargs['desc']: 666 # eval() is not supported, so skip these tests 667 return 668 669 test_name = name 670 if 'desc' in kwargs: 671 test_name = f"{test_name}_{kwargs['desc']}" 672 test_name = get_nn_mod_test_name(**kwargs) 673 674 if test_name in EXCLUDE_SCRIPT_MODULES: 675 return 676 if 'constructor' in kwargs: 677 nn_module = kwargs['constructor'] 678 else: 679 nn_module = getattr(torch.nn, name) 680 681 if "FunctionalModule" in str(nn_module): 682 return 683 684 if 'constructor_args_fn' in kwargs: 685 constructor_args = kwargs['constructor_args_fn']() 686 else: 687 constructor_args = kwargs.get('constructor_args', ()) 688 689 # Set up inputs from tuple of sizes or constructor fn 690 input_dtype = torch.double 691 if 'input_fn' in kwargs: 692 input = kwargs['input_fn']() 693 if isinstance(input, torch.Tensor): 694 input = (input,) 695 696 if all(tensor.is_complex() for tensor in input): 697 input_dtype = torch.cdouble 698 else: 699 input = (kwargs['input_size'],) 700 701 # Extra parameters to forward() 702 if 'extra_args' in kwargs: 703 input = input + kwargs['extra_args'] 704 705 if 'target_size' in kwargs: 706 input = input + (kwargs['target_size'],) 707 elif 'target_fn' in kwargs: 708 if torch.is_tensor(input): 709 input = (input,) 710 input = input + (kwargs['target_fn'](),) 711 712 args_variable, kwargs_variable = create_input(input, dtype=input_dtype) 713 f_args_variable = deepcopy(unpack_variables(args_variable)) 714 out_var = deepcopy(f_args_variable) 715 716 args, mod = f_args_variable, create_script_module(None, nn_module, constructor_args, *f_args_variable)(*f_args_variable) 717 718 return mod, out_var 719 720 721def get_all_nn_module_tests(): 722 return module_tests + new_module_tests + additional_module_tests 723