1# Owner(s): ["oncall: jit"] 2 3import torch 4 5# This is how we include tests located in test/jit/... 6# They are included here so that they are invoked when you call `test_jit.py`, 7# do not run these test files directly. 8from jit.test_tracer import TestTracer, TestMixTracingScripting # noqa: F401 9from jit.test_recursive_script import TestRecursiveScript # noqa: F401 10from jit.test_type_sharing import TestTypeSharing # noqa: F401 11from jit.test_logging import TestLogging # noqa: F401 12from jit.test_backends import TestBackends, TestBackendsWithCompiler # noqa: F401 13from jit.test_backend_nnapi import TestNnapiBackend # noqa: F401 14from jit.test_list_dict import TestList, TestDict, TestNamedTuple, TestScriptDict, TestScriptList # noqa: F401 15from jit.test_async import TestAsync # noqa: F401 16from jit.test_await import TestAwait # noqa: F401 17from jit.test_data_parallel import TestDataParallel # noqa: F401 18from jit.test_models import TestModels # noqa: F401 19from jit.test_modules import TestModules # noqa: F401 20from jit.test_autodiff import TestAutodiffJit # noqa: F401 21from jit.test_autodiff_subgraph_slicing import TestAutodiffSubgraphSlicing # noqa: F401 22from jit.test_custom_operators import TestCustomOperators # noqa: F401 23from jit.test_graph_rewrite_passes import TestGraphRewritePasses # noqa: F401 24from jit.test_class_type import TestClassType # noqa: F401 25from jit.test_builtins import TestBuiltins, TestTensorBuiltins # noqa: F401 26from jit.test_ignore_context_manager import TestIgnoreContextManager # noqa: F401 27from jit.test_symbolic_shape_analysis import TestSymbolicShapeAnalysis # noqa: F401 28from jit.test_op_decompositions import TestOpDecompositions # noqa: F401 29from jit.test_unsupported_ops import TestUnsupportedOps # noqa: F401 30from jit.test_freezing import TestFreezing, TestFrozenOptimizations, TestMKLDNNReinplacing # noqa: F401 31from jit.test_peephole import TestPeephole # noqa: F401 32from jit.test_alias_analysis import TestAliasAnalysis # noqa: F401 33from jit.test_save_load import TestSaveLoad, TestSaveLoadFlatbuffer # noqa: F401 34from jit.test_save_load_for_op_version import TestSaveLoadForOpVersion # noqa: F401 35from jit.test_module_containers import TestModuleContainers # noqa: F401 36from jit.test_python_bindings import TestPythonBindings # noqa: F401 37from jit.test_python_ir import TestPythonIr # noqa: F401 38from jit.test_functional_blocks import TestFunctionalBlocks # noqa: F401 39from jit.test_remove_mutation import TestRemoveMutation # noqa: F401 40from jit.test_torchbind import TestTorchbind # noqa: F401 41from jit.test_module_interface import TestModuleInterface # noqa: F401 42from jit.test_with import TestWith # noqa: F401 43from jit.test_enum import TestEnum # noqa: F401 44from jit.test_string_formatting import TestStringFormatting # noqa: F401 45from jit.test_profiler import TestProfiler # noqa: F401 46from jit.test_slice import TestSlice # noqa: F401 47from jit.test_ignorable_args import TestIgnorableArgs # noqa: F401 48from jit.test_hooks import TestHooks # noqa: F401 49from jit.test_warn import TestWarn # noqa: F401 50from jit.test_isinstance import TestIsinstance # noqa: F401 51from jit.test_cuda import TestCUDA # noqa: F401 52from jit.test_python_builtins import TestPythonBuiltinOP # noqa: F401 53from jit.test_typing import TestTyping # noqa: F401 54from jit.test_hash import TestHash # noqa: F401 55from jit.test_complex import TestComplex # noqa: F401 56from jit.test_jit_utils import TestJitUtils # noqa: F401 57from jit.test_scriptmod_ann import TestScriptModuleInstanceAttributeTypeAnnotation # noqa: F401 58from jit.test_types import TestTypesAndAnnotation # noqa: F401 59from jit.test_misc import TestMisc # noqa: F401 60from jit.test_upgraders import TestUpgraders # noqa: F401 61from jit.test_pdt import TestPDT # noqa: F401 62from jit.test_tensor_creation_ops import TestTensorCreationOps # noqa: F401 63from jit.test_module_apis import TestModuleAPIs # noqa: F401 64from jit.test_script_profile import TestScriptProfile # noqa: F401 65from jit.test_convert_activation import TestFunctionalToInplaceActivation, TestInplaceToFunctionalActivation # noqa: F401 66from jit.test_parametrization import TestParametrization # noqa: F401 67from jit.test_attr import TestGetDefaultAttr # noqa: F401 68from jit.test_aten_pow import TestAtenPow # noqa: F401 69from jit.test_optimize_for_mobile_preserve_debug_info import TestOptimizeForMobilePreserveDebugInfo # noqa: F401 70from jit.test_union import TestUnion # noqa: F401 71from jit.test_batch_mm import TestBatchMM # noqa: F401 72from jit.test_dtype_analysis import TestDtypeAnalysis, TestDtypeCustomRulesCPU # noqa: F401 73from jit.test_device_analysis import TestDeviceAnalysis # noqa: F401 74from jit.test_dce import TestDCE # noqa: F401 75from jit.test_sparse import TestSparse # noqa: F401 76from jit.test_tensor_methods import TestTensorMethods # noqa: F401 77from jit.test_dataclasses import TestDataclasses # noqa: F401 78from jit.test_generator import TestGenerator # noqa: F401 79 80# Torch 81from torch import Tensor 82from torch._C import TensorType, BoolType, parse_ir, _propagate_shapes 83from torch.autograd import Variable 84from torch.jit.annotations import BroadcastingList2, BroadcastingList3, Any # noqa: F401 85from torch.nn.utils.rnn import PackedSequence 86from torch.testing import FileCheck, make_tensor 87import torch.autograd.profiler 88import torch.cuda 89import torch.jit 90import torch.jit._logging 91import torch.jit.frontend 92import torch.nn as nn 93import torch.nn.functional as F 94 95# Testing utils 96from torch.testing._internal import jit_utils 97from torch.testing._internal.common_jit import check_against_reference 98from torch.testing._internal.common_utils import run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \ 99 suppress_warnings, IS_SANDCASTLE, GRAPH_EXECUTOR, ProfilingMode, TestCase, \ 100 freeze_rng_state, slowTest, TemporaryFileName, \ 101 enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs, \ 102 skipIfCrossRef, skipIfTorchDynamo 103from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, disable_autodiff_subgraph_inlining, \ 104 _trace, do_input_map, get_execution_plan, make_global, \ 105 execWrapper, _inline_everything, _tmp_donotuse_dont_inline_everything, \ 106 RUN_CUDA 107from torch.testing._internal.jit_metaprogramming_utils import ( 108 get_script_args, 109 create_input, unpack_variables, 110 additional_module_tests, EXCLUDE_SCRIPT_MODULES, 111 get_nn_module_name_from_kwargs, get_nn_mod_test_name, script_method_template) 112 113from torch.testing._internal.common_nn import module_tests, new_module_tests, criterion_tests 114 115# For testing truediv in python 2 116from torch.testing._internal.test_module.future_div import div_int_future, div_float_future 117from torch.testing._internal.test_module.no_future_div import div_int_nofuture, div_float_nofuture 118 119# Standard library 120from collections import defaultdict, namedtuple, OrderedDict 121from copy import deepcopy 122from itertools import product 123from textwrap import dedent 124from typing import List, Dict, NamedTuple, Optional, Tuple, Union 125import copy 126import functools 127import inspect 128import io 129import itertools 130import math 131import numpy as np 132import os 133import pickle 134import pickletools 135import random 136import re 137import shutil 138import string 139import sys 140import tempfile 141import types 142import typing 143import unittest 144import warnings 145import zipfile 146import tracemalloc 147 148 149def canonical(graph): 150 return torch._C._jit_pass_canonicalize(graph).str(False) 151 152def LSTMCellF(input, hx, cx, *params): 153 return LSTMCell(input, (hx, cx), *params) 154 155def doAutodiffCheck(testname): 156 # TODO: setting false on test itself is not working 157 if "test_t_" in testname or testname == "test_t": 158 return False 159 160 if GRAPH_EXECUTOR == ProfilingMode.SIMPLE: 161 return False 162 163 if GRAPH_EXECUTOR == ProfilingMode.LEGACY: 164 return True 165 166 167 # these tests are disabled because BailOut nodes 168 # inserted by ProfilingExecutor interfere with 169 # subgraph slicing of Differentiable Graphs 170 test_exceptions = ( 171 # functional 172 'test_nn_dropout', 173 'test_nn_log_softmax', 174 'test_nn_relu', 175 'test_nn_softmax', 176 'test_nn_threshold', 177 'test_nn_lp_pool2d', 178 'test_nn_lp_pool1d', 179 'test_nn_gumbel_softmax_hard', 180 'test_nn_gumbel_softmax', 181 'test_nn_multilabel_soft_margin_loss', 182 'test_nn_batch_norm', 183 'test_nn_max_pool2d_with_indices', 184 # AutogradJitGenerated 185 'test___rdiv___constant', 186 'test___rdiv___scalar_constant', 187 'test_split', 188 'test_split_dim', 189 'test_split_dim_neg0', 190 'test_split_size_list', 191 'test_split_size_list_dim', 192 'test_split_size_list_dim_neg0', 193 'test_split_with_sizes', 194 'test_split_with_sizes_dim', 195 'test_split_with_sizes_dim_neg0', 196 'test_split_with_sizes_size_0', 197 'test_nn_max_pool2d_with_indices', 198 ) 199 200 return testname not in test_exceptions 201 202 203# TODO: enable TE in PE when all tests are fixed 204torch._C._jit_set_texpr_fuser_enabled(GRAPH_EXECUTOR == ProfilingMode.PROFILING) 205torch._C._jit_set_profiling_executor(GRAPH_EXECUTOR != ProfilingMode.LEGACY) 206 207def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): 208 hx, cx = hidden 209 gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh) 210 211 ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 212 ingate = torch.sigmoid(ingate) 213 forgetgate = torch.sigmoid(forgetgate) 214 cellgate = torch.tanh(cellgate) 215 outgate = torch.sigmoid(outgate) 216 217 cy = (forgetgate * cx) + (ingate * cellgate) 218 hy = outgate * torch.tanh(cy) 219 return hy, cy 220 221 222def LSTMCellC(*args, **kwargs): 223 hy, cy = LSTMCellF(*args, **kwargs) 224 return torch.cat((hy, cy)) 225 226 227def LSTMCellS(x, hx, cx, w_ih, w_hh, b_ih, b_hh): 228 gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh 229 ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 230 ingate = torch.sigmoid(ingate) 231 forgetgate = torch.sigmoid(forgetgate) 232 cellgate = torch.tanh(cellgate) 233 outgate = torch.sigmoid(outgate) 234 cy = (forgetgate * cx) + (ingate * cellgate) 235 hy = outgate * torch.tanh(cy) 236 return hy, cy 237 238 239# Code reference: https://github.com/pytorch/translate/blob/master/pytorch_translate/rnn_cell.py#L27:44 240def MiLSTMCell(x, hx, cx, w_ih, w_hh, alpha, beta_i, beta_h, bias): 241 Wx = x.mm(w_ih.t()) 242 Uz = hx.mm(w_hh.t()) 243 # Section 2.1 in https://arxiv.org/pdf/1606.06630.pdf 244 gates = alpha * Wx * Uz + beta_i * Wx + beta_h * Uz + bias 245 # Same as LSTMCell after this point 246 ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 247 ingate = ingate.sigmoid() 248 forgetgate = forgetgate.sigmoid() 249 cellgate = cellgate.tanh() 250 outgate = outgate.sigmoid() 251 cy = (forgetgate * cx) + (ingate * cellgate) 252 hy = outgate * cy.tanh() 253 return hy, cy 254 255 256 257def get_lstm_inputs(device, training=False, seq_length=None): 258 input_shape = (3, 10) if seq_length is None else (seq_length, 3, 10) 259 input = torch.randn(*input_shape, dtype=torch.float, device=device, requires_grad=training) 260 hx = torch.randn(3, 20, dtype=torch.float, device=device, requires_grad=training) 261 cx = torch.randn(3, 20, dtype=torch.float, device=device, requires_grad=training) 262 module = nn.LSTMCell(10, 20).to(device, torch.float) # Just to allocate weights with correct sizes 263 if training: 264 params = tuple(module.parameters()) 265 else: 266 params = tuple(p.requires_grad_(False) for p in module.parameters()) 267 return (input, hx, cx) + params 268 269 270def get_milstm_inputs(device, training=False): 271 minibatch = 3 272 input_size = 10 273 hidden_size = 20 274 x = torch.randn(minibatch, input_size, device=device, dtype=torch.float) 275 hx = torch.randn(minibatch, hidden_size, device=device, dtype=torch.float) 276 cx = torch.randn(minibatch, hidden_size, device=device, dtype=torch.float) 277 278 ih = torch.randn(4 * hidden_size, input_size, device=device, dtype=torch.float, requires_grad=training) 279 hh = torch.randn(4 * hidden_size, hidden_size, device=device, dtype=torch.float, requires_grad=training) 280 alpha = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training) 281 ibeta = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training) 282 hbeta = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training) 283 bias = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training) 284 return x, hx, cx, ih, hh, alpha, ibeta, hbeta, bias 285 286 287def get_fn(file_name, script_path): 288 import importlib.util 289 spec = importlib.util.spec_from_file_location(file_name, script_path) 290 module = importlib.util.module_from_spec(spec) 291 spec.loader.exec_module(module) 292 fn = module.fn 293 return fn 294 295def get_grad_executor(plan_state, diff_graph_idx=None, skip_check=False): 296 if diff_graph_idx is None: 297 nodes = list(plan_state.graph.nodes()) 298 299 if not skip_check: 300 nodes = list(filter(lambda n : n.kind() != "prim::BailOut" and n.kind() != "prim::BailoutTemplate", nodes)) 301 if len(nodes) == 1 or (len(nodes) == 2 and nodes[1].kind() == "prim::TupleConstruct"): 302 pass 303 elif len(nodes) == 2 and nodes[0].kind() == "prim::RequiresGradCheck" and nodes[1].kind() == "prim::If": 304 pass 305 else: 306 raise RuntimeError("Can't get a grad_executor for a non-differentiable graph") 307 grad_executors = list(plan_state.code.grad_executor_states()) 308 return grad_executors[diff_graph_idx or 0] 309 310 311def all_backward_graphs(script_module, diff_graph_idx=None): 312 # Note: for Python 2 the order seems to be unstable 313 ge_state = script_module.get_debug_state() 314 fwd_plan = get_execution_plan(ge_state) 315 grad_executor_state = get_grad_executor(fwd_plan, diff_graph_idx=diff_graph_idx) 316 bwd_plans = list(grad_executor_state.execution_plans.values()) 317 return [p.graph.copy() for p in bwd_plans] 318 319 320def backward_graph(script_module, diff_graph_idx=None, skip_check=False): 321 ge_state = script_module.get_debug_state() 322 fwd_plan = get_execution_plan(ge_state) 323 grad_executor_state = get_grad_executor(fwd_plan, diff_graph_idx=diff_graph_idx, skip_check=skip_check) 324 bwd_plan = get_execution_plan(grad_executor_state) 325 # Running JIT passes requires that we own the graph (with a shared_ptr). 326 # The debug state struct does not own its graph so we make a copy of it. 327 return bwd_plan.graph.copy() 328 329 330# helper function to get sum of List[Tensor] 331def _sum_of_list(tensorlist): 332 s = 0 333 for t in tensorlist: 334 s += t.sum() 335 return s 336 337 338# has to be at top level or Pickle complains 339class FooToPickle(torch.nn.Module): 340 def __init__(self) -> None: 341 super().__init__() 342 self.bar = torch.jit.ScriptModule() 343 344 345class TestJitProfiler(JitTestCase): 346 """ 347 This runs tests that requires setting some global states like torch._C._set_graph_executor_optimize 348 and restore the values afterward, i.e. test_profiler. This is to address the flaky issue in 349 https://github.com/pytorch/pytorch/issues/91483 in which test_profiler was flaky and failed in the 350 middle without the chance to restore torch._C._set_graph_executor_optimize to its original value. 351 This causes issues for all future tests running after. 352 353 Using a separate test class here, so that there is no need to run setup and teardown for all tests 354 in TestJit. 355 """ 356 357 def setUp(self): 358 super().setUp() 359 self.graph_executor_optimize_opt = torch._C._get_graph_executor_optimize() 360 361 def tearDown(self): 362 super().tearDown() 363 # Resetting 364 torch._C._set_graph_executor_optimize( 365 self.graph_executor_optimize_opt 366 ) 367 368 def test_profiler(self): 369 torch._C._set_graph_executor_optimize(False) 370 371 def other_fn(x): 372 return x * 2 373 374 x = torch.rand(3, 4) 375 traced_other_fn = torch.jit.trace(other_fn, x) 376 377 def fn(x): 378 y = traced_other_fn(x) 379 fut = torch.jit._fork(traced_other_fn, x) 380 y = torch.jit._wait(fut) 381 return y 382 383 traced_fn = torch.jit.trace(fn, x) 384 with torch.autograd.profiler.profile() as prof: 385 traced_fn(x) 386 387 # expecting to see other_fn TS function call 388 # with cpu time >= mul cpu time and 389 # a forked other_fn 390 391 mul_events = defaultdict(int) 392 other_fn_events = defaultdict(int) 393 for e in prof.function_events: 394 if e.name == "aten::mul": 395 self.assertTrue(e.thread not in mul_events) 396 mul_events[e.thread] = e.time_range.elapsed_us() 397 elif e.name == "other_fn": 398 self.assertTrue(e.thread not in other_fn_events) 399 other_fn_events[e.thread] = e.time_range.elapsed_us() 400 401 self.assertTrue(len(mul_events) == 2) 402 self.assertTrue(len(other_fn_events) == 2) 403 404 for thread, mul_time in mul_events.items(): 405 self.assertTrue(thread in other_fn_events) 406 self.assertTrue(other_fn_events[thread] >= mul_time) 407 408 409class TestJit(JitTestCase): 410 @unittest.skip("Requires a lot of RAM") 411 def test_big(self): 412 m = torch.jit.ScriptModule() 413 gig = int(1024 * 1024 * 1024 / 4) 414 # a small tensor in the first 4GB 415 m.v0 = nn.Parameter(torch.full((2,), 1, dtype=torch.float)) 416 # a large tensor in the first 4GB that ends outside of it 417 m.v1 = nn.Parameter(torch.full((5, gig), 2, dtype=torch.float)) 418 # a small tensor in >4GB space 419 m.v2 = nn.Parameter(torch.full((2,), 3, dtype=torch.float)) 420 # s large tensor in the > 4GB space 421 m.v3 = nn.Parameter(torch.full((5, gig), 4, dtype=torch.float)) 422 423 m2 = self.getExportImportCopy(m) 424 425 self.assertEqual(tuple(m.parameters()), tuple(m2.parameters())) 426 427 def test_inferred_as_tensor(self): 428 with self.assertRaisesRegex(RuntimeError, "Inferred the value for argument 'dim' to be of type 'Tensor' " 429 "because it was not annotated with an explicit type"): 430 @torch.jit.script 431 def dot(points, query, dim): 432 return (points * query).sum(dim) 433 434 def test_constants_pkl(self): 435 # This test asserts that the serialization archive includes a `constants.pkl` 436 # file. This file is used by `torch.load` to determine whether a zip file 437 # is a normal eager-mode serialization zip or a jit serialization zip. If 438 # you are deleting `constants.pkl`, make sure to update `torch.serialization.load` 439 # so it is still able to figure out which is which. 440 @torch.jit.script 441 def fn(x): 442 return x 443 444 buf = io.BytesIO() 445 torch.jit.save(fn, buf) 446 buf.seek(0) 447 448 files = zipfile.ZipFile(buf).filelist 449 self.assertTrue(any('archive/constants.pkl' == f.filename for f in files)) 450 451 def test_script_fn_pkl(self): 452 with self.assertRaisesRegex(pickle.PickleError, "ScriptFunction cannot be pickled"): 453 454 @torch.jit.script 455 def fn(x: torch.Tensor) -> torch.Tensor: 456 return x 457 458 pkl_fn = pickle.dumps(fn, protocol=0) 459 460 def test_restore_device(self): 461 class M(torch.jit.ScriptModule): 462 def __init__(self, cpu_device_str): 463 super().__init__() 464 self.p0 = nn.Parameter(torch.tensor([0.3], dtype=torch.float, 465 device=cpu_device_str)) 466 self.b0 = torch.tensor([0.9], dtype=torch.float, 467 device=cpu_device_str) 468 469 # main purpose is checking map_location works 470 m = M("cpu") 471 m2 = self.getExportImportCopy(m) 472 self.assertEqual(tuple(m.parameters()), tuple(m2.parameters())) 473 self.assertEqual(tuple(m.buffers()), tuple(m2.buffers())) 474 self.assertFalse(m2.p0.is_cuda) 475 self.assertFalse(m2.b0.is_cuda) 476 477 @unittest.skipIf(not RUN_CUDA, "restore device requires CUDA") 478 def test_restore_device_cuda(self): 479 class MyModule(torch.jit.ScriptModule): 480 def __init__(self) -> None: 481 super().__init__() 482 self.b0 = nn.Buffer(torch.randn(1, 3)) 483 self.p0 = nn.Parameter(torch.randn(2, 3)) 484 485 @torch.jit.script_method 486 def forward(self, x): 487 return x + self.b0 + self.p0 488 489 m = MyModule() 490 m.cuda(torch.cuda.device_count() - 1) 491 cuda_device_str = 'cuda:' + str(torch.cuda.device_count() - 1) 492 493 self.assertTrue(m.p0.is_cuda) 494 self.assertTrue(m.b0.is_cuda) 495 496 # restore to the saved devices 497 m2 = self.getExportImportCopy(m) 498 self.assertEqual(tuple(m.parameters()), tuple(m2.parameters())) 499 self.assertEqual(tuple(m.buffers()), tuple(m2.buffers())) 500 self.assertEqual(str(m2.p0.device), cuda_device_str) 501 self.assertEqual(str(m2.b0.device), cuda_device_str) 502 503 # restore all to cpu using string 504 cpu_device_str = 'cpu' 505 m3 = self.getExportImportCopy(m, map_location=cpu_device_str) 506 self.assertEqual(str(m3.p0.device), cpu_device_str) 507 self.assertEqual(str(m3.b0.device), cpu_device_str) 508 509 # restore all to first gpu using device 510 m4 = self.getExportImportCopy( 511 m3, map_location=torch.device('cuda:0')) 512 self.assertEqual(str(m4.p0.device), 'cuda:0') 513 self.assertEqual(str(m4.b0.device), 'cuda:0') 514 515 # compute and compare the results 516 input = torch.rand(2, 3).cuda(torch.cuda.device_count() - 1) 517 origin_result = m(input) 518 self.assertEqual(origin_result, m2(input)) 519 self.assertEqual(origin_result, m3(input.cpu())) 520 self.assertEqual(origin_result, m4(input.cuda(0))) 521 522 def test_trace_retains_train(self): 523 class M(torch.nn.Module): 524 def forward(self, x): 525 return x 526 m = M() 527 m.eval() 528 tm = torch.jit.trace(m, (torch.rand(3))) 529 self.assertEqual(tm.training, m.training) 530 531 @unittest.skipIf(not RUN_CUDA, "restore device requires CUDA") 532 def test_restore_shared_storage_on_cuda(self): 533 class Foo(torch.jit.ScriptModule): 534 def __init__(self) -> None: 535 super().__init__() 536 whole_tensor = torch.randn(4, 5, dtype=torch.float, device='cpu') 537 self.p0 = nn.Parameter(whole_tensor.narrow(0, 0, 1)) 538 self.b0 = nn.Buffer(whole_tensor.narrow(0, 3, 1)) 539 540 m = Foo() 541 m2 = self.getExportImportCopy(m, map_location=torch.device('cuda:0')) 542 self.assertEqual(tuple(m.parameters()), tuple(m2.parameters())) 543 self.assertEqual(tuple(m.buffers()), tuple(m2.buffers())) 544 self.assertTrue(m2.p0.is_cuda) 545 self.assertTrue(m2.b0.is_cuda) 546 self.assertTrue(m2.p0.is_shared()) 547 self.assertTrue(m2.b0.is_shared()) 548 self.assertEqual(m2.b0.storage().data_ptr(), m2.p0.storage().data_ptr()) 549 550 def test_add_relu_fusion(self): 551 class M(torch.nn.Module): 552 def __init__(self, relu_op): 553 super().__init__() 554 self.relu_op = relu_op 555 556 def forward(self, a, b, c): 557 tmp = torch.add(a, b) 558 x = self.relu_op(tmp) 559 d = torch.add(a, c) 560 return x + d 561 a = torch.rand((7, 11)) 562 a = a * -10 563 a = a + 5 564 b = torch.rand((7, 11)) 565 c = torch.rand((7, 11)) 566 m = torch.jit.script(M(torch.relu)) 567 orig_res = m(a, b, c) 568 torch._C._jit_pass_fuse_add_relu(m.graph) 569 buffer = io.BytesIO() 570 torch.jit.save(m, buffer) 571 buffer.seek(0) 572 m = torch.jit.load(buffer) 573 new_res = m(a, b, c) 574 FileCheck().check_not("aten::relu(") \ 575 .check("aten::_add_relu(") \ 576 .run(m.graph) 577 torch.testing.assert_close(orig_res, new_res) 578 579 # add, relu_ 580 a = torch.rand((7, 11)) 581 a = a * -10 582 a = a + 5 583 b = torch.rand((7, 11)) 584 c = torch.rand((7, 11)) 585 m = torch.jit.script(M(torch.relu_)) 586 orig_res = m(a, b, c) 587 torch._C._jit_pass_fuse_add_relu(m.graph) 588 buffer = io.BytesIO() 589 torch.jit.save(m, buffer) 590 buffer.seek(0) 591 m = torch.jit.load(buffer) 592 new_res = m(a, b, c) 593 FileCheck().check_not("aten::relu_(") \ 594 .check("aten::_add_relu(") \ 595 .run(m.graph) 596 torch.testing.assert_close(orig_res, new_res) 597 598 class Madd_(torch.nn.Module): 599 def __init__(self, relu_op): 600 super().__init__() 601 self.relu_op = relu_op 602 603 def forward(self, a, b): 604 x = a.add_(b) 605 x = self.relu_op(x) 606 return x 607 608 # add_, relu_ 609 a = torch.rand((7, 11)) 610 a = a * -10 611 a = a + 5 612 b = torch.rand((7, 11)) 613 # Because in place add_ will overwrite a 614 a_copy = a.clone() 615 m = torch.jit.script(Madd_(torch.relu_)) 616 orig_res = m(a, b) 617 torch._C._jit_pass_fuse_add_relu(m.graph) 618 buffer = io.BytesIO() 619 torch.jit.save(m, buffer) 620 buffer.seek(0) 621 m = torch.jit.load(buffer) 622 new_res = m(a_copy, b) 623 FileCheck().check_not("aten::add_(") \ 624 .check_not("aten::relu_(") \ 625 .check("aten::_add_relu_(") \ 626 .run(m.graph) 627 torch.testing.assert_close(orig_res, new_res) 628 # Since _add_relu_ does inplace mutation ensure 629 # a_copy is modified 630 torch.testing.assert_close(orig_res, a_copy) 631 632 class Madd_out(torch.nn.Module): 633 def __init__(self, relu_op): 634 super().__init__() 635 self.relu_op = relu_op 636 637 def forward(self, a, b): 638 x = torch.add(a, b, out=a) 639 x = self.relu_op(x) 640 return x 641 a = torch.rand((7, 11)) 642 a = a * -10 643 a = a + 5 644 b = torch.rand((7, 11)) 645 646 # add_out, relu_ 647 a = torch.rand((7, 11)) 648 a = a * -10 649 a = a + 5 650 b = torch.rand((7, 11)) 651 # Because in place add_ will overwrite a 652 a_copy = a.clone() 653 m = torch.jit.script(Madd_out(torch.relu_)) 654 orig_res = m(a, b) 655 torch._C._jit_pass_fuse_add_relu(m.graph) 656 buffer = io.BytesIO() 657 torch.jit.save(m, buffer) 658 buffer.seek(0) 659 m = torch.jit.load(buffer) 660 new_res = m(a_copy, b) 661 FileCheck().check_not("aten::add(") \ 662 .check_not("aten::relu_(") \ 663 .check("aten::_add_relu(") \ 664 .run(m.graph) 665 torch.testing.assert_close(orig_res, new_res) 666 # Since _add_relu_ with out=a does inplace mutation ensure 667 # a_copy is modified 668 torch.testing.assert_close(orig_res, a_copy) 669 670 def test_repeat_interleave_script(self): 671 def fn(input: torch.Tensor, repeats: torch.Tensor) -> torch.Tensor: 672 output = input.repeat_interleave(repeats) 673 return output 674 fn_scripted = torch.jit.script(fn) 675 676 input = torch.tensor([5, 7], dtype=torch.int64) 677 repeats = torch.tensor([3, 6], dtype=torch.int64) 678 679 output = fn(input, repeats) 680 output_scripted = fn_scripted(input, repeats) 681 self.assertEqual(output_scripted, output) 682 683 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "Simple executor doesn't have shape information") 684 def test_peephole_optimize_shape_ops(self): 685 def test_input(func, input, result): 686 # if result == 2 we will trigger a bailout and 687 # the unprofiled graph should return the correct result 688 self.assertEqual(func(input, profile_and_replay=True), result) 689 gre = func.graph_for(input) 690 FileCheck().check_not("prim::If").run(gre) 691 692 def test_dim(): 693 @torch.jit.script 694 def func(x): 695 if x.dim() == 1: 696 return 1 697 else: 698 return 2 699 700 test_input(func, torch.tensor([0.5]), 1) 701 test_input(func, torch.tensor([[0.5]]), 2) 702 test_dim() 703 704 def test_size_index(): 705 @torch.jit.script 706 def func(x): 707 if x.size(0) == 1: 708 return 1 709 else: 710 return 2 711 712 test_input(func, torch.rand([1, 2]), 1) 713 test_input(func, torch.rand([1, 3]), 1) 714 715 @torch.jit.script 716 def neg_index(x): 717 if x.size(-2) == 1: 718 return 1 719 else: 720 return 2 721 722 test_input(neg_index, torch.rand([1, 2]), 1) 723 test_input(neg_index, torch.rand([1, 3]), 1) 724 725 if GRAPH_EXECUTOR == ProfilingMode.PROFILING: 726 test_size_index() 727 728 def test_dtype(): 729 @torch.jit.script 730 def func(x): 731 if x.dtype == torch.float32: 732 return 1 733 else: 734 return 2 735 736 test_input(func, torch.tensor(0.5, dtype=torch.float32), 1) 737 test_input(func, torch.tensor(0.5, dtype=torch.int64), 2) 738 test_dtype() 739 740 def test_is_floating_poiint(): 741 @torch.jit.script 742 def func(x): 743 if x.is_floating_point(): 744 return 1 745 else: 746 return 2 747 748 test_input(func, torch.tensor(0.5, dtype=torch.float32), 1) 749 test_input(func, torch.tensor(0.5, dtype=torch.int64), 2) 750 test_is_floating_poiint() 751 752 def test_device(): 753 @torch.jit.script 754 def func_1(x): 755 if x.device == torch.device('cuda:0'): 756 a = 0 757 else: 758 a = 1 759 return a 760 761 @torch.jit.script 762 def func_2(x): 763 if x.is_cuda: 764 a = 0 765 else: 766 a = 1 767 return a 768 769 test_input(func_1, torch.tensor(0.5), 1) 770 test_input(func_2, torch.tensor(0.5), 1) 771 772 if RUN_CUDA: 773 test_input(func_1, torch.tensor(0.5, device="cuda:0"), 0) 774 test_input(func_2, torch.tensor(0.5, device="cuda:0"), 0) 775 776 test_device() 777 778 def test_attrs(self): 779 def foo(x): 780 return ( 781 # x.dtype, TODO: dtype long -> instance conversion 782 x.device, 783 x.shape, 784 x.is_cuda, 785 x.is_mkldnn, 786 x.is_quantized, 787 x.requires_grad, 788 x.T, 789 x.mT, 790 x.H, 791 x.mH 792 # x.layout TODO: layout long -> instance conversion 793 ) 794 795 scripted = torch.jit.script(foo) 796 x = torch.rand(3, 4) 797 self.assertEqual(scripted(x), foo(x)) 798 799 def test_layout(self): 800 @torch.jit.script 801 def check(x, y): 802 return x.layout == y.layout 803 804 x = torch.rand(3, 4) 805 y = torch.rand(3, 4) 806 807 self.assertTrue(check(x, y)) 808 809 def test_matrix_transpose(self): 810 @torch.jit.script 811 def check(x): 812 return torch.equal(x.mT, x.transpose(-2, -1)) 813 814 x = torch.rand(3, 4) 815 self.assertTrue(check(x)) 816 817 def test_transpose(self): 818 @torch.jit.script 819 def check(x): 820 return torch.equal(x.T, x.t()) 821 822 x = torch.rand(3, 4) 823 self.assertTrue(check(x)) 824 825 def test_matrix_conj_transpose(self): 826 @torch.jit.script 827 def check(x): 828 return torch.equal(x.mH, x.transpose(-2, -1).conj()) 829 830 x = torch.rand(3, 4) 831 self.assertTrue(check(x)) 832 833 x = make_tensor((3, 4), device="cpu", dtype=torch.complex64) 834 self.assertTrue(check(x)) 835 836 def test_conj_transpose(self): 837 @torch.jit.script 838 def check(x): 839 return torch.equal(x.H, x.t().conj()) 840 841 x = torch.rand(3, 4) 842 self.assertTrue(check(x)) 843 844 x = make_tensor((3, 4), device="cpu", dtype=torch.complex64) 845 self.assertTrue(check(x)) 846 847 def test_T_mT_H_mH(self): 848 def T(x): 849 return x.mT 850 851 def mT(x): 852 return x.mT 853 854 def H(x): 855 return x.H 856 857 def mH(x): 858 return x.mH 859 860 x = torch.rand(3, 4) 861 y = make_tensor((3, 4), device="cpu", dtype=torch.complex64) 862 863 self.checkScript(T, (x, )) 864 self.checkScript(mT, (x, )) 865 self.checkScript(H, (x, )) 866 self.checkScript(mH, (x, )) 867 self.checkScript(T, (y, )) 868 self.checkScript(mT, (y, )) 869 self.checkScript(H, (y, )) 870 self.checkScript(mH, (y, )) 871 872 def test_nn_conv(self): 873 class Mod(nn.Module): 874 def __init__(self, conv): 875 super().__init__() 876 self.conv = conv 877 878 def forward(self, input): 879 return self.conv(input) 880 881 inputs = [ 882 # Conv 883 (Mod(nn.Conv1d(16, 33, 3, stride=2)), torch.randn(20, 16, 5)), 884 (Mod(nn.Conv2d(16, 33, 3, stride=2)), torch.randn(20, 16, 5, 10)), 885 (Mod(nn.Conv3d(16, 33, 3, stride=2)), torch.randn(20, 16, 3, 5, 4)), 886 # ConvTransposed 887 (Mod(nn.ConvTranspose1d(16, 33, 3, stride=2)), torch.randn(20, 16, 5)), 888 (Mod(nn.ConvTranspose2d(16, 33, 3, stride=2)), torch.randn(20, 16, 5, 10)), 889 (Mod(nn.ConvTranspose3d(16, 33, 3, stride=2)), torch.randn(20, 16, 3, 5, 4)), 890 ] 891 892 for m, inp in inputs: 893 self.checkModule(m, (inp,)) 894 895 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, 'Not implemented for Simple or Legacy') 896 def test_debug_flush_compilation_cache(self): 897 def foo(x): 898 return x + 2 899 900 class Mod(nn.Module): 901 def forward(self, t): 902 return t + 2 903 904 m = torch.jit.script(Mod()) 905 x = torch.rand(1, 10) 906 907 with enable_profiling_mode_for_profiling_tests(): 908 jitted = self.checkScript(foo, (x,)) 909 # shouldn't throw 910 states = jitted.get_debug_state() 911 912 # after flushing there shouldn't be 913 # no opt plan 914 jitted._debug_flush_compilation_cache() 915 with self.assertRaisesRegex(RuntimeError, "INTERNAL ASSERT FAILED"): 916 states = jitted.get_debug_state() 917 918 NUM_RUNS = 1 919 with num_profiled_runs(NUM_RUNS): 920 m(x) 921 m(x) 922 fwd = m._c._get_method("forward") 923 states = m.get_debug_state() 924 925 # after flushing there shouldn't be 926 # no opt plan 927 fwd._debug_flush_compilation_cache() 928 with self.assertRaisesRegex(RuntimeError, "INTERNAL ASSERT FAILED"): 929 states = m.get_debug_state() 930 931 def test_numel(self): 932 @torch.jit.script 933 def get_numel_script(x): 934 return x.numel() 935 936 x = torch.rand(3, 4) 937 numel = get_numel_script(x) 938 self.assertEqual(numel, x.numel()) 939 940 def test_element_size(self): 941 @torch.jit.script 942 def get_element_size_script(x): 943 return x.element_size() 944 945 x = torch.rand(3, 4) 946 element_size = get_element_size_script(x) 947 self.assertEqual(element_size, x.element_size()) 948 949 def test_Sequential(self): 950 class Seq(nn.Module): 951 def __init__(self) -> None: 952 super().__init__() 953 self.seq = nn.Sequential(nn.Linear(10, 20), nn.Linear(20, 30)) 954 955 @torch.jit.script_method 956 def forward(self, x): 957 for l in self.seq: 958 x = l(x) 959 return x 960 961 m = torch.jit.script(Seq()) 962 assert m.graph # ensure jit was able to compile 963 964 def test_ModuleList(self): 965 class Mod(nn.Module): 966 def __init__(self) -> None: 967 super().__init__() 968 self.model = nn.ModuleList([nn.Linear(10, 10) for _ in range(10)]) 969 self.model += (nn.Linear(10, 20),) 970 self.model.append(nn.Linear(20, 30)) 971 self.model.extend([nn.Linear(30, 40), nn.Linear(40, 50)]) 972 973 def forward(self, v): 974 for m in self.model: 975 v = m(v) 976 return v 977 978 m = torch.jit.script(Mod()) 979 assert m.graph # ensure jit was able to compile 980 981 def test_disabled(self): 982 torch.jit._state.disable() 983 try: 984 def f(x, y): 985 return x + y 986 987 self.assertIs(torch.jit.trace(f, (torch.randn(2, 2), torch.randn(2, 2))), f) 988 self.assertIs(torch.jit.script(f), f) 989 990 class MyModule(torch.jit.ScriptModule): 991 @torch.jit.script_method 992 def method(self, x): 993 return x 994 995 # XXX: Unfortunately ScriptModule won't simply become Module now, 996 # because that requires disabling the JIT at startup time, which 997 # we can't do in here. 998 # We need to or those two conditions to make it work with all versions of Python 999 self.assertTrue(inspect.ismethod(MyModule.method) or inspect.isfunction(MyModule.method)) 1000 finally: 1001 torch.jit._state.enable() 1002 1003 def test_train_eval(self): 1004 class Sub(nn.Module): 1005 def forward(self, input): 1006 if self.training: 1007 return input 1008 else: 1009 return -input 1010 1011 class MyModule(torch.jit.ScriptModule): 1012 def __init__(self, module): 1013 super().__init__() 1014 self.module = module 1015 1016 @torch.jit.script_method 1017 def forward(self, input): 1018 return self.module(input) + 1 1019 1020 m = MyModule(Sub()) 1021 input = torch.rand(3, 4) 1022 self.assertEqual(input + 1, m(input)) 1023 m.eval() 1024 self.assertEqual(-input + 1, m(input)) 1025 1026 # test batchnorm and dropout train/eval 1027 input = torch.randn(6, 10) 1028 batchnorm = nn.BatchNorm1d(10) 1029 dropout = nn.Dropout(p=0.2) 1030 1031 m_batchnorm = MyModule(batchnorm) 1032 self.assertEqual(batchnorm(input) + 1, m_batchnorm(input)) 1033 batchnorm.eval() 1034 m_batchnorm.eval() 1035 self.assertEqual(batchnorm(input) + 1, m_batchnorm(input)) 1036 1037 m_dropout = MyModule(dropout) 1038 dropout.eval() 1039 m_dropout.eval() 1040 self.assertEqual(dropout(input) + 1, m_dropout(input)) 1041 1042 def test_nn_lp_pool2d(self): 1043 class Mod(torch.nn.Module): 1044 def __init__(self) -> None: 1045 super().__init__() 1046 self.l = torch.nn.LPPool2d(2, 3) 1047 self.n = torch.nn.LPPool2d(2, (7, 1)) 1048 1049 def forward(self, x): 1050 return (self.l(x), 1051 self.n(x), 1052 torch.nn.functional.lp_pool2d(x, float(2), 3), 1053 torch.nn.functional.lp_pool2d(x, 2, 3), 1054 torch.nn.functional.lp_pool2d(x, float(2), (7, 1))) 1055 1056 self.checkModule(Mod(), (torch.rand(1, 3, 7, 7),)) 1057 1058 def test_nn_lp_pool1d(self): 1059 class Mod(torch.nn.Module): 1060 def __init__(self) -> None: 1061 super().__init__() 1062 self.l = torch.nn.LPPool1d(2, 3) 1063 self.n = torch.nn.LPPool1d(2, 7) 1064 1065 def forward(self, x): 1066 return (self.l(x), 1067 self.n(x), 1068 torch.nn.functional.lp_pool1d(x, float(2), 3), 1069 torch.nn.functional.lp_pool1d(x, 2, 3), 1070 torch.nn.functional.lp_pool1d(x, float(2), 7)) 1071 1072 self.checkModule(Mod(), (torch.rand(1, 3, 7),)) 1073 1074 def test_nn_padding_functional(self): 1075 class Mod(nn.Module): 1076 def __init__(self, *pad): 1077 super().__init__() 1078 self.pad = pad 1079 1080 def forward(self, x): 1081 return F.pad(x, self.pad, mode='constant', value=3.5) 1082 1083 inputs = [ 1084 (Mod(1, 2), torch.randn(1, 3, 4)), # 1D 1085 (Mod(1, 2, 3, 4), torch.randn(1, 3, 4)), # 2D 1086 (Mod(1, 2, 3, 4, 5, 6), torch.randn(1, 3, 4)), # 3D 1087 ] 1088 1089 for m, inp in inputs: 1090 self.checkModule(m, (inp,)) 1091 1092 def test_nn_padding(self): 1093 class Mod(nn.Module): 1094 def __init__(self, padding): 1095 super().__init__() 1096 self.padding = padding 1097 1098 def forward(self, input): 1099 return self.padding(input) 1100 1101 inputs = [ 1102 (Mod(nn.ConstantPad1d(2, 3.5)), torch.randn(1, 2, 4)), 1103 (Mod(nn.ConstantPad2d(2, 3.5)), torch.randn(1, 2, 2)), 1104 (Mod(nn.ConstantPad3d(3, 3.5)), torch.randn(16, 3, 10, 20, 30)), 1105 (Mod(nn.ReflectionPad1d(2)), torch.arange(8, dtype=torch.float).reshape(1, 2, 4)), 1106 (Mod(nn.ReflectionPad2d(2)), torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3)), 1107 (Mod(nn.ReflectionPad3d(3)), torch.randn(16, 3, 8, 32, 48)), 1108 (Mod(nn.ReplicationPad1d(2)), torch.arange(8, dtype=torch.float).reshape(1, 2, 4)), 1109 (Mod(nn.ReplicationPad2d(2)), torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3)), 1110 (Mod(nn.ReplicationPad3d(3)), torch.randn(16, 3, 8, 32, 48)), 1111 (Mod(nn.ZeroPad2d(2)), torch.randn(1, 1, 3, 3)) 1112 ] 1113 1114 for m, inp in inputs: 1115 self.checkModule(m, (inp,)) 1116 1117 def test_script_autograd_grad(self): 1118 def test_simple_grad(x, y): 1119 # type: (Tensor, Tensor) -> List[Optional[Tensor]] 1120 z = x + 2 * y + x * y 1121 return torch.autograd.grad((z.sum(), ), (x, y)) 1122 1123 def test_simple_grad_with_grad_outputs(x, y): 1124 # type: (Tensor, Tensor) -> List[Optional[Tensor]] 1125 z = x + 2 * y + x * y 1126 grad_outputs = torch.jit.annotate(List[Optional[torch.Tensor]], [torch.ones((2, 2)), ]) 1127 return torch.autograd.grad((z, ), (x, y), grad_outputs) 1128 1129 def test_one_output_not_requires_grad(x, y): 1130 # type: (Tensor, Tensor) -> List[Optional[Tensor]] 1131 z = 2 * y + y 1132 return torch.autograd.grad((z.sum(),), (x, y), allow_unused=True) 1133 1134 def test_retain_graph(x, y): 1135 # type: (Tensor, Tensor) -> None 1136 z = x + 2 * y + x * y 1137 torch.autograd.grad((z.sum(), ), (x, y), retain_graph=True) 1138 torch.autograd.grad((z.sum(), ), (x, y)) 1139 1140 x = torch.randn(2, 2, requires_grad=True) 1141 y = torch.randn(2, 2, requires_grad=True) 1142 self.checkScript(test_simple_grad, (x, y), inputs_requires_grad=True) 1143 self.checkScript(test_simple_grad_with_grad_outputs, (x, y), inputs_requires_grad=True) 1144 self.checkScript(test_one_output_not_requires_grad, (x, y), inputs_requires_grad=True) 1145 self.checkScript(test_retain_graph, (x, y), inputs_requires_grad=True) 1146 1147 def test_script_backward(self): 1148 def checkBackwardScript(fn, inputs): 1149 scripted_fn = torch.jit.script(fn) 1150 FileCheck().check("torch.autograd.backward").run(scripted_fn.code) 1151 recording_inputs = do_input_map(lambda t: t.detach().requires_grad_(), inputs) 1152 1153 fn(*inputs) 1154 scripted_fn(*recording_inputs) 1155 1156 for inp1, inp2 in zip(inputs, recording_inputs): 1157 self.assertEqual(inp1.grad, inp2.grad) 1158 1159 def test_tensor_backward(input): 1160 # type: (Tensor) -> None 1161 output = torch.relu(input) 1162 output = output.softmax(0) 1163 sum_out = output.sum() 1164 sum_out.backward() 1165 1166 def test_torch_autograd_backward(input): 1167 # type: (Tensor) -> None 1168 output = torch.relu(input) 1169 output = output.softmax(0) 1170 torch.autograd.backward(output.sum()) 1171 1172 def test_torch_autograd_backward_with_grad_tensors(input): 1173 # type: (Tensor) -> None 1174 output = torch.relu(input) 1175 output = output.softmax(0) 1176 grad_outputs = torch.jit.annotate(List[Optional[torch.Tensor]], [torch.ones((2, 2)), ]) 1177 torch.autograd.backward((output,), grad_outputs) 1178 1179 inp = torch.randn(2, 2, requires_grad=True) 1180 checkBackwardScript(test_tensor_backward, (inp,)) 1181 checkBackwardScript(test_torch_autograd_backward, (inp,)) 1182 checkBackwardScript(test_torch_autograd_backward_with_grad_tensors, (inp,)) 1183 1184 def test_script_backward_twice(self): 1185 def checkBackwardTwiceScript(fn, inputs, retain_graph_=False): 1186 class jit_profiling_executor_false: 1187 def __enter__(self): 1188 torch._C._jit_set_profiling_executor(False) 1189 1190 def __exit__(self, *args): 1191 torch._C._jit_set_profiling_executor(GRAPH_EXECUTOR != ProfilingMode.LEGACY) 1192 1193 with jit_profiling_executor_false(), torch.jit.optimized_execution(True): 1194 scripted_fn = torch.jit.script(fn, inputs) 1195 FileCheck().check("prim::DifferentiableGraph").run(scripted_fn.graph_for(*inputs)) 1196 1197 result = scripted_fn(*inputs) 1198 result.sum().backward(retain_graph=retain_graph_) 1199 if not retain_graph_: 1200 self.assertRaisesRegex(RuntimeError, 'Specify retain_graph=True', 1201 lambda: result.sum().backward()) 1202 else: 1203 result.sum().backward() 1204 1205 def test_script_backward_twice_with_saved_values(input1, input2): 1206 # type: (Tensor, Tensor) -> Tensor 1207 tmp1 = torch.mul(input1, input2) 1208 tmp2 = torch.abs(tmp1) 1209 if torch.equal(input1, input2): 1210 tmp2 = torch.acos(tmp2) 1211 else: 1212 tmp2 = torch.atan(tmp2) 1213 result = torch.add(tmp2, input2) 1214 return result 1215 1216 inp1 = torch.randn(2, 2, requires_grad=True) 1217 inp2 = torch.randn(2, 2, requires_grad=True) 1218 checkBackwardTwiceScript(test_script_backward_twice_with_saved_values, (inp1, inp2), False) 1219 checkBackwardTwiceScript(test_script_backward_twice_with_saved_values, (inp1, inp2), True) 1220 1221 def test_diff_subgraph_clones_constants(self): 1222 @torch.jit.script 1223 def f(x, y): 1224 return x + x + y + x + y + x + y + x + y + x 1225 1226 def count_constants(graph): 1227 return sum(node.kind() == 'prim::Constant' for node in graph.nodes()) 1228 1229 graph = f.graph.copy() 1230 self.run_pass('cse', graph) 1231 self.run_pass('create_autodiff_subgraphs', graph) 1232 nodes = list(graph.nodes()) 1233 self.assertEqual(count_constants(graph), 1) 1234 self.assertEqual(count_constants(nodes[1].g('Subgraph')), 1) 1235 1236 # TODO: adapt this test to check that GraphExecutor treats them differently 1237 @unittest.skip("Need to be adjusted to Graph Executor") 1238 def test_arg_configurations(self): 1239 """Different arg configurations should trigger different traces""" 1240 x = Variable(torch.FloatTensor(4, 4).uniform_()) 1241 x_double = Variable(x.data.double()) 1242 x_grad = Variable(x.data.clone(), requires_grad=True) 1243 y = Variable(torch.randn(4)) 1244 1245 configurations = [ 1246 (x,), 1247 (x_double,), 1248 (x_grad,), 1249 (y,), 1250 ([x, x],), 1251 ([x, y],), 1252 ] 1253 if torch.cuda.is_available(): 1254 x_cuda = Variable(x.data.cuda()) 1255 configurations += [ 1256 (x_cuda,), 1257 ([x, x_cuda],), 1258 ([x_cuda, x],), 1259 ([[x_cuda, x]],), 1260 ] 1261 if torch.cuda.device_count() > 1: 1262 x_cuda_1 = Variable(x.data.cuda(1)) 1263 configurations += [ 1264 (x_cuda_1,), 1265 ([x_cuda, x_cuda_1],), 1266 ] 1267 1268 @torch.jit.compile(nderivs=0) 1269 def fn(*args): 1270 in_vars, _ = torch._C._jit_flatten(args) 1271 return in_vars[0] + 1 1272 1273 for i, config in enumerate(configurations): 1274 self.assertFalse(fn.has_trace_for(*config)) 1275 fn(*config) 1276 self.assertTrue(fn.has_trace_for(*config)) 1277 for unk_config in configurations[i + 1:]: 1278 self.assertFalse(fn.has_trace_for(*unk_config)) 1279 self.assertEqual(fn.hits, 0) 1280 1281 def test_torch_sum(self): 1282 def fn(x): 1283 return torch.sum(x) 1284 1285 def fn1(x, dim: int): 1286 return torch.sum(x, dim) 1287 1288 x = torch.randn(3, 4) 1289 self.checkScript(fn, (x, )) 1290 self.checkScript(fn1, (x, 1, )) 1291 self.checkScript(fn1, (x, 0, )) 1292 1293 def test_cse(self): 1294 x = torch.tensor([0.4, 0.3], requires_grad=True) 1295 y = torch.tensor([0.7, 0.5], requires_grad=True) 1296 1297 def fn(x, y): 1298 w = (x + y) * (x + y) * (x + y) 1299 t = torch.tanh(w) + torch.tanh(w) 1300 z = (x + y) * (x + y) * (x + y) + t 1301 return z 1302 1303 g, _ = torch.jit._get_trace_graph(fn, (x, y)) 1304 self.run_pass('cse', g) 1305 do_exactly = True 1306 FileCheck().check_count("add", 1).check_count("mul", 2, do_exactly) \ 1307 .check_count("tanh", 1, do_exactly).check_count("add", 2, do_exactly).check_next("return") \ 1308 .run(str(g)) 1309 1310 self.assertExportImport(g, (x, y)) 1311 1312 def test_cse_not_introduce_aliasing(self): 1313 @torch.jit.script 1314 def tensor_alias_outputs(x): 1315 return x + x, x + x 1316 1317 self.run_pass('cse', tensor_alias_outputs.graph) 1318 FileCheck().check_count("aten::add", 2).run(tensor_alias_outputs.graph) 1319 1320 @torch.jit.script 1321 def ints_alias_outputs(x): 1322 # type: (int) -> Tuple[int, int] 1323 return x + x, x + x 1324 1325 # non-aliasing types can be CSEd 1326 self.run_pass('cse', ints_alias_outputs.graph) 1327 FileCheck().check_count("aten::add", 1, exactly=True).run(ints_alias_outputs.graph) 1328 1329 def test_recursive_cse(self): 1330 input_str = """ 1331graph(%x : Tensor, 1332 %y : Tensor, 1333 %20 : int): 1334 %2 : int = prim::Constant[value=1]() 1335 %3 : Tensor = aten::add(%x, %y, %2) 1336 %4 : int = aten::add(%2, %20) 1337 %5 : bool = aten::Bool(%4) 1338 %z : int = prim::If(%5) 1339 # CHECK: block 1340 block0(): 1341 # CHECK-NOT: aten::add 1342 %z.1 : int = aten::add(%2, %20) 1343 -> (%z.1) 1344 block1(): 1345 -> (%2) 1346 return (%z) 1347""" 1348 graph = parse_ir(input_str) 1349 self.run_pass('cse', graph) 1350 FileCheck().run(input_str, graph) 1351 1352 def test_pattern_based_rewrite(self): 1353 # mul(mul(mul(mul(x,y),z),x),y) --> mul(mul(mulmul(x,y,z), x), y) --> 1354 # --> mulmul(mulmul(x,y,z), x, y) 1355 input_str = """ 1356graph(%x, %y, %z): 1357 # CHECK-NOT: aten::mul 1358 # CHECK: my::fused_mulmul 1359 %t = aten::mul(%x, %y) 1360 %p = aten::mul(%t, %z) 1361 # CHECK: my::fused_mulmul 1362 %u = aten::mul(%p, %x) 1363 %o = aten::mul(%u, %y) 1364 return (%o)""" 1365 graph = parse_ir(input_str) 1366 torch._C._jit_pass_custom_pattern_based_rewrite_graph(""" 1367graph(%a, %b, %c): 1368 %q = aten::mul(%a, %b) 1369 %r = aten::mul(%q, %c) 1370 return (%r)""", """ 1371graph(%a, %b, %c): 1372 %r = my::fused_mulmul(%a, %b, %c) 1373 return (%r)""", graph) 1374 FileCheck().run(input_str, graph) 1375 1376 # Check that overlapping matches are handled correctly 1377 # mul(mul(mul(x,y),z),x) --> mul(mulmul(x,y,z), x) 1378 input_str = """ 1379graph(%x, %y, %z): 1380 # CHECK-NOT: aten::mul 1381 # CHECK: my::fused_mulmul 1382 %t = aten::mul(%x, %y) 1383 %p = aten::mul(%t, %z) 1384 # CHECK-NEXT: aten::mul 1385 %u = aten::mul(%p, %x) 1386 return (%u)""" 1387 graph = parse_ir(input_str) 1388 torch._C._jit_pass_custom_pattern_based_rewrite_graph(""" 1389graph(%a, %b, %c): 1390 %q = aten::mul(%a, %b) 1391 %r = aten::mul(%q, %c) 1392 return (%r)""", """ 1393graph(%a, %b, %c): 1394 %r = my::fused_mulmul(%a, %b, %c) 1395 return (%r)""", graph) 1396 FileCheck().run(input_str, graph) 1397 1398 # Check add(mul(x,y),z) --> muladd(x,y,z) replacement 1399 input_str = """ 1400graph(%x, %y, %z): 1401 # CHECK-NOT: aten::mul 1402 # CHECK-NOT: aten::add 1403 %c = prim::Const[value=1]() 1404 %t = aten::mul(%x, %y) 1405 %p = aten::add(%t, %z, %c) 1406 # CHECK: my::muladd 1407 # CHECK-NEXT: return 1408 return (%p)""" 1409 graph = parse_ir(input_str) 1410 torch._C._jit_pass_custom_pattern_based_rewrite_graph(""" 1411graph(%a, %b, %c, %d): 1412 %q = aten::mul(%a, %b) 1413 %r = aten::add(%q, %c, %d) 1414 return (%r)""", """ 1415graph(%a, %b, %c, %d): 1416 %r = my::muladd(%a, %b, %c, %d) 1417 return (%r)""", graph) 1418 FileCheck().run(input_str, graph) 1419 1420 # Check add(mul(x,y),z) --> sub(add(x,y),z) replacement 1421 input_str = """ 1422graph(%x, %y, %z): 1423 # CHECK-NOT: aten::mul 1424 %c = prim::Const[value=1]() 1425 # CHECK: aten::add 1426 %t = aten::mul(%x, %y) 1427 # CHECK-NEXT: aten::sub 1428 %p = aten::add(%t, %z, %c) 1429 # CHECK-NOT: aten::add 1430 # CHECK-NEXT: return 1431 return (%p)""" 1432 graph = parse_ir(input_str) 1433 torch._C._jit_pass_custom_pattern_based_rewrite_graph(""" 1434graph(%a, %b, %c, %d): 1435 %q = aten::mul(%a, %b) 1436 %r = aten::add(%q, %c, %d) 1437 return (%r)""", """ 1438graph(%a, %b, %c, %d): 1439 %q = aten::add(%a, %b, %d) 1440 %r = aten::sub(%q, %c, %d) 1441 return (%r)""", graph) 1442 FileCheck().run(input_str, graph) 1443 1444 # Check mul(x,y) --> x replacement 1445 input_str = """ 1446graph(%x, %y, %z): 1447 %c = prim::Const[value=1]() 1448 # CHECK-NOT: aten::mul 1449 %t = aten::mul(%x, %y) 1450 # CHECK: aten::add(%x, %z 1451 %p = aten::add(%t, %z, %c) 1452 # CHECK-NEXT: return 1453 return (%p)""" 1454 graph = parse_ir(input_str) 1455 torch._C._jit_pass_custom_pattern_based_rewrite_graph(""" 1456graph(%Pa, %Pb): 1457 %Pq = aten::mul(%Pa, %Pb) 1458 return (%Pq)""", """ 1459graph(%Ra, %Rb): 1460 return (%Ra)""", graph) 1461 FileCheck().run(input_str, graph) 1462 1463 @_tmp_donotuse_dont_inline_everything 1464 def test_pattern_based_module_rewrite(self): 1465 # Check match::module behavior 1466 class Test(torch.nn.Module): 1467 def __init__(self) -> None: 1468 super().__init__() 1469 self.conv = torch.nn.Conv2d(1, 20, 5, 1) 1470 self.bn = torch.nn.BatchNorm2d(num_features=20) 1471 1472 def forward(self, x): 1473 x = self.conv(x) 1474 x = self.bn(x) 1475 return x 1476 m = torch.jit.script(Test()) 1477 torch._C._jit_pass_custom_pattern_based_rewrite_graph(""" 1478 graph(%self, %x): 1479 %conv = match::module[name="Conv2d"](%self) 1480 %y = prim::CallMethod[name="forward"](%conv, %x) 1481 %bn = match::module[name="BatchNorm2d"](%self) 1482 %z = prim::CallMethod[name="forward"](%bn, %y) 1483 return (%z)""", """ 1484 graph(%self, %x): 1485 %z = my::matched_conv_bn(%self, %x) 1486 return (%z)""", m._c._get_method("forward").graph) 1487 1488 FileCheck().check("my::matched_conv_bn").run(m._c._get_method("forward").graph) 1489 1490 def test_pattern_based_rewrite_with_source_range_preserved(self): 1491 class TestModule1(torch.nn.Module): 1492 def forward(self, x, y, z, w): 1493 x = x + y 1494 x = x * z 1495 return w - x 1496 1497 input_pattern = """ 1498 graph(%x, %y, %z, %const): 1499 %t = aten::add(%x, %y, %const) 1500 %o = aten::mul(%t, %z) 1501 return (%o)""" 1502 replacement_pattern = """ 1503 graph(%x, %y, %z, %const): 1504 %o = my::add_mul(%x, %y, %z, %const) 1505 return (%o)""" 1506 scripted_model = torch.jit.script(TestModule1()) 1507 graph = scripted_model.graph 1508 value_mappings = [("o", "t")] 1509 for node in graph.nodes(): 1510 if node.kind() == "aten::add": 1511 source_range_1 = node.sourceRange() 1512 torch._C._jit_pass_custom_pattern_based_rewrite_graph( 1513 input_pattern, replacement_pattern, scripted_model.graph, value_name_pairs=value_mappings) 1514 graph = scripted_model.graph 1515 for node in graph.nodes(): 1516 if node.kind() == "my::add_mul": 1517 source_range_2 = node.sourceRange() 1518 self.assertTrue(source_range_1 == source_range_2) 1519 1520 class TestModule2(torch.nn.Module): 1521 def forward(self, x, y, z, w): 1522 x = x + y 1523 x = x + z 1524 x = x * z 1525 x = x * w 1526 return x - 2 1527 1528 # Check source range preservation for two node transforms add -> my_add 1529 input_pattern = """ 1530 graph(%x, %y, %const): 1531 %o = aten::add(%x, %y, %const) 1532 return (%o)""" 1533 replacement_pattern = """ 1534 graph(%x, %y, %const): 1535 %o = my::add(%x, %y, %const) 1536 return (%o)""" 1537 scripted_model = copy.deepcopy(torch.jit.script(TestModule2())) 1538 graph_copy = scripted_model.graph.copy() 1539 value_mappings = [("o", "o")] 1540 source_range_add_1 = None 1541 for node in graph_copy.nodes(): 1542 if source_range_add_1 is None and node.kind() == "aten::add": 1543 source_range_add_1 = node.sourceRange() 1544 if source_range_add_1 is not None and node.kind() == "aten::add": 1545 source_range_add_2 = node.sourceRange() 1546 torch._C._jit_pass_custom_pattern_based_rewrite_graph( 1547 input_pattern, replacement_pattern, graph_copy, value_name_pairs=value_mappings) 1548 source_range_my_add_1 = None 1549 for node in graph_copy.nodes(): 1550 if source_range_my_add_1 is None and node.kind() == "my::add": 1551 source_range_my_add_1 = node.sourceRange() 1552 if source_range_my_add_1 is not None and node.kind() == "my::add": 1553 source_range_my_add_2 = node.sourceRange() 1554 self.assertTrue(source_range_add_1 == source_range_my_add_1) 1555 self.assertTrue(source_range_add_2 == source_range_my_add_2) 1556 1557 # Check source range preservation for add-add -> double_add transform 1558 # fuse nodes 1559 input_pattern = """ 1560 graph(%x, %y, %z, %const): 1561 %t = aten::add(%x, %y, %const) 1562 %o = aten::add(%t, %z, %const) 1563 return (%o)""" 1564 replacement_pattern = """ 1565 graph(%x, %y, %z, %const): 1566 %o = my::double_add(%x, %y, %z, %const) 1567 return (%o)""" 1568 scripted_model = torch.jit.script(TestModule2()) 1569 graph_copy = scripted_model.graph.copy() 1570 value_mappings = [("o", "t")] 1571 source_range_1 = None 1572 source_range_2 = None 1573 for node in graph_copy.nodes(): 1574 if node.kind() == "aten::add": 1575 source_range_1 = node.sourceRange() 1576 break 1577 torch._C._jit_pass_custom_pattern_based_rewrite_graph( 1578 input_pattern, replacement_pattern, graph_copy, value_name_pairs=value_mappings) 1579 for node in graph_copy.nodes(): 1580 if node.kind() == "my::double_add": 1581 source_range_2 = node.sourceRange() 1582 self.assertTrue(source_range_1 == source_range_2) 1583 1584 # Check source range preservation for mul -> add + add transform 1585 # split node 1586 input_pattern = """ 1587 graph(%x, %y): 1588 %t = aten::mul(%x, %y) 1589 return (%t)""" 1590 replacement_pattern = """ 1591 graph(%x, %y): 1592 %t = my::add(%x, %y) 1593 %o = my::add(%t, %y) 1594 return (%o)""" 1595 scripted_model = torch.jit.script(TestModule2()) 1596 graph_copy = scripted_model.graph.copy() 1597 value_mappings = [("t", "t"), ("o", "t")] 1598 source_range_mul_1 = None 1599 for node in graph_copy.nodes(): 1600 if source_range_mul_1 is None and node.kind() == "aten::mul": 1601 source_range_mul_1 = node.sourceRange() 1602 if source_range_mul_1 is not None and node.kind() == "aten::mul": 1603 source_range_mul_2 = node.sourceRange() 1604 torch._C._jit_pass_custom_pattern_based_rewrite_graph( 1605 input_pattern, replacement_pattern, graph_copy, value_name_pairs=value_mappings) 1606 source_range_add_1 = None 1607 for node in graph_copy.nodes(): 1608 if source_range_add_1 is None and node.kind() == "my::add": 1609 source_range_add_1 = node.sourceRange() 1610 if source_range_add_1 is not None and node.kind() == "my::add": 1611 source_range_add_2 = node.sourceRange() 1612 self.assertTrue(source_range_mul_1 == source_range_add_1) 1613 self.assertTrue(source_range_mul_2 == source_range_add_2) 1614 1615 # Check lack of source range preservation for mul-mul-> double_mul transform 1616 input_pattern = """ 1617 graph(%x, %y, %z): 1618 %t = aten::mul(%x, %y) 1619 %o = aten::mul(%t, %z) 1620 return (%o)""" 1621 replacement_pattern = """ 1622 graph(%x, %y, %z): 1623 %o = my::double_mul(%x, %y, %z) 1624 return (%o)""" 1625 scripted_model = torch.jit.script(TestModule2()) 1626 graph_copy = scripted_model.graph.copy() 1627 for node in graph_copy.nodes(): 1628 if node.kind() == "aten::mul": 1629 source_range_1 = node.sourceRange() 1630 torch._C._jit_pass_custom_pattern_based_rewrite_graph(input_pattern, replacement_pattern, graph_copy) 1631 for node in graph_copy.nodes(): 1632 if node.kind() == "my::double_mul": 1633 source_range_2 = node.sourceRange() 1634 self.assertFalse(source_range_1 == source_range_2) 1635 1636 def test_expand_quantlint(self): 1637 pass 1638 1639 def test_expand_fold_quant_inputs(self): 1640 pass 1641 1642 def test_shape_analysis_broadcast(self): 1643 def broadcast(a, b): 1644 return a + b 1645 1646 x = torch.randn(3, 1, 5, requires_grad=True) 1647 y = torch.randn(4, 1, 8, 5, requires_grad=True) 1648 1649 graph = torch.jit.script(broadcast).graph 1650 torch._C._jit_pass_complete_shape_analysis(graph, (x, y), False) 1651 FileCheck().check("Float(4, 3, 8, 5, strides=[120, 40, 5, 1], device=cpu)").run(str(graph)) 1652 1653 def test_shape_analysis_unsqueeze_in_loop(self): 1654 input_str = """graph(%x.1 : Tensor): 1655 %4 : bool = prim::Constant[value=1]() 1656 %1 : int = prim::Constant[value=2]() 1657 %7 : int = prim::Constant[value=0]() 1658 # CHECK: FloatTensor(requires_grad=0, device=cpu) = prim::Loop 1659 %x : Tensor = prim::Loop(%1, %4, %x.1) 1660 # CHECK: : FloatTensor(requires_grad=0, device=cpu)): 1661 block0(%i : int, %x.6 : Tensor): 1662 # CHECK: FloatTensor(requires_grad=0, device=cpu) = aten::unsqueeze 1663 %x.3 : Tensor = aten::unsqueeze(%x.6, %7) 1664 -> (%4, %x.3) 1665 return (%x)""" 1666 graph = parse_ir(input_str) 1667 torch._C._jit_pass_complete_shape_analysis(graph, (torch.zeros(2, 2, dtype=torch.float32),), False) 1668 FileCheck().run(input_str, graph) 1669 1670 def test_script_tensor_type(self): 1671 def foo(x, t: torch.dtype): 1672 return x.type(t) 1673 scr = torch.jit.script(foo) 1674 x = torch.rand(3, 4) 1675 for t in [torch.int8, torch.float64, torch.float32, 1676 torch.bfloat16, torch.complex64, torch.complex128, torch.bool]: 1677 self.assertEqual(scr(x, t), foo(x, t)) 1678 1679 def test_script_bool_literal_conversion(self): 1680 def foo(x): 1681 return torch.mul(x, True) 1682 scr = torch.jit.script(foo) 1683 x = torch.rand(3, 4) 1684 self.assertEqual(scr(x), foo(x)) 1685 1686 def test_shape_analysis_masked_select(self): 1687 input_str = """graph(%0 : Float(), 1688 %1 : Bool()): 1689 # CHECK: Float(*, requires_grad=0, device=cpu) = aten::masked_select 1690 %2 : Tensor = aten::masked_select(%0, %1) # test/test_jit.py:15261:0 1691 return (%2)""" 1692 graph = parse_ir(input_str) 1693 x = torch.ones(1, dtype=torch.float32)[0] 1694 mask = x.ge(0.5) 1695 torch._C._jit_pass_complete_shape_analysis(graph, (x, mask), False) 1696 FileCheck().run(input_str, graph) 1697 1698 # TODO: update verify to work with GraphExecutors 1699 @unittest.skip("verify needs to be updated to work with GraphExecutors") 1700 def test_verify(self): 1701 x = torch.tensor([0.4], requires_grad=True) 1702 y = torch.tensor([0.7], requires_grad=True) 1703 1704 @torch.jit.compile 1705 def f(x, y): 1706 z = torch.sigmoid(x * (x + y)) 1707 w = torch.abs(x * x * x + y) + Variable(torch.ones(1)) 1708 return z, w 1709 1710 torch.jit.verify(f, (x, y), loss_fn=lambda z, w: z * w, devices=[]) 1711 1712 # TODO: adapt to a GraphExecutor test 1713 @unittest.skip("Need to instrument GraphExecutors a bit more") 1714 def test_flags(self): 1715 x, y = torch.randn(2, 2) 1716 y = Variable(torch.randn(2, 2)) 1717 1718 @torch.jit.compile 1719 def fn(x, y): 1720 return (x * x + y * y + x * y).sum() 1721 1722 grads = {} 1723 for rx, ry in product((True, False), repeat=2): 1724 x.requires_grad = rx 1725 y.requires_grad = ry 1726 1727 self.assertFalse(fn.has_trace_for(x, y)) 1728 out = fn(x, y) 1729 1730 self.assertFalse(fn.has_trace_for(x, y)) 1731 for v, name, compute in [(x, 'x', rx), (y, 'y', ry)]: 1732 if not compute: 1733 continue 1734 grad_v, = torch.autograd.grad(out, v, retain_graph=True) 1735 expected_grad = grads.setdefault(name, grad_v) 1736 self.assertEqual(grad_v, expected_grad) 1737 self.assertEqual(fn.has_trace_for(x, y), rx or ry) 1738 1739 def test_python_ir(self): 1740 x = torch.tensor([0.4], requires_grad=True) 1741 y = torch.tensor([0.7], requires_grad=True) 1742 1743 def doit(x, y): 1744 return torch.sigmoid(torch.tanh(x * (x + y))) 1745 1746 g, _ = torch.jit._get_trace_graph(doit, (x, y)) 1747 self.run_pass('dce', g) 1748 self.run_pass('canonicalize', g) 1749 g2 = torch._C.Graph() 1750 g_to_g2 = {} 1751 for node in g.inputs(): 1752 g_to_g2[node] = g2.addInput() 1753 for node in g.nodes(): 1754 n_ = g2.createClone(node, lambda x: g_to_g2[x]) 1755 g2.appendNode(n_) 1756 for o, no in zip(node.outputs(), n_.outputs()): 1757 g_to_g2[o] = no 1758 1759 for node in g.outputs(): 1760 g2.registerOutput(g_to_g2[node]) 1761 1762 t_node = g2.create("prim::TensorTest").t_("a", torch.ones([2, 2])) 1763 self.assertEqual(t_node.attributeNames(), ["a"]) 1764 g2.appendNode(t_node) 1765 self.assertTrue(torch.equal(torch.ones(2, 2), t_node.t("a"))) 1766 for node in g.nodes(): 1767 self.assertTrue(g2.findNode(node.kind()) is not None) 1768 1769 @unittest.skipIf(IS_SANDCASTLE, "gtest runs these in sandcastle") 1770 @unittest.skipIf(RUN_CUDA, "covered by test_cpp_cuda") 1771 @unittest.skipIf(not torch._C._jit_has_cpp_tests(), "Tests were not built, use BUILD_TEST=1") 1772 def test_cpp(self): 1773 from cpp.jit import tests_setup 1774 tests_setup.setup() 1775 torch._C._jit_run_cpp_tests() 1776 tests_setup.shutdown() 1777 1778 def test_batchnorm(self): 1779 x = torch.ones(2, 2, 2, 2) 1780 g, outputs, inputs = torch.jit._get_trace_graph(nn.BatchNorm2d(2), x, 1781 _force_outplace=True, return_inputs=True) 1782 m = self.createFunctionFromGraph(g) 1783 self.assertEqual(outputs, m(*inputs)) 1784 1785 def test_dropout(self): 1786 x = torch.ones(2, 2) 1787 with torch.random.fork_rng(devices=[]): 1788 g, outputs, inputs = torch.jit._get_trace_graph(nn.Dropout(0.6), x, return_inputs=True) 1789 with torch.random.fork_rng(devices=[]): 1790 m = self.createFunctionFromGraph(g) 1791 self.assertEqual(outputs, m(*inputs)) 1792 1793 @unittest.skipIf(not RUN_CUDA, "test requires CUDA") 1794 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "skip if profiling isn't enabled") 1795 def test_native_dropout_corner_case(self): 1796 with disable_autodiff_subgraph_inlining(): 1797 def t(x, p: float, t: bool): 1798 o = torch.dropout(x, p, t) 1799 return o 1800 1801 jit_t = torch.jit.script(t) 1802 x = torch.randn(5).requires_grad_() 1803 FileCheck().check("prim::DifferentiableGraph").run(jit_t.graph_for(x, 1.0, True, profile_and_replay=True)) 1804 1805 for train in [True, False]: 1806 for p in [0.0, 1.0]: 1807 for device in ["cuda", "cpu"]: 1808 x = torch.randn(5).to(device=device).requires_grad_() 1809 x_ref = x.detach().requires_grad_() 1810 o = jit_t(x, p, train) 1811 o_ref = t(x_ref, p, train) 1812 o.sum().backward() 1813 o_ref.sum().backward() 1814 assert o.equal(o_ref) 1815 assert x.grad.equal(x_ref.grad) 1816 1817 @slowTest 1818 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, 'Testing differentiable graph') 1819 def test_dropout_module_requires_grad(self): 1820 with enable_profiling_mode_for_profiling_tests(): 1821 class MyModule(torch.nn.Module): 1822 def __init__(self, M): 1823 super().__init__() 1824 self.dropout = torch.nn.Dropout(0.5) 1825 self.linear = torch.nn.Linear(M, M) 1826 1827 def forward(self, input): 1828 input = self.dropout(input) 1829 output = self.linear(input) 1830 return output 1831 1832 def profile(func, X): 1833 with torch.autograd.profiler.profile() as prof: 1834 func(X) 1835 return [e.name for e in prof.function_events] 1836 1837 M = 1000 1838 scripted = torch.jit.script(MyModule(M)) 1839 # To reduce confusion about expected behaviors: 1840 # requires_grad controls whether dropout is symbolically differentiated. 1841 # training controls whether bernoulli_ is called inside symbolic differentiation of dropout. 1842 # * When requires_grad == training, the expected behaviors are obvious. 1843 # * When requires_grad=True and training=False, bernoulli_ might still show up in the graph. 1844 # But it's in a branch that's not called. That's why we have separate checks for autograd 1845 # profiler to make sure it's not run. 1846 # * When requires_grad=False and training=True, bernoulli_ must be run since it's the expected 1847 # behavior for the dropout layer in training mode. It's independent of whether graph requires 1848 # gradient. In fact bernoulli_ comes from autograd instead of autodiff in this case. 1849 for training in (True, False): 1850 if training: 1851 scripted.train() 1852 else: 1853 scripted.eval() 1854 for requires_grad in (True, False): 1855 X = torch.randn(M, M, requires_grad=requires_grad) 1856 if requires_grad: 1857 FileCheck().check("aten::native_dropout").run(scripted.graph_for(X, profile_and_replay=True)) 1858 self.assertEqual(training, 'aten::bernoulli_' in profile(scripted, X)) 1859 1860 @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.SIMPLE, 'Testing differentiable graph') 1861 @skipIfTorchDynamo("Torchdynamo cannot correctly handle profiler.profile calls") 1862 def test_dropout_func_requires_grad(self): 1863 def dropout_training(input): 1864 return F.dropout(input, 0.5, training=True) 1865 1866 def dropout_eval(input): 1867 return F.dropout(input, 0.5, training=False) 1868 1869 def profile(func, X): 1870 with torch.autograd.profiler.profile() as prof: 1871 func(X) 1872 return [e.name for e in prof.function_events] 1873 1874 M = 1000 1875 scripted_training = torch.jit.script(dropout_training) 1876 scripted_eval = torch.jit.script(dropout_eval) 1877 # See comments in test_dropout_module_requires_grad. 1878 with disable_autodiff_subgraph_inlining(): 1879 for requires_grad in (True, False): 1880 X = torch.randn(M, M, requires_grad=requires_grad) 1881 if requires_grad: 1882 FileCheck().check("aten::native_dropout").run(scripted_training.graph_for(X, profile_and_replay=True)) 1883 self.assertIn('aten::bernoulli_', profile(scripted_training, X)) 1884 self.assertNotIn('aten::bernoulli_', profile(scripted_eval, X)) 1885 1886 @unittest.skipIf(not RUN_CUDA, "test_dropout_cuda require CUDA") 1887 def test_dropout_cuda(self): 1888 # Dropout AD is dispatched to _fused_dropout in CUDA case, 1889 # which is not included in TestJitGeneratedFunctional 1890 def _zero_rate(t): 1891 return torch.true_divide((t == 0).sum(), t.numel()) 1892 1893 x = torch.ones(1000, 1000).cuda().requires_grad_() 1894 1895 with enable_profiling_mode_for_profiling_tests(): 1896 @torch.jit.script 1897 def func(x): 1898 return torch.nn.functional.dropout(x) 1899 1900 with freeze_rng_state(): 1901 out_ref = torch.nn.functional.dropout(x) 1902 grad_ref = torch.autograd.grad(out_ref.sum(), x) 1903 1904 with freeze_rng_state(): 1905 out = func(x) 1906 grad = torch.autograd.grad(out.sum(), x) 1907 1908 # TODO(#40882): previously we assert exact matches between eager and JIT result: 1909 # self.assertEqual(out, out_ref) 1910 # self.assertEqual(grad, grad_ref) 1911 # This test was disabled during legacy -> profiling executor transition. 1912 # Currently JIT fused results doesn't match eager result exactly due to some changes merged in between. 1913 # We temporarily only check statstical difference but it should be reverted once the issue is fixed. 1914 self.assertEqual(_zero_rate(out), _zero_rate(out_ref), rtol=1e-3, atol=1e-4) 1915 self.assertEqual(_zero_rate(grad[0]), _zero_rate(grad_ref[0]), rtol=1e-3, atol=1e-4) 1916 1917 def test_torch_ops_overloaded(self): 1918 with self.assertRaisesRegex(RuntimeError, "failed to match any schema"): 1919 torch.ops.aten.add("a", 1) 1920 self.assertEqual("ab", torch.ops.aten.add("a", "b")) 1921 a, b = torch.rand(3, 4), torch.rand(3, 4) 1922 self.assertEqual(a + b, torch.ops.aten.add(a, b)) 1923 self.assertEqual(a + 1, torch.ops.aten.add(a, 1)) 1924 1925 def test_torch_ops_kwonly(self): 1926 a, b = torch.rand(3, 4), torch.rand(3, 4) 1927 with self.assertRaisesRegex(RuntimeError, "positional argument"): 1928 torch.ops.aten.add(a, b, 2) 1929 # h/t Chillee for this ambiguous case 1930 self.assertEqual(a.prod(1), torch.ops.aten.prod(a, 1)) 1931 1932 def test_torch_complex(self): 1933 def fn(real, img): 1934 return torch.complex(real, img) 1935 1936 def fn_out(real, img, out): 1937 return torch.complex(real, img, out=out) 1938 self.checkScript(fn, (torch.rand(3, 4), torch.rand(3, 4), )) 1939 self.checkScript(fn, (torch.ones(5, 1, 4), torch.ones(5, 1, 4), )) 1940 self.checkScript(fn, (torch.zeros(1, 6), torch.ones(6, 1), )) 1941 self.checkScript(fn, (torch.zeros(1, 6), torch.zeros(6, 1), )) 1942 self.checkScript(fn, (torch.empty(3, 4), torch.empty(3, 4), )) 1943 1944 real = torch.tensor([1, 2], dtype=torch.float32) 1945 img = torch.tensor([3, 4], dtype=torch.float32) 1946 out = torch.empty([3, 4], dtype=torch.complex64) 1947 self.checkScript(fn_out, (real, img, out, )) 1948 1949 real = torch.tensor([5, 2], dtype=torch.float64) 1950 img = torch.tensor([3, 4], dtype=torch.float64) 1951 out = torch.empty([5, 2], dtype=torch.complex128) 1952 self.checkScript(fn_out, (real, img, out, )) 1953 1954 real = torch.ones([1, 2]) 1955 img = torch.ones([1, 2]) 1956 out = torch.empty([1, 2], dtype=torch.complex64) 1957 self.checkScript(fn_out, (real, img, out, )) 1958 1959 real = torch.ones([3, 8, 7]) 1960 img = torch.ones([3, 8, 7]) 1961 out = torch.empty([3, 8, 7], dtype=torch.complex64) 1962 self.checkScript(fn_out, (real, img, out, )) 1963 1964 real = torch.empty([3, 2, 6]) 1965 img = torch.empty([3, 2, 6]) 1966 out = torch.empty([3, 2, 6], dtype=torch.complex64) 1967 self.checkScript(fn_out, (real, img, out, )) 1968 1969 real = torch.zeros([1, 3]) 1970 img = torch.empty([3, 1]) 1971 out = torch.empty([3, 3], dtype=torch.complex64) 1972 self.checkScript(fn_out, (real, img, out, )) 1973 1974 real = torch.ones([2, 5]) 1975 img = torch.empty([2, 1]) 1976 out = torch.empty([2, 5], dtype=torch.complex64) 1977 self.checkScript(fn_out, (real, img, out, )) 1978 1979 real = torch.ones([2, 5]) 1980 img = torch.zeros([2, 1]) 1981 out = torch.empty([2, 5], dtype=torch.complex64) 1982 self.checkScript(fn_out, (real, img, out, )) 1983 1984 def test_einsum(self): 1985 def check(fn, jitted, *args): 1986 self.assertGraphContains(jitted.graph, kind='aten::einsum') 1987 self.assertEqual(fn(*args), jitted(*args)) 1988 1989 def equation_format(x, y): 1990 return torch.einsum('i,j->ij', (x, y)) 1991 1992 def equation_format_varargs(x, y): 1993 return torch.einsum('i,j->ij', x, y) 1994 1995 def sublist_format(x, y): 1996 return torch.einsum(x, [0], y, [1], [0, 1]) 1997 1998 x = make_tensor((5,), dtype=torch.float32, device="cpu") 1999 y = make_tensor((10,), dtype=torch.float32, device="cpu") 2000 2001 for fn in [equation_format, equation_format_varargs, sublist_format]: 2002 check(fn, torch.jit.script(fn), x, y) 2003 check(fn, torch.jit.trace(fn, (x, y)), x, y) 2004 2005 @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 2006 def test_python_ivalue(self): 2007 # Test if pure python object can be hold as IValue and conversion 2008 # between IValue and PyObject are correct 2009 # test for numpy object 2010 py_array = np.arange(15) 2011 ret_py_obj = torch._C._ivalue_debug_python_object(py_array) 2012 self.assertEqual(py_array, ret_py_obj) 2013 2014 # test for function object 2015 ret_py_obj = torch._C._ivalue_debug_python_object(F.relu) 2016 self.assertEqual(F.relu, ret_py_obj) 2017 2018 # test for memory management 2019 # we need to ensure IValue correctly call incref/decref to avoid 2020 # dangling behavior and potential memory leaks during conversions 2021 def test_func_scope_helper(inp): 2022 # create a scope and do the conversion -> ivalue -> pyobject 2023 # this func return a new pyobject that refcount + 1 2024 inp_refcount = sys.getrefcount(inp) 2025 ivalue_holder = torch._C._ivalue_debug_python_object(inp) 2026 self.assertEqual(inp_refcount + 1, sys.getrefcount(ivalue_holder)) 2027 return ivalue_holder + 1 2028 2029 test_input = 2200 2030 before_count = sys.getrefcount(test_input) 2031 test_func_scope_helper(test_input) 2032 after_count = sys.getrefcount(test_input) 2033 2034 # after the test_func_scope_helper_call, the refcount of 2035 # test_input should be equal to the original refcount 2036 # otherwise we get either dangling pointer or memory leak! 2037 self.assertEqual(before_count, after_count) 2038 2039 def test_decompose_addmm(self): 2040 def does_decompose(): 2041 @torch.jit.script 2042 def addmm(mat, mat1, mat2): 2043 a = mat.addmm(mat1, mat2) 2044 b = mat.addmm(mat1, mat2, alpha=1.0, beta=1.0) 2045 return a + b 2046 2047 mat = torch.randn(2, 2) 2048 mat1 = torch.randn(2, 4) 2049 mat2 = torch.randn(4, 2) 2050 2051 out_ref = addmm(mat, mat1, mat2) 2052 self.run_pass('decompose_ops', addmm.graph) 2053 out_test = addmm(mat, mat1, mat2) 2054 self.assertEqual(out_ref, out_test) 2055 FileCheck().check_not("addmm").run(str(addmm.graph)) 2056 2057 def doesnt_decompose(): 2058 @torch.jit.script 2059 def addmm(mat, mat1, mat2, alpha, beta): 2060 a = mat.addmm(mat1, mat2, alpha=4.20, beta=2.0) 2061 b = mat.addmm(mat1, mat2, alpha=int(alpha), beta=int(beta)) 2062 2063 return a + b 2064 2065 orig = str(addmm.graph) 2066 self.run_pass('decompose_ops', addmm.graph) 2067 self.assertTrue(orig == str(addmm.graph)) 2068 2069 does_decompose() 2070 doesnt_decompose() 2071 2072 @suppress_warnings 2073 def test_sparse_tensors(self): 2074 @torch.jit.ignore 2075 def get_sparse(): 2076 return torch.sparse_coo_tensor((2, 3), dtype=torch.float32) 2077 2078 @torch.jit.script 2079 def test_is_sparse(input): 2080 # type: (Tensor) -> bool 2081 return input.is_sparse 2082 2083 script_out_is_sparse = test_is_sparse(get_sparse()) 2084 script_out_is_dense = test_is_sparse(torch.randn(2, 3)) 2085 self.assertEqual(script_out_is_sparse, True) 2086 self.assertEqual(script_out_is_dense, False) 2087 2088 def test_basic_sparse(input): 2089 output = get_sparse() 2090 return output, input 2091 2092 self.checkScript(test_basic_sparse, (get_sparse(),)) 2093 self.checkScript(test_basic_sparse, (torch.tensor([1]),)) 2094 2095 def test_sparse_sum(input): 2096 return torch.sparse.sum(input) 2097 2098 self.checkScript(test_sparse_sum, (get_sparse(),)) 2099 2100 def test_sparse_mm(input1, input2): 2101 return torch.sparse.mm(input1, input2) 2102 2103 self.checkScript(test_sparse_mm, (get_sparse(), torch.randn(3, 4))) 2104 2105 def test_sparse_addmm(input, input1, input2): 2106 return torch.sparse.addmm(input, input1, input2) 2107 2108 def test_sparse_addmm_alpha_beta(input, input1, input2): 2109 return torch.sparse.addmm(input, input1, input2, alpha=1.3, beta=1.5) 2110 2111 self.checkScript(test_sparse_addmm, (torch.randn(2, 4), get_sparse(), torch.randn(3, 4))) 2112 self.checkScript(test_sparse_addmm_alpha_beta, (torch.randn(2, 4), get_sparse(), torch.randn(3, 4))) 2113 2114 @suppress_warnings 2115 def test_sparse_csr_tensors(self): 2116 @torch.jit.ignore 2117 def get_sparse_csr(): 2118 return torch.randn(3, 3).to_sparse_csr() 2119 2120 @torch.jit.script 2121 def test_is_sparse_csr(input): 2122 # type: (Tensor) -> bool 2123 return input.is_sparse_csr 2124 2125 script_out_is_sparse_csr = test_is_sparse_csr(get_sparse_csr()) 2126 script_out_is_dense_csr = test_is_sparse_csr(torch.randn(3, 3)) 2127 2128 self.assertEqual(script_out_is_sparse_csr, True) 2129 self.assertEqual(script_out_is_dense_csr, False) 2130 2131 @unittest.skipIf(not RUN_CUDA, "requires CUDA") 2132 def test_device_not_equal(self): 2133 2134 def compare_device(x: torch.device): 2135 return x != torch.device("cuda:0") 2136 2137 def compare_two_device(x: torch.device, y: torch.device): 2138 return x != y 2139 2140 self.checkScript(compare_device, (torch.device("cuda:0"),)) 2141 self.checkScript(compare_two_device, (torch.device("cuda:0"), torch.device("cuda:1"), )) 2142 2143 def test_constant_prop_simple(self): 2144 @torch.jit.script 2145 def constant_prop(input_int): 2146 # type: (int) -> int 2147 a = 2 * 3 2148 b = a + 2 2149 return b - input_int 2150 2151 out_ref = constant_prop(2) 2152 self.run_pass('constant_propagation', constant_prop.graph) 2153 out_test = constant_prop(2) 2154 self.assertEqual(out_ref, out_test) 2155 graph_str = str(constant_prop.graph) 2156 self.assertTrue("aten::add" not in graph_str and "aten::mul" not in graph_str) 2157 const = constant_prop.graph.findNode("prim::Constant").output().toIValue() 2158 self.assertEqual(const, 8) 2159 2160 def test_constant_prop_nested(self): 2161 @torch.jit.script 2162 def constant_prop(a): 2163 b = 2 + 1 2164 if bool(a < 2): 2165 c = b + 2 2166 else: 2167 c = b - 2 2168 return c 2169 out_ref = constant_prop(torch.tensor(2)) 2170 self.run_pass('constant_propagation', constant_prop.graph) 2171 out_test = constant_prop(torch.tensor(2)) 2172 self.assertEqual(out_ref, out_test) 2173 if_node = constant_prop.graph.findNode("prim::If") 2174 for block in if_node.blocks(): 2175 for node in block.nodes(): 2176 self.assertTrue(node.kind() == "prim::Constant") 2177 2178 def test_constant_prop_print(self): 2179 @torch.jit.script 2180 def constant_prop(input_tensor): 2181 a = 2 * 3 2182 print(a) 2183 b = a + 2 2184 return b + input_tensor 2185 2186 self.run_pass('constant_propagation', constant_prop.graph) 2187 graph = constant_prop.graph 2188 print_node = graph.findNode("prim::Print") 2189 self.assertTrue(print_node.input().toIValue() == 6) 2190 2191 def test_constant_prop_rand(self): 2192 @torch.jit.script 2193 def constant_prop(): 2194 a = torch.randn([3]) 2195 b = a + 2 2196 return b 2197 2198 self.run_pass('constant_propagation', constant_prop.graph) 2199 self.assertTrue("aten::randn" in str(constant_prop.graph)) 2200 2201 def test_constant_prop_none(self): 2202 @torch.jit.script 2203 def typed_none(): 2204 # type: () -> Optional[int] 2205 return None 2206 2207 @torch.jit.script 2208 def constant_prop(): 2209 a = typed_none() 2210 b = typed_none() 2211 if (a is None and b is None): 2212 a = 2 2213 else: 2214 a = 1 2215 return a 2216 2217 self.run_pass('constant_propagation', constant_prop.graph) 2218 FileCheck().check("prim::Constant").run(constant_prop.graph) 2219 2220 def test_constant_prop_if_inline(self): 2221 @torch.jit.script 2222 def constant_prop(): 2223 cond = True 2224 a = 1 2225 if cond: 2226 a = 1 * 2 2227 else: 2228 a = 1 // 0 2229 return a 2230 2231 # testing that 1 // 0 error is not thrownn 2232 self.run_pass('constant_propagation', constant_prop.graph) 2233 2234 def test_constant_prop_exception(self): 2235 # checking y = a[4] does not error in constant propagation 2236 def bad_index(x): 2237 # type: (bool) 2238 y = 0 2239 if x: 2240 a = [1, 2, 3] 2241 y = a[4] 2242 return y 2243 2244 self.checkScript(bad_index, (False,)) 2245 2246 def test_constant_prop_aliasing_type(self): 2247 @torch.jit.script 2248 def foo(): 2249 return len([1]), len(torch.tensor([2])) 2250 2251 FileCheck().check_dag("aten::tensor").check_dag("aten::len").run(foo.graph) 2252 2253 @torch.jit.script 2254 def fn(): 2255 if 1 == 1: 2256 return 1 2257 else: 2258 return 2 2259 2260 FileCheck().check_not("prim::If").run(fn.graph) 2261 2262 def test_unchecked_cast(self): 2263 def test(cond): 2264 # type: (bool) 2265 a = torch.tensor([10]) 2266 if cond: 2267 b = None 2268 else: 2269 b = a 2270 if b is not None: 2271 b[0] = 5 2272 return a.int() 2273 2274 self.checkScript(test, (True,)) 2275 self.checkScript(test, (False,)) 2276 2277 def test_constant_prop_if_constant(self): 2278 @torch.jit.script 2279 def constant_prop(a, b): 2280 c0 = 1 2281 c1 = 1 2282 c2 = 1 2283 if bool(a): # -> c0, c1 2284 if bool(b): # -> c0 2285 if 1 == 1: # -> c0 2286 c0 = c0 + 1 2287 if 1 == 2: 2288 c1 = c1 + 1 2289 c2 = c2 + 1 2290 else: # -> c0, c1 2291 c1 = c1 + 1 2292 2293 if 1 == 1: # inlined 2294 c0 = c0 + 1 # dynamic 2295 c2 = c2 + 4 # set to 5 2296 return a + c0 + c1 + c2 2297 2298 graph = constant_prop.graph 2299 self.run_pass('constant_propagation', graph) 2300 ifs = graph.findAllNodes("prim::If", recurse=False) 2301 snd_if_inlined = len(ifs) == 1 2302 self.assertTrue(snd_if_inlined) 2303 first_if = ifs[0] 2304 self.assertTrue(first_if.outputsSize() == 2) 2305 second_if = first_if.findNode("prim::If", recurse=False) 2306 self.assertTrue(second_if.outputsSize() == 1) 2307 self.assertTrue(second_if.findNode("prim::If") is None) 2308 2309 def test_constant_prop_loop_constant(self): 2310 @torch.jit.script 2311 def constant_prop(cond, iter): 2312 # type: (bool, int) -> int 2313 b = 0 2314 while True: 2315 print("stays") 2316 for _ in range(2): 2317 print("stays") 2318 for _ in range(iter): 2319 print("stays") 2320 while cond: 2321 print("stays") 2322 while False: 2323 print("removed") 2324 for _i in range(0): 2325 print("removed") 2326 for _i in range(-4): 2327 print("removed") 2328 return b 2329 2330 self.run_pass('constant_propagation', constant_prop.graph) 2331 graph = canonical(constant_prop.graph) 2332 self.assertTrue(graph.count("removed") == 0) 2333 self.assertTrue(graph.count("stays") == 1) # constant gets pooled 2334 self.assertTrue(graph.count("prim::Print") == 4) 2335 2336 def test_constant_prop_remove_output(self): 2337 @torch.jit.script 2338 def constant_prop(iter): 2339 # type: (int) -> None 2340 a = 1 2341 b = 1 2342 c = 1 2343 for i in range(iter): 2344 if 1 == 2: 2345 a = 10 2346 if i == 5: 2347 b = 2 2348 c = 3 2349 print(a, b, c) 2350 2351 graph = constant_prop.graph 2352 self.run_pass('constant_propagation', graph) 2353 self.assertTrue(graph.findNode("prim::Loop").outputsSize() == 2) 2354 2355 # TODO(gmagogsfm): Refactor this test to reduce complexity. 2356 def test_constant_insertion(self): 2357 funcs_template = dedent(''' 2358 def func(): 2359 return {constant_constructor} 2360 ''') 2361 2362 # constants: primitives: int, double, bool, str, lists of primitives, 2363 # and tuples 2364 def check_constant(constant_constructor): 2365 scope = {} 2366 funcs_str = funcs_template.format(constant_constructor=constant_constructor) 2367 execWrapper(funcs_str, globals(), scope) 2368 cu = torch.jit.CompilationUnit(funcs_str) 2369 f_script = cu.func 2370 self.run_pass('constant_propagation', f_script.graph) 2371 FileCheck().check_count("prim::Constant", 1, exactly=True).run(f_script.graph) 2372 self.assertEqual(scope['func'](), f_script()) 2373 imported = self.getExportImportCopy(f_script) 2374 self.assertEqual(imported(), f_script()) 2375 2376 constants = ["None", "-.5", "0", "1", "True", "False", "''", "'a'", "'b'", "torch.tensor(1)", 2377 "[True, False]", "[0., .5]", "[torch.tensor(4), torch.tensor(2)]", "[0, 1]", "['0', '1']", 2378 "[True, None]", "[.5, None, .2]"] 2379 2380 for type in ["Tensor", "str", "int", "float", "bool"]: 2381 constants.append("torch.jit.annotate(List[ " + type + "], [])") 2382 2383 for constant in constants: 2384 check_constant(constant) 2385 2386 for key_type in ["str", "int", "float"]: 2387 for value_type in ["Tensor", "bool", "str", "int", "float"]: 2388 check_constant("torch.jit.annotate(Dict[ " + key_type + ", " + value_type + "], {})") 2389 check_constant("torch.jit.annotate(Dict[ " + key_type + ", Optional[" + value_type + "]], {})") 2390 2391 for i in range(len(constants)): 2392 for j in range(i + 1, len(constants)): 2393 tup_constant = constants[i] + ", " + constants[j] 2394 check_constant(tup_constant) 2395 2396 dict_constants = [] 2397 for i in range(len(constants)): 2398 # check_constant constructs the second dict with another Tensor 2399 # which fails the comparison 2400 if not isinstance(eval(constants[i]), (str, int, float)): 2401 continue 2402 for j in range(len(constants)): 2403 dict_constant = "{ " + constants[i] + ": " + constants[j] + "}" 2404 check_constant(dict_constant) 2405 dict_constants.append(dict_constant) 2406 constants = constants + dict_constants 2407 2408 # testing node hashing 2409 funcs_template = dedent(''' 2410 def func(): 2411 print({constant_constructor}) 2412 ''') 2413 single_elem_tuples = ("(" + x + ",)" for x in constants) 2414 input_arg = ", ".join(single_elem_tuples) 2415 scope = {} 2416 funcs_str = funcs_template.format(constant_constructor=input_arg) 2417 execWrapper(funcs_str, globals(), scope) 2418 cu = torch.jit.CompilationUnit(funcs_str) 2419 f_script = cu.func 2420 self.run_pass('constant_propagation', f_script.graph) 2421 # prim::None return adds one constant 2422 self.assertEqual(len(constants) + 1, str(f_script.graph).count("prim::Constant")) 2423 self.run_pass('cse', f_script.graph) 2424 # node hashing correctly working, no CSE occurs 2425 self.assertEqual(len(constants) + 1, str(f_script.graph).count("prim::Constant")) 2426 2427 funcs_template = dedent(''' 2428 def func(): 2429 a = {constant_constructor} 2430 print(a) 2431 b = {constant_constructor} 2432 print(b) 2433 ''') 2434 2435 # generate dicts with built-in types (excluding torch.Tensor) 2436 xprod = itertools.product(constants, constants) 2437 2438 # test that equal tuples and dicts correctly work with node hashing 2439 for tup in ("(" + x + ",)" for x in constants): 2440 funcs_str = funcs_template.format(constant_constructor=tup) 2441 scope = {} 2442 execWrapper(funcs_str, globals(), scope) 2443 cu = torch.jit.CompilationUnit(funcs_str) 2444 f_script = cu.func 2445 self.run_pass('constant_propagation_immutable_types', f_script.graph) 2446 num_constants = str(f_script.graph).count("prim::Constant") 2447 self.run_pass('cse', f_script.graph) 2448 FileCheck().check_count("prim::Constant", num_constants, exactly=True).run(f_script.graph) 2449 2450 @unittest.skipIf(not RUN_CUDA, "requires CUDA") 2451 def test_cuda_export_restore(self): 2452 class Sub(torch.jit.ScriptModule): 2453 def __init__(self) -> None: 2454 super().__init__() 2455 self.weight = nn.Parameter(torch.randn(3, 4)) 2456 2457 @torch.jit.script_method 2458 def forward(self, thing): 2459 return self.weight + thing 2460 2461 class M(torch.jit.ScriptModule): 2462 def __init__(self) -> None: 2463 super().__init__() 2464 self.mod = Sub() 2465 2466 @torch.jit.script_method 2467 def forward(self, v): 2468 return self.mod(v) 2469 m = M() 2470 m.cuda() 2471 m2 = self.getExportImportCopy(m) 2472 m2.cuda() 2473 input = torch.rand(3, 4).cuda() 2474 self.assertEqual(m(input), m2(input)) 2475 2476 @slowTest 2477 def test_export_batchnorm(self): 2478 for mode in ['eval', 'train']: 2479 for clazz in [ 2480 torch.nn.BatchNorm1d(100), 2481 torch.nn.BatchNorm1d(100, affine=False), 2482 torch.nn.BatchNorm2d(100), 2483 torch.nn.BatchNorm2d(100, affine=False)]: 2484 getattr(clazz, mode)() 2485 input = torch.randn(20, 100) if isinstance(clazz, torch.nn.BatchNorm1d) else \ 2486 torch.randn(20, 100, 35, 45) 2487 traced = torch.jit.trace(clazz, (input,)) 2488 imported = self.getExportImportCopy(traced) 2489 x = torch.randn(20, 100) if isinstance(clazz, torch.nn.BatchNorm1d) else \ 2490 torch.randn(20, 100, 35, 45) 2491 self.assertEqual(traced(x), imported(x)) 2492 2493 def test_export_rnn(self): 2494 for clazz in [nn.RNN(10, 20, 2), nn.GRU(10, 20, 2)]: 2495 class RNNTest(torch.nn.Module): 2496 def __init__(self) -> None: 2497 super().__init__() 2498 self.rnn = clazz 2499 2500 def forward(self, x, lengths, h0): 2501 packed = torch.nn.utils.rnn.pack_padded_sequence(x, lengths) 2502 out, h = self.rnn(packed, h0) 2503 padded_outs, _ = torch.nn.utils.rnn.pad_packed_sequence(out) 2504 return padded_outs 2505 2506 test = RNNTest() 2507 2508 traced = torch.jit.trace(test, (torch.randn(5, 3, 10), torch.LongTensor([3, 2, 1]), torch.randn(2, 3, 20))) 2509 imported = self.getExportImportCopy(traced) 2510 # NB: We make sure to pass in a batch with a different max sequence 2511 # length to ensure that the argument stashing for pad_packed works 2512 # properly. 2513 x, lengths, h0 = torch.randn(7, 4, 10), torch.LongTensor([7, 3, 2, 1]), torch.randn(2, 4, 20) 2514 self.assertEqual(traced(x, lengths, h0), imported(x, lengths, h0)) 2515 2516 def test_export_lstm(self): 2517 class LSTMTest(torch.nn.Module): 2518 def __init__(self) -> None: 2519 super().__init__() 2520 self.rnn = nn.LSTM(10, 20, 2) 2521 2522 def forward(self, x, lengths, hiddens): 2523 h0, c0 = hiddens 2524 packed = torch.nn.utils.rnn.pack_padded_sequence(x, lengths) 2525 out, (h, c) = self.rnn(packed, (h0, c0)) 2526 padded_outs, _ = torch.nn.utils.rnn.pad_packed_sequence(out) 2527 return padded_outs 2528 2529 test = LSTMTest() 2530 2531 traced = torch.jit.trace(test, (torch.randn(5, 3, 10), 2532 torch.LongTensor([3, 2, 1]), 2533 (torch.randn(2, 3, 20), torch.randn(2, 3, 20)))) 2534 imported = self.getExportImportCopy(traced) 2535 x, lengths, h0, c0 = \ 2536 torch.randn(7, 3, 10), torch.LongTensor([7, 5, 2]), torch.randn(2, 3, 20), torch.randn(2, 3, 20) 2537 self.assertEqual(traced(x, lengths, (h0, c0)), imported(x, lengths, (h0, c0))) 2538 2539 def test_unique_state_dict(self): 2540 class MyModule(torch.nn.Module): 2541 def __init__(self) -> None: 2542 super().__init__() 2543 shared_param = torch.nn.Parameter(torch.ones(1)) 2544 self.register_parameter('w1', shared_param) 2545 self.register_parameter('w2', shared_param) 2546 2547 def forward(self, input): 2548 return input + self.w1 + self.w2 2549 2550 model = MyModule() 2551 unittest.TestCase.assertEqual( 2552 self, len(torch.jit._unique_state_dict(model, keep_vars=False)), 1) 2553 unittest.TestCase.assertEqual( 2554 self, len(torch.jit._unique_state_dict(model, keep_vars=True)), 1) 2555 2556 def test_export_dropout(self): 2557 test = torch.nn.Dropout() 2558 test.eval() 2559 2560 traced = torch.jit.trace(test, (torch.rand(3, 4),), check_trace=False) 2561 imported = self.getExportImportCopy(traced) 2562 x = torch.randn(3, 4) 2563 self.assertEqual(traced(x), imported(x)) 2564 2565 def test_pretty_printer(self): 2566 @torch.jit.script 2567 def if_test(a, b): 2568 # FIXME: use 0 instead of a. 2569 # c = 0 2570 c = a 2571 if bool(a < b): 2572 c = b 2573 else: 2574 c = a 2575 return c 2576 2577 @torch.jit.script 2578 def if_one(a, b): 2579 c = b 2580 if bool(a < b): 2581 c = a 2582 return c 2583 2584 @torch.jit.script 2585 def while_test(a, i): 2586 while bool(i < 3): 2587 a *= a 2588 i += 1 2589 return a 2590 2591 @torch.jit.script 2592 def while_if_test(a, b): 2593 c = 0 2594 while bool(a < 10): 2595 a = a + 1 2596 b = b + 1 2597 if bool(a > b): 2598 c = 2 2599 else: 2600 c = 3 2601 return a + 1 + c 2602 2603 @torch.jit.script 2604 def loop_use_test(y): 2605 x = y + 1 2606 z = x + 5 2607 while bool(y < 8): 2608 y += 1 2609 z = x 2610 return x, z 2611 2612 @torch.jit.ignore 2613 def python_fn(x): 2614 return x + 10 2615 2616 @torch.jit.script 2617 def python_op_name_test(y): 2618 return python_fn(y) 2619 2620 @torch.jit.script 2621 def empty_int_list_test(y): 2622 x = torch.jit.annotate(List[int], []) 2623 return x[0] 2624 2625 @torch.jit.script 2626 def empty_float_list_test(y): 2627 return [1.0, 2.0, 3.0] 2628 2629 @torch.jit.script 2630 def print_weird_test(y): 2631 print("hi\016") 2632 2633 self.assertExpected(if_test.code, "if_test") 2634 self.assertExpected(if_one.code, "if_one") 2635 self.assertExpected(while_test.code, "while_test") 2636 self.assertExpected(while_if_test.code, "while_if_test") 2637 self.assertExpected(loop_use_test.code, "loop_use_test") 2638 self.assertExpected(python_op_name_test.code, "python_op_name_test") 2639 self.assertExpected(empty_int_list_test.code, "empty_int_list_test") 2640 self.assertExpected(empty_float_list_test.code, "empty_float_list_test") 2641 self.assertExpected(print_weird_test.code, "print_weird_test") 2642 2643 def test_cu_escaped_number(self): 2644 cu = torch.jit.CompilationUnit(''' 2645 def foo(a): 2646 print("hi\016") 2647 ''') 2648 self.assertExpected(cu.foo.code) 2649 2650 def test_import_method(self): 2651 with torch._jit_internal._disable_emit_hooks(): 2652 class Foo(torch.jit.ScriptModule): 2653 @torch.jit.script_method 2654 def forward(self, x, y): 2655 return 2 * x + y 2656 2657 foo = Foo() 2658 buffer = io.BytesIO() 2659 torch.jit.save(foo, buffer) 2660 2661 buffer.seek(0) 2662 foo_loaded = torch.jit.load(buffer) 2663 self.assertExpected(foo_loaded.forward.code) 2664 2665 @unittest.skip("temporarily disable the test for fwd compatibility") 2666 def test_non_ascii_string(self): 2667 class Foo(torch.jit.ScriptModule): 2668 def __init__(self) -> None: 2669 super().__init__() 2670 self.a = "Over \u0e55\u0e57 57" 2671 2672 @torch.jit.script_method 2673 def forward(self, x, y): 2674 return self.a + "hi\xA1" 2675 2676 foo = Foo() 2677 buffer = io.BytesIO() 2678 torch.jit.save(foo, buffer) 2679 2680 buffer.seek(0) 2681 foo_loaded = torch.jit.load(buffer) 2682 self.assertExpected(foo_loaded.forward.code) 2683 2684 def test_function_default_values(self): 2685 outer_var = torch.tensor(20) 2686 outer_var2 = torch.tensor(30) 2687 a = torch.tensor(0.5) 2688 b = torch.tensor(10) 2689 2690 @torch.jit.script 2691 def simple_fn(x, a=a, b=b, c=outer_var + outer_var2): 2692 return x + a + b + c 2693 2694 self.assertEqual( 2695 simple_fn(torch.ones(1)), 2696 torch.ones(1) + 0.5 + 10 + (20 + 30)) 2697 self.assertEqual( 2698 simple_fn(torch.ones(1), torch.tensor(1), torch.tensor(3), torch.tensor(4)), 2699 torch.ones(1) + 1 + 3 + 4) 2700 2701 outer_c = torch.tensor(9) 2702 outer_flag = torch.tensor(False) 2703 2704 @torch.jit.script 2705 def bool_fn(x, a=outer_c, flag=outer_flag): 2706 if bool(flag): 2707 result = x 2708 else: 2709 result = x + a 2710 return result 2711 2712 self.assertEqual(bool_fn(torch.ones(1)), torch.ones(1) + 9) 2713 self.assertEqual( 2714 bool_fn(torch.ones(1), torch.tensor(1), torch.tensor(True)), 2715 torch.ones(1)) 2716 2717 @torch.jit.script 2718 def none_fn(x=None): 2719 # type: (Optional[int]) -> Optional[int] 2720 return x 2721 2722 self.assertEqual(none_fn(), None) 2723 self.assertEqual(none_fn(1), 1) 2724 2725 @torch.jit.script 2726 def hints(x, a=0.5, b=10): 2727 # type: (Tensor, float, int) -> Tensor 2728 return x + a + b 2729 2730 self.assertEqual(hints(torch.ones(1)), torch.ones(1) + 0.5 + 10) 2731 2732 with self.assertRaisesRegex(RuntimeError, "Expected a default value"): 2733 2734 @torch.jit.script 2735 def hints_bad_types(x, a=10, b=0.5): # noqa: T484 2736 # type: (Tensor, float, int) -> Tensor 2737 return x + a + b 2738 with self.assertRaisesRegex(RuntimeError, "Expected a default value"): 2739 @torch.jit.script 2740 def bad_no_optional(x=None): 2741 # type: (Dict[str, int]) -> Dict[str, int] 2742 return x 2743 2744 2745 def test_module_default_values(self): 2746 four = torch.tensor(4) 2747 2748 class Test(torch.jit.ScriptModule): 2749 @torch.jit.script_method 2750 def forward(self, input, other=four): 2751 return input + other 2752 2753 t = Test() 2754 self.assertEqual(t(torch.ones(1)), torch.ones(1) + 4) 2755 2756 def test_mutable_default_values(self): 2757 with self.assertRaisesRegex(Exception, "Mutable default parameters"): 2758 @torch.jit.script 2759 def foo(x=(1, [])): 2760 # type: (Tuple[int, List[Tensor]]) 2761 return x 2762 2763 class Test(torch.nn.Module): 2764 def forward(self, input=[]): # noqa: B006 2765 return input 2766 2767 with self.assertRaisesRegex(Exception, "Mutable default parameters"): 2768 torch.jit.script(Test()) 2769 2770 @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 2771 def test_warnings(self): 2772 import warnings 2773 2774 def fn(x): 2775 if bool(x < 2): 2776 warnings.warn("x is less than 2") 2777 return x 2778 2779 class M(torch.nn.Module): 2780 def forward(self, x): 2781 if bool(x < 2): 2782 warnings.warn("x is less than 2") 2783 return x 2784 2785 2786 scripted_mod = torch.jit.script(M()) 2787 scripted_fn = torch.jit.script(fn) 2788 2789 with warnings.catch_warnings(record=True) as warns: 2790 fn(torch.ones(1)) 2791 2792 with warnings.catch_warnings(record=True) as script_warns: 2793 scripted_fn(torch.ones(1)) 2794 2795 with warnings.catch_warnings(record=True) as script_mod_warns: 2796 scripted_mod(torch.ones(1)) 2797 2798 self.assertEqual(str(warns[0]), str(script_warns[0])) 2799 self.assertEqual(len(script_mod_warns), 1) 2800 self.assertEqual(str(warns[0].message), str(script_mod_warns[0].message)) 2801 2802 def test_no_erroneous_warnings(self): 2803 import warnings 2804 2805 def fn(x): 2806 if bool(x > 0): 2807 warnings.warn('This should NOT be printed') 2808 x += 1 2809 return x 2810 2811 with warnings.catch_warnings(record=True) as warns: 2812 fn_script = torch.jit.script(fn) 2813 fn_script(torch.tensor(0)) 2814 warns = [str(w.message) for w in warns] 2815 self.assertEqual(len(warns), 0) 2816 2817 @unittest.skipIf(True, "TODO: re-enable with https://github.com/pytorch/pytorch/pull/29339") 2818 def test_torch_load_error(self): 2819 class J(torch.jit.ScriptModule): 2820 @torch.jit.script_method 2821 def forward(self, input): 2822 return input + 100 2823 2824 j = J() 2825 with TemporaryFileName() as fname: 2826 j.save(fname) 2827 with self.assertRaisesRegex(RuntimeError, "is a zip"): 2828 torch.load(fname) 2829 2830 def test_torch_load_zipfile_check(self): 2831 @torch.jit.script 2832 def fn(x): 2833 return x + 10 2834 2835 with TemporaryFileName() as fname: 2836 fn.save(fname) 2837 with open(fname, 'rb') as f: 2838 self.assertTrue(torch.serialization._is_zipfile(f)) 2839 2840 def test_python_bindings(self): 2841 lstm_cell = torch.jit.script(LSTMCellS) 2842 2843 def lstm(x, hx, cx, w_ih, w_hh, b_ih, b_hh): 2844 for i in range(x.size(0)): 2845 hx, cx = lstm_cell(x[i], hx, cx, w_ih, w_hh, b_ih, b_hh) 2846 return hx 2847 2848 slstm = torch.jit.script(lstm) 2849 2850 inputs = get_lstm_inputs('cpu', training=True, seq_length=10) 2851 slstm(*inputs).sum().backward() 2852 global fw_graph 2853 fw_graph = slstm.graph_for(*inputs) 2854 nodes = list(fw_graph.nodes()) 2855 tested_blocks = False 2856 for node in nodes: 2857 for output in node.outputs(): 2858 self.assertTrue(hasattr(output, 'type')) 2859 self.assertTrue(output.type() is not None) 2860 for input in node.inputs(): 2861 self.assertTrue(hasattr(input, 'type')) 2862 self.assertTrue(input.type() is not None) 2863 for block in node.blocks(): 2864 tested_blocks = True 2865 self.assertTrue(hasattr(block, 'inputs')) 2866 self.assertTrue(hasattr(block, 'outputs')) 2867 for output in block.outputs(): 2868 self.assertTrue(hasattr(output, 'type')) 2869 self.assertTrue(output.type() is not None) 2870 for input in block.inputs(): 2871 self.assertTrue(hasattr(input, 'type')) 2872 self.assertTrue(input.type() is not None) 2873 self.assertTrue(hasattr(block, 'returnNode')) 2874 self.assertTrue(type(block.returnNode()) == torch._C.Node) 2875 self.assertTrue(hasattr(block, 'paramNode')) 2876 self.assertTrue(type(block.paramNode()) == torch._C.Node) 2877 self.assertTrue(tested_blocks) 2878 2879 def test_export_opnames(self): 2880 class Foo(torch.jit.ScriptModule): 2881 def one(self, x, y): 2882 # type: (Tensor, Tensor) -> Tensor 2883 return x + y 2884 2885 def two(self, x): 2886 # type: (Tensor) -> Tensor 2887 return 2 * x 2888 2889 @torch.jit.script_method 2890 def forward(self, x): 2891 # type: (Tensor) -> Tensor 2892 return self.one(self.two(x), x) 2893 2894 class Bar(torch.jit.ScriptModule): 2895 def __init__(self) -> None: 2896 super().__init__() 2897 self.sub = Foo() 2898 2899 @torch.jit.script_method 2900 def forward(self, x): 2901 # type: (Tensor) -> Tensor 2902 return self.sub.forward(x) 2903 2904 bar = Bar() 2905 ops = torch.jit.export_opnames(bar) 2906 expected = ['aten::add.Tensor', 'aten::mul.Scalar'] 2907 self.assertTrue(set(expected).issubset(set(ops))) 2908 2909 def test_pytorch_jit_env_off(self): 2910 import subprocess 2911 env = os.environ.copy() 2912 env['PYTORCH_JIT'] = '0' 2913 try: 2914 subprocess.check_output([sys.executable, '-c', 'import torch'], env=env) 2915 except subprocess.CalledProcessError as e: 2916 raise RuntimeError("Could not 'import torch' with PYTORCH_JIT=0") from e 2917 2918 def test_print_op_module(self): 2919 # Issue #19351: python2 and python3 go through different paths. 2920 # python2 returns '<module 'torch.ops' (built-in)>' 2921 # python3 uses __file__ and return 2922 # '<module 'torch.ops' from '/scratch/ailzhang/pytorch/torch/_ops.py'>' 2923 s = str(torch.ops) 2924 self.assertRegex(s, r'ops') 2925 2926 def test_print_classes_module(self): 2927 s = str(torch.classes) 2928 self.assertRegex(s, r'classes') 2929 2930 def test_print_torch_ops_modules(self): 2931 s = str(torch._ops.ops.quantized) 2932 self.assertRegex(s, r'torch.ops') 2933 s = str(torch._ops.ops.atan) 2934 self.assertRegex(s, r'torch.ops') 2935 2936 def test_hide_source_ranges_context_manager(self): 2937 @torch.jit.script 2938 def foo(x): 2939 return torch.add(x, x) 2940 2941 graph = foo.graph 2942 source_range_regex = "# .*\\.py" 2943 self.assertRegex(graph.__repr__(), source_range_regex) 2944 with torch.jit._hide_source_ranges(): 2945 self.assertNotRegex(graph.__repr__(), source_range_regex) 2946 self.assertRegex(graph.str(print_source_ranges=True), source_range_regex) 2947 self.assertRegex(graph.__repr__(), source_range_regex) 2948 2949 2950class TestFrontend(JitTestCase): 2951 2952 def test_instancing_error(self): 2953 @torch.jit.ignore 2954 class MyScriptClass: 2955 def unscriptable(self): 2956 return "a" + 200 2957 2958 2959 class TestModule(torch.nn.Module): 2960 def forward(self, x): 2961 return MyScriptClass() 2962 2963 with self.assertRaises(torch.jit.frontend.FrontendError) as cm: 2964 torch.jit.script(TestModule()) 2965 2966 checker = FileCheck() 2967 checker.check("Cannot instantiate class") 2968 checker.check("def forward") 2969 checker.run(str(cm.exception)) 2970 2971 def test_dictionary_as_example_inputs_for_jit_trace(self): 2972 class TestModule_v1(torch.nn.Module): 2973 def forward(self, key2=None, key3=None, key4=None, key5=None, key1=None, key6=None): 2974 return key1 + key2 + key3 2975 2976 class TestModule_v2(torch.nn.Module): 2977 def forward(self, x, y): 2978 return x + y 2979 2980 def test_func(x, y): 2981 return x + y 2982 model_1 = TestModule_v1() 2983 model_2 = TestModule_v2() 2984 value1 = torch.ones(1) 2985 value2 = torch.ones(1) 2986 value3 = torch.ones(1) 2987 example_input_dict = {'key1': value1, 'key2': value2, 'key3': value3} 2988 example_input_dict_func = {'x': value1, 'y': value2} 2989 traced_model_1 = torch.jit.trace(model_1, example_kwarg_inputs=example_input_dict, strict=False) 2990 traced_model_1_m = torch.jit.trace_module( 2991 model_1, {'forward': example_input_dict}, example_inputs_is_kwarg=True, strict=False) 2992 traced_model_2 = torch.jit.trace(model_2, example_kwarg_inputs={'x': torch.rand([2]), 'y': torch.rand([2])}) 2993 traced_func = torch.jit.trace(test_func, example_kwarg_inputs=example_input_dict_func, strict=False) 2994 res_1 = traced_model_1(**example_input_dict) 2995 res_1_m = traced_model_1_m(**example_input_dict) 2996 self.assertEqual(res_1, 3 * torch.ones(1)) 2997 self.assertEqual(res_1_m, 3 * torch.ones(1)) 2998 res_func = traced_func(**example_input_dict_func) 2999 self.assertEqual(res_func, 2 * torch.ones(1)) 3000 with self.assertRaisesRegex(RuntimeError, r"forward\(\) is missing value for argument 'x'."): 3001 res_2 = traced_model_2(**{'z': torch.rand([2]), 'y': torch.rand([2])}) # noqa: PIE804 3002 with self.assertRaisesRegex(RuntimeError, r"forward\(\) is missing value for argument 'y'."): 3003 res_2 = traced_model_2(**{'x': torch.rand([2]), 'z': torch.rand([2])}) # noqa: PIE804 3004 3005 3006class TestScript(JitTestCase): 3007 3008 # Tests that calling torch.jit.script repeated on function is allowed. 3009 def test_repeated_script_on_function(self): 3010 @torch.jit.script 3011 @torch.jit.script 3012 def fn(x): 3013 return x 3014 3015 torch.jit.script(torch.jit.script(fn)) 3016 3017 def test_pretty_print_function(self): 3018 @torch.jit.script 3019 def foo(x): 3020 return torch.nn.functional.interpolate(x) 3021 3022 FileCheck().check("interpolate").run(foo.code) 3023 3024 def test_inlined_graph(self): 3025 """ 3026 Check that the `inlined_graph` property correctly returns an inlined 3027 graph, both through function calls and method calls. 3028 """ 3029 @torch.jit.script 3030 def foo(x): 3031 return torch.add(x, x) 3032 3033 class MyNestedMod(torch.nn.Module): 3034 def forward(self, x): 3035 return torch.sub(x, x) 3036 3037 3038 class MyMod(torch.nn.Module): 3039 def __init__(self) -> None: 3040 super().__init__() 3041 self.nested = MyNestedMod() 3042 3043 def forward(self, x): 3044 x = self.nested(x) # sub 3045 x = foo(x) # add 3046 return torch.mul(x, x) 3047 3048 m = torch.jit.script(MyMod()) 3049 FileCheck().check("aten::sub") \ 3050 .check("aten::add") \ 3051 .check("aten::mul") \ 3052 .run(m.inlined_graph) 3053 3054 def test_static_method_on_module(self): 3055 """ 3056 Check that the `@staticmethod` annotation on a function on a module works. 3057 """ 3058 class MyCell(torch.nn.Module): 3059 @staticmethod 3060 def do_it(x, h): 3061 new_h = torch.tanh(x + h) 3062 return new_h, new_h 3063 3064 def forward(self, x, h): 3065 return self.do_it(x, h) 3066 3067 my_cell = torch.jit.script(MyCell()) 3068 x = torch.rand(3, 4) 3069 h = torch.rand(3, 4) 3070 jitted_cell = my_cell(x, h) 3071 non_jitted_cell = MyCell().do_it(x, h) 3072 3073 self.assertEqual(jitted_cell, non_jitted_cell) 3074 3075 def test_code_with_constants(self): 3076 """ 3077 Check that the `code_with_constants` property correctly returns graph CONSTANTS in the 3078 CONSTANTS.cN format used in the output of the `code` property. 3079 """ 3080 @torch.jit.script 3081 def foo(x=torch.ones(1)): 3082 return x 3083 3084 class Moddy(torch.nn.Module): 3085 def forward(self, x): 3086 return foo() 3087 3088 m = torch.jit.script(Moddy()) 3089 src, CONSTANTS = m.code_with_constants 3090 3091 self.assertEqual(CONSTANTS.c0, torch.ones(1)) 3092 self.assertEqual(src, m.code) 3093 3094 def test_code_with_constants_restore(self): 3095 """ 3096 Check that the `code_with_constants` property correctly works on restoration after save() + load() 3097 """ 3098 @torch.jit.script 3099 def foo(x=torch.ones(1)): 3100 return x 3101 3102 class Moddy(torch.nn.Module): 3103 def forward(self, x): 3104 return foo() 3105 3106 m = torch.jit.script(Moddy()) 3107 src, CONSTANTS = m.code_with_constants 3108 eic = self.getExportImportCopy(m) 3109 3110 src_eic, CONSTANTS_eic = eic.code_with_constants 3111 3112 self.assertEqual(src, src_eic) 3113 self.assertEqual(CONSTANTS.c0, CONSTANTS_eic.c0) 3114 3115 3116 def test_oneline_func(self): 3117 def fn(x): return x # noqa: E704 3118 3119 self.checkScript(fn, (torch.ones(2, 2), )) 3120 3121 def test_request_bailout(self): 3122 with enable_profiling_mode_for_profiling_tests(): 3123 3124 def fct_loop(x): 3125 for i in range(3): 3126 x = torch.cat((x, x), 0) 3127 return x 3128 3129 x = torch.ones(2, 3, 4, dtype=torch.float32) 3130 expected = fct_loop(x) 3131 jitted = torch.jit.script(fct_loop) 3132 # profile 3133 jitted(x) 3134 # optimize 3135 jitted(x) 3136 dstate = jitted.get_debug_state() 3137 eplan = get_execution_plan(dstate) 3138 num_bailouts = eplan.code.num_bailouts() 3139 3140 for i in range(0, num_bailouts): 3141 eplan.code.request_bailout(i) 3142 self.assertEqual(jitted(x), expected) 3143 3144 @unittest.skip("bailouts are being deprecated") 3145 def test_dominated_bailout(self): 3146 with enable_profiling_mode_for_profiling_tests(): 3147 # functional dominated guard 3148 @torch.jit.script 3149 def foo(x): 3150 dim = x.dim() 3151 if dim == 0: 3152 y = int(x) 3153 else: 3154 y = x.size()[dim - 1] 3155 return y 3156 3157 x = torch.zeros(2) 3158 self.assertEqual(foo(x), 2) 3159 self.assertEqual(foo(x), 2) 3160 g = torch.jit.last_executed_optimized_graph() 3161 g_s = str(g) 3162 g_s = g_s[0:g_s.find("return")] 3163 FileCheck().check_count("prim::BailOut[", 1, exactly=True).run(g_s) 3164 3165 # dominated guard of non-functional value 3166 @torch.jit.script 3167 def foo(x): 3168 dim = x.dim() 3169 x.add_(3) 3170 if dim == 0: 3171 return 0 3172 else: 3173 return x.size()[dim - 1] 3174 3175 x = torch.zeros(2) 3176 self.assertEqual(foo(x), 2) 3177 self.assertEqual(foo(x), 2) 3178 g = torch.jit.last_executed_optimized_graph() 3179 FileCheck().check("prim::BailOut[").check("aten::add_").check_next("prim::BailOut[").check("return").run(g) 3180 3181 with torch.enable_grad(): 3182 @torch.jit.ignore 3183 def disable_grad(): 3184 torch.set_grad_enabled(False) 3185 3186 @torch.jit.ignore 3187 def enable_grad(): 3188 torch.set_grad_enabled(True) 3189 3190 @torch.jit.script 3191 def foo(x): 3192 x = x + 1 3193 dim = x.dim() 3194 disable_grad() 3195 if dim == 0: 3196 y = int(x) 3197 else: 3198 y = x.size()[dim - 1] 3199 enable_grad() 3200 return y 3201 3202 x = torch.zeros(2, requires_grad=True) 3203 self.assertEqual(foo(x), 2) 3204 self.assertEqual(foo(x), 2) 3205 g = torch.jit.last_executed_optimized_graph() 3206 # there should still be a Bailout after disable_grad call 3207 FileCheck().check("disable_grad").check("BailOut[").check("BailoutTemplate").run(g) 3208 3209 @skipIfTorchDynamo("Torchdynamo cannot correctly handle profiler.profile calls") 3210 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "skip if profiling isn't enabled") 3211 def test_profiling_merge(self): 3212 @torch.jit.script 3213 def test_not_const(x): 3214 if x.size(0) == 1: 3215 return 1 3216 else: 3217 return 2 3218 3219 with enable_profiling_mode_for_profiling_tests(): 3220 with num_profiled_runs(2): 3221 test_not_const(torch.rand([1, 2])) 3222 test_not_const(torch.rand([2, 2])) 3223 3224 graph_str = torch.jit.last_executed_optimized_graph() 3225 FileCheck().check("profiled_type=Float(*, 2, strides=[2, 1], requires_grad=0, device=cpu").run(graph_str) 3226 FileCheck().check_not("profiled_type=Float(1, 2, strides=[2, 1], requires_grad=0, device=cpu").run(graph_str) 3227 3228 3229 def test_nested_bailouts(self): 3230 @torch.jit.script 3231 def fct_loop(x): 3232 for i in range(3): 3233 x = torch.cat((x, x), 0) 3234 return x 3235 3236 x = torch.ones(2, 3, 4, dtype=torch.float32) 3237 out = fct_loop(x) 3238 jit_trace = torch.jit.trace(fct_loop, x) 3239 out_trace = jit_trace(x) 3240 3241 def test_no_self_arg_ignore_function(self): 3242 class MyModule(nn.Module): 3243 @torch.jit.ignore # noqa: B902 3244 def call_np(): # noqa: B902 3245 # type: () -> int 3246 return np.random.choice(2, p=[.95, .05]) 3247 3248 def forward(self): 3249 return self.call_np() 3250 3251 with self.assertRaisesRegex(Exception, "does not have a self argument"): 3252 torch.jit.script(MyModule()) 3253 3254 def test_loop_liveness(self): 3255 with enable_profiling_mode_for_profiling_tests(): 3256 @torch.jit.script 3257 def f(i): 3258 # type: (int) -> Tensor 3259 l = [] 3260 for n in [2, 1]: 3261 l.append(torch.zeros(n, i)) 3262 3263 return l[0] 3264 3265 f(2) 3266 f(1) 3267 3268 def test_bailout_loop_carried_deps_name_clash(self): 3269 with enable_profiling_mode_for_profiling_tests(): 3270 NUM_ITERATIONS = 10 3271 3272 @torch.jit.script 3273 def fct_loop(z, size): 3274 # type: (int, int) -> Tuple[Tensor, List[int]] 3275 counters = torch.jit.annotate(List[int], []) 3276 j = 0 3277 y = torch.ones(2) 3278 for i in range(size): 3279 counters.append(i + j) 3280 y = torch.cat((y, torch.ones(z)), 0) 3281 j = j + 1 3282 return y, counters 3283 3284 inputs = [1, 2, 3, 4] 3285 expected = [x * 2 for x in range(NUM_ITERATIONS)] 3286 for inp in inputs: 3287 results = fct_loop(inp, NUM_ITERATIONS) 3288 self.assertEqual(results[1], expected) 3289 3290 def test_bailout_loop_counter_transition(self): 3291 with enable_profiling_mode_for_profiling_tests(): 3292 NUM_ITERATIONS = 10 3293 3294 @torch.jit.script 3295 def fct_loop(z, size): 3296 # type: (int, int) -> Tuple[Tensor, List[int]] 3297 counters = torch.jit.annotate(List[int], []) 3298 y = torch.ones(2) 3299 for i in range(size): 3300 counters.append(i) 3301 y = torch.cat((y, torch.ones(z)), 0) 3302 return y, counters 3303 3304 inputs = [1, 2, 3, 4] 3305 expected = list(range(NUM_ITERATIONS)) 3306 for inp in inputs: 3307 results = fct_loop(inp, NUM_ITERATIONS) 3308 self.assertEqual(results[1], expected) 3309 3310 def test_ignored_method_binding(self): 3311 class Bar(torch.nn.Module): 3312 def __init__(self) -> None: 3313 super().__init__() 3314 self.x : int = 0 3315 3316 @torch.jit.export 3317 def setx(self, x : int): 3318 self.x = x 3319 3320 @torch.jit.export 3321 def getx(self): 3322 return self.x 3323 3324 @torch.jit.ignore 3325 def ignored_getx(self): 3326 return self.x 3327 3328 b = Bar() 3329 b.setx(123) 3330 sb = torch.jit.script(b) 3331 self.assertEqual(sb.getx(), 123) 3332 self.assertEqual(sb.ignored_getx(), 123) 3333 3334 sb.setx(456) 3335 self.assertEqual(sb.getx(), 456) 3336 self.assertEqual(sb.ignored_getx(), 456) 3337 3338 def test_set_attribute_through_optional(self): 3339 class A(torch.nn.Module): 3340 __annotations__ = {"x": Optional[torch.Tensor]} 3341 3342 def __init__(self) -> None: 3343 super().__init__() 3344 self.x = None 3345 3346 @torch.jit.ignore 3347 def foo(self): 3348 if self.x is None: 3349 self.x = torch.tensor([3]) 3350 return self.x 3351 3352 def forward(self, x): 3353 a = self.foo() 3354 return x + 1 3355 3356 m = torch.jit.script(A()) 3357 self.assertEqual(m.x, None) 3358 m(torch.rand(1)) 3359 self.assertEqual(m.x, torch.tensor([3])) 3360 3361 def test_mutate_constant(self): 3362 class M(torch.jit.ScriptModule): 3363 __constants__ = ["foo"] 3364 3365 def __init__(self, foo): 3366 super().__init__() 3367 self.foo = foo 3368 3369 m = M(5) 3370 # m has a constant attribute, but we can't 3371 # assign to it 3372 with self.assertRaises(RuntimeError): 3373 m.foo = 6 3374 3375 def test_class_attribute(self): 3376 class M(torch.jit.ScriptModule): 3377 FOO = 0 3378 3379 def __init__(self) -> None: 3380 super().__init__() 3381 self.foo = self.FOO 3382 m = M() 3383 self.assertEqual(m.foo, M.FOO) 3384 3385 def test_class_attribute_in_script(self): 3386 class M(torch.jit.ScriptModule): 3387 FOO = 0 3388 3389 @torch.jit.script_method 3390 def forward(self): 3391 return self.FOO 3392 with self.assertRaises(RuntimeError): 3393 M() 3394 3395 def test_not_initialized_err(self): 3396 class M(torch.jit.ScriptModule): 3397 def __init__(self) -> None: 3398 self.foo = torch.rand(2, 3) 3399 with self.assertRaises(RuntimeError): 3400 M() 3401 3402 def test_attribute_in_init(self): 3403 class M(torch.jit.ScriptModule): 3404 def __init__(self) -> None: 3405 super().__init__() 3406 self.foo = torch.jit.Attribute(0.1, float) 3407 # we should be able to use self.foo as a float here 3408 assert 0.0 < self.foo 3409 M() 3410 3411 def test_scriptable_fn_as_attr(self): 3412 class M(torch.nn.Module): 3413 def __init__(self, fn): 3414 super().__init__() 3415 self.fn = fn 3416 3417 def forward(self, x): 3418 return self.fn(x) 3419 3420 m = M(torch.sigmoid) 3421 inp = torch.rand(2, 3) 3422 self.checkModule(m, (inp, )) 3423 3424 def test_sequence_parsing(self): 3425 tests = [ 3426 ("return [x, x,]", True), 3427 ("return [x x]", "expected ]"), 3428 ("return x, x,", True), 3429 ("return bar(x, x,)", True), 3430 ("return bar()", "Argument x not provided"), 3431 ("for a, b, in x, x,:\n pass", "List of iterables"), 3432 ("a, b, = x, x,\n return a + b", True) 3433 ] 3434 for exp, result in tests: 3435 cu = torch.jit.CompilationUnit() 3436 full = f""" 3437def bar(x, y): 3438 return x + y 3439def foo(x): 3440 {exp} 3441 """ 3442 if isinstance(result, str): 3443 with self.assertRaisesRegex(RuntimeError, result): 3444 cu.define(full) 3445 else: 3446 cu.define(full) 3447 3448 def test_namedtuple_python(self): 3449 global MyTuple, MyMod # see [local resolution in python] 3450 MyTuple = namedtuple('MyTuple', ['a']) 3451 3452 @torch.jit.unused 3453 def fn(): 3454 # type: () -> MyTuple 3455 return MyTuple(1) 3456 3457 # Only check compilation 3458 @torch.jit.script 3459 def fn2(): 3460 # type: () -> MyTuple 3461 return fn() 3462 3463 FileCheck().check("NamedTuple").run(fn2.graph) 3464 3465 class MyMod(torch.nn.Module): 3466 @torch.jit.unused 3467 def fn(self): 3468 # type: () -> MyTuple 3469 return MyTuple(1) 3470 3471 def forward(self, x): 3472 if 1 == 1: 3473 return MyTuple(torch.rand(2, 3)) 3474 else: 3475 return self.fn() 3476 3477 # shouldn't throw a type error 3478 torch.jit.script(MyMod()) 3479 3480 def test_unused_decorator(self): 3481 class MyMod(torch.nn.Module): 3482 @torch.jit.unused 3483 @torch.no_grad() 3484 def fn(self, x): 3485 # type: (Tensor) -> int 3486 return next(x) # invalid, but should be ignored 3487 3488 def forward(self, x): 3489 return self.fn(x) 3490 3491 torch.jit.script(MyMod()) 3492 3493 @_inline_everything 3494 def test_lazy_script(self): 3495 def untraceable(x): 3496 if x.ndim > 2: 3497 print("hello") 3498 else: 3499 print("goodbye") 3500 return x + 2 3501 3502 # Non-working example 3503 def fn(x): 3504 return untraceable(x) 3505 3506 with self.capture_stdout(): 3507 traced_bad = torch.jit.trace(fn, [torch.ones(2, 2)]) 3508 3509 FileCheck().check_not("goodbye").check_not("hello").run(traced_bad.graph) 3510 3511 # Working example 3512 untraceable = torch.jit.script_if_tracing(untraceable) 3513 3514 def fn2(x): 3515 return untraceable(x) 3516 3517 with self.capture_stdout(): 3518 traced = torch.jit.trace(fn, [torch.ones(2, 2)]) 3519 3520 FileCheck().check("goodbye").run(traced.graph) 3521 3522 def foo(x: int): 3523 return x + 1 3524 3525 @torch.jit.script_if_tracing 3526 def fee(x: int = 2): 3527 return foo(1) + x 3528 3529 # test directly compiling function 3530 fee_compiled = torch.jit.script(fee) 3531 self.assertEqual(fee_compiled(), fee()) 3532 3533 # test compiling it within another function 3534 @torch.jit.script 3535 def hum(): 3536 return fee(x=3) 3537 3538 self.assertEqual(hum(), 5) 3539 3540 def test_big_int_literals(self): 3541 def ok(): 3542 # signed 64 bit max 3543 a = 9223372036854775807 3544 return a 3545 3546 def toobig(): 3547 a = 9223372036854775808 3548 return a 3549 3550 def waytoobig(): 3551 a = 99999999999999999999 3552 return a 3553 3554 self.checkScript(ok, []) 3555 3556 with self.assertRaisesRegex(RuntimeError, "out of range"): 3557 torch.jit.script(toobig) 3558 3559 with self.assertRaisesRegex(RuntimeError, "out of range"): 3560 torch.jit.script(waytoobig) 3561 3562 def test_hex_literals(self): 3563 def test1(): 3564 return 0xaaaaaa 3565 3566 def test2(): 3567 return 0xaaaaaa 3568 3569 def test3(): 3570 return -0xaaaaaa 3571 3572 self.checkScript(test1, []) 3573 self.checkScript(test2, []) 3574 self.checkScript(test3, []) 3575 3576 def ok(): 3577 a = 0x7FFFFFFFFFFFFFFF 3578 return a 3579 3580 def toobig(): 3581 a = 0xFFFFFFFFFFFFFFFF 3582 return a 3583 3584 def waytoobig(): 3585 a = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF 3586 return a 3587 3588 self.checkScript(ok, []) 3589 3590 with self.assertRaisesRegex(RuntimeError, "out of range"): 3591 torch.jit.script(toobig) 3592 3593 with self.assertRaisesRegex(RuntimeError, "out of range"): 3594 torch.jit.script(waytoobig) 3595 3596 def test_big_float_literals(self): 3597 def ok(): 3598 # Python interprets this as inf 3599 a = 1.2E400 3600 return a 3601 3602 def check(fn): 3603 self.assertTrue(fn() == ok()) 3604 3605 # checkScript doesn't work since assertEqual doesn't consider 3606 # `inf` == `inf` 3607 check(torch.jit.script(ok)) 3608 3609 cu = torch.jit.CompilationUnit() 3610 cu.define(dedent(inspect.getsource(ok))) 3611 check(cu.ok) 3612 3613 def _test_device_type(self, dest): 3614 def fn(x): 3615 # type: (Device) -> Tuple[str, Optional[int]] 3616 return x.type, x.index 3617 3618 device = torch.ones(2).to(dest).device 3619 self.checkScript(fn, [device]) 3620 3621 def test_device_type(self): 3622 self._test_device_type('cpu') 3623 3624 @unittest.skipIf(not RUN_CUDA, "Requires CUDA") 3625 def test_device_type_cuda(self): 3626 self._test_device_type('cuda') 3627 3628 def test_string_device_implicit_conversion(self): 3629 @torch.jit.script 3630 def fn(x: torch.device): 3631 return x 3632 3633 self.assertEqual(fn("cpu"), torch.device("cpu")) 3634 3635 with self.assertRaisesRegex(RuntimeError, "Expected one of"): 3636 fn("invalid_device") 3637 3638 def test_eval_python(self): 3639 def _test(m): 3640 self.assertTrue(m(torch.ones(2, 2))) 3641 self.assertTrue(m.training) 3642 self.assertTrue(m._c.getattr('training')) 3643 3644 m.eval() 3645 3646 self.assertFalse(m.training) 3647 self.assertFalse(m._c.getattr('training')) 3648 self.assertFalse(m(torch.ones(2, 2))) 3649 3650 buffer = io.BytesIO() 3651 torch.jit.save(m, buffer) 3652 buffer.seek(0) 3653 3654 loaded = torch.jit.load(buffer) 3655 3656 self.assertFalse(loaded.training) 3657 self.assertFalse(loaded._c.getattr('training')) 3658 3659 class M(nn.Module): 3660 def forward(self, x): 3661 return self.training 3662 3663 class OldM(torch.jit.ScriptModule): 3664 @torch.jit.script_method 3665 def forward(self, x): 3666 return self.training 3667 3668 _test(torch.jit.script(M())) 3669 _test(OldM()) 3670 3671 def test_inherit_method(self): 3672 class A(torch.jit.ScriptModule): 3673 @torch.jit.script_method 3674 def forward(self, x): 3675 return x + self.bar(x) 3676 3677 class B(A): 3678 @torch.jit.script_method 3679 def bar(self, x): 3680 return x * x 3681 3682 with self.assertRaisesRegex(RuntimeError, 'attribute'): 3683 A() # cannot use because bar is not defined 3684 3685 v = torch.rand(3, 4) 3686 b = B() 3687 self.assertEqual(b(v), v + v * v) 3688 3689 class C(torch.jit.ScriptModule): 3690 @torch.jit.script_method 3691 def bar(self, x): 3692 return x 3693 3694 class D(C, B): 3695 def __init__(self) -> None: 3696 super().__init__() 3697 3698 self.assertEqual(D()(v), v + v) 3699 3700 def test_tensor_subclasses(self): 3701 def check_subclass(x, tensor): 3702 template = dedent(""" 3703 def func(input: {}) -> {}: 3704 return torch.zeros((input.shape[0], 1), dtype=input.dtype) 3705 """) 3706 3707 self._check_code(template.format(x, x), "func", [tensor]) 3708 3709 check_subclass("torch.LongTensor", torch.LongTensor([[1, 2], [3, 4]])) 3710 check_subclass("torch.DoubleTensor", torch.DoubleTensor([[1.2, 2.3], [3.4, 4.5]])) 3711 check_subclass("torch.IntTensor", torch.IntTensor([[1, 2], [3, 4]])) 3712 check_subclass("torch.BoolTensor", torch.BoolTensor([[False, True], [True, False]])) 3713 3714 def check_subclass_warn(input: torch.LongTensor) -> torch.LongTensor: 3715 return torch.zeros((input.shape[0], 1), dtype=input.dtype) 3716 3717 with warnings.catch_warnings(record=True) as warns: 3718 scripted = torch.jit.script(check_subclass_warn) 3719 FileCheck().check("TorchScript will treat type annotations of Tensor").run(str(warns[0])) 3720 3721 def test_first_class_module(self): 3722 class Foo(torch.jit.ScriptModule): 3723 def __init__(self) -> None: 3724 super().__init__() 3725 self.foo = nn.Parameter(torch.rand(3, 4)) 3726 3727 @torch.jit.script_method 3728 def forward(self, input): 3729 self.foo = input 3730 return self.foo 3731 foo = Foo() 3732 input = torch.rand(3, 4) 3733 foo.forward(input) 3734 self.assertEqual(input, foo.foo) 3735 3736 @_tmp_donotuse_dont_inline_everything 3737 def test_first_class_calls(self): 3738 @torch.jit.script 3739 class Foo: 3740 def __init__(self, x): 3741 self.bar = x 3742 3743 def stuff(self, x): 3744 return self.bar + x 3745 3746 @torch.jit.script 3747 def foo(x): 3748 return x * x + Foo(x).stuff(2 * x) 3749 3750 @torch.jit.script 3751 def bar(x): 3752 return foo(x) * foo(x) 3753 3754 x = torch.rand(3, 4) 3755 self.assertEqual(bar(x), (x * x + 3 * x) * (x * x + 3 * x)) 3756 3757 def test_static_methods(self): 3758 class M(nn.Module): 3759 @staticmethod 3760 def my_method(x): 3761 return x + 100 3762 3763 def forward(self, x): 3764 return x + M.my_method(x) 3765 3766 class N(nn.Module): 3767 @staticmethod 3768 def my_method(x): 3769 return x * 100 3770 3771 def forward(self, x): 3772 return x - M.my_method(x) + N.my_method(x) 3773 3774 self.checkModule(M(), (torch.ones(2, 2),)) 3775 3776 self.checkModule(N(), (torch.ones(2, 2),)) 3777 3778 def test_invalid_prefix_annotation(self): 3779 with self.assertRaisesRegex(RuntimeError, "annotation prefix in line"): 3780 with self.capture_stdout() as captured: 3781 @torch.jit.script 3782 def invalid_prefix_annotation1(a): 3783 #type: (Int) -> Int # noqa: E265 3784 return a + 2 3785 3786 with self.assertRaisesRegex(RuntimeError, "annotation prefix in line"): 3787 with self.capture_stdout() as captured: 3788 @torch.jit.script 3789 def invalid_prefix_annotation2(a): 3790 #type : (Int) -> Int # noqa: E265 3791 return a + 2 3792 3793 with self.assertRaisesRegex(RuntimeError, "annotation prefix in line"): 3794 with self.capture_stdout() as captured: 3795 @torch.jit.script 3796 def invalid_prefix_annotation3(a): 3797 # type: (Int) -> Int 3798 return a + 2 3799 3800 def test_builtin_function_attributes(self): 3801 class Add(nn.Module): 3802 def __init__(self) -> None: 3803 super().__init__() 3804 self.add = torch.add 3805 3806 def forward(self, input): 3807 return self.add(input, input) 3808 3809 self.checkModule(Add(), [torch.randn(2, 2)]) 3810 3811 def test_pybind_type_comparisons(self): 3812 @torch.jit.script 3813 def f(): 3814 return None 3815 3816 node = list(f.graph.nodes())[0] 3817 t = node.outputsAt(0).type() 3818 self.assertIsNotNone(t) 3819 3820 @unittest.skipIf(IS_WINDOWS, 'TODO: need to fix the test case') 3821 def test_unmatched_type_annotation(self): 3822 message1 = re.escape("Number of type annotations (2) did not match the number of function parameters (1):") 3823 message2 = 'def invalid2\\(a\\):\n\\s*~+\\.*\\s+<--- HERE\n\\s+# type: \\(Int, Int\\) -> Int\n\\s+return a \\+ 2' 3824 message3 = 'def invalid4\\(a\\):\n\\s*~+\\.*\\s+<--- HERE\n\\s+# type: \\(Int, Int\\) -> Int\n\\s+return a \\+ 2' 3825 with self.assertRaisesRegex(RuntimeError, message1): 3826 @torch.jit.script 3827 def invalid1(a): 3828 # type: (Int, Int) -> Int 3829 return a + 2 3830 3831 with self.assertRaisesRegex(RuntimeError, message2): 3832 @torch.jit.script 3833 def invalid2(a): 3834 # type: (Int, Int) -> Int 3835 return a + 2 3836 3837 with self.assertRaisesRegex(RuntimeError, message1): 3838 def invalid3(a): 3839 # type: (Int, Int) -> Int 3840 return a + 2 3841 torch.jit.script(invalid3) 3842 3843 with self.assertRaisesRegex(RuntimeError, message3): 3844 def invalid4(a): 3845 # type: (Int, Int) -> Int 3846 return a + 2 3847 torch.jit.script(invalid4) 3848 3849 def test_calls_in_type_annotations(self): 3850 with self.assertRaisesRegex(RuntimeError, "Type annotation should not contain calls"): 3851 def spooky(a): 3852 # type: print("Hello") -> Tensor # noqa: F723 3853 return a + 2 3854 print(torch.__file__) 3855 torch.jit.annotations.get_signature(spooky, None, 1, True) 3856 3857 def test_is_optional(self): 3858 ann = Union[List[int], List[float]] 3859 torch._jit_internal.is_optional(ann) 3860 3861 def test_interpreter_fuzz(self): 3862 import builtins 3863 # This test generates random tree-like programs to fuzz test 3864 # that the interpreter does not have a bug in its stack manipulation 3865 # code. An assert in that code ensures individual operators are 3866 # not reordered. 3867 templates = [ 3868 "torch.rand(3, 4)", 3869 "({} + {})", 3870 "-{}", 3871 "({} * {})", 3872 "torch.tanh({})", 3873 "VAR {}", 3874 ] 3875 3876 def gen_code(): 3877 src_lines = ['def f():'] 3878 exprs = [] 3879 n_variables = 0 3880 3881 def get_expr(idx): 3882 elem = exprs[idx] 3883 exprs[idx] = exprs[-1] 3884 exprs.pop() 3885 return elem 3886 3887 def select_expr_or_var(): 3888 idx = random.randrange(0, len(exprs) + n_variables) 3889 if idx < len(exprs): 3890 return get_expr(idx) 3891 else: 3892 return f'v{idx - len(exprs)}' 3893 3894 for i in range(50): 3895 n = None 3896 while n is None or n > len(exprs) + n_variables: 3897 template = random.choice(templates) 3898 n = template.count('{}') 3899 3900 if 'VAR' in template: 3901 src_lines.append(f' v{n_variables} = {select_expr_or_var()}') 3902 n_variables += 1 3903 else: 3904 exprs.append(template.format(*(select_expr_or_var() for _ in range(n)))) 3905 3906 src_lines.append(' return ({})\n'.format(''.join(f'v{i},' for i in range(n_variables)))) 3907 return '\n'.join(src_lines) 3908 3909 for i in range(100): 3910 g = {'torch': torch} 3911 code = gen_code() 3912 builtins.exec(code, g, None) 3913 cu = torch.jit.CompilationUnit(code) 3914 with freeze_rng_state(): 3915 o1 = g['f']() 3916 with freeze_rng_state(): 3917 o2 = cu.f() 3918 self.assertEqual(o1, o2) 3919 3920 @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 3921 def test_cpp_module_iterator(self): 3922 a = nn.Module() 3923 a.name = 'a' 3924 a.p = nn.Parameter(torch.rand(3, 4)) 3925 a.foo = nn.Module() 3926 a.foo.name = 'foo' 3927 a.foo.b = nn.Buffer(torch.rand(1, 1)) 3928 a.foo.bar = nn.Module() 3929 a.foo.bar.name = 'bar' 3930 a.foo.bar.an_int = 4 3931 a.another = nn.Module() 3932 a.another.name = 'another' 3933 sa = torch.jit.script(a) 3934 result = torch._C._jit_debug_module_iterators(sa._c) 3935 3936 def replace(e): 3937 if e is a.p: 3938 return 'P' 3939 elif e is a.foo.b: 3940 return 'B' 3941 elif isinstance(e, torch._C.ScriptModule): 3942 return e.getattr('name') 3943 3944 return e 3945 for v in result.values(): 3946 for i in range(len(v)): 3947 if isinstance(v[i], tuple): 3948 n, v2 = v[i] 3949 v[i] = (n, replace(v2)) 3950 else: 3951 v[i] = replace(v[i]) 3952 # module type creation is not deterministic, so we have to sort 3953 # the result 3954 v.sort() 3955 expected = {'buffers': [], 3956 'buffers_r': ['B'], 3957 'children': ['another', 'foo'], 3958 'modules': ['a', 'another', 'bar', 'foo'], 3959 'named_attributes': [('_is_full_backward_hook', None), 3960 ('another', 'another'), 3961 ('foo', 'foo'), 3962 ('name', 'a'), 3963 ('p', 'P'), 3964 ('training', True)], 3965 'named_attributes_r': [('_is_full_backward_hook', None), 3966 ('another', 'another'), 3967 ('another._is_full_backward_hook', None), 3968 ('another.name', 'another'), 3969 ('another.training', True), 3970 ('foo', 'foo'), 3971 ('foo._is_full_backward_hook', None), 3972 ('foo.b', 'B'), 3973 ('foo.bar', 'bar'), 3974 ('foo.bar._is_full_backward_hook', None), 3975 ('foo.bar.an_int', 4), 3976 ('foo.bar.name', 'bar'), 3977 ('foo.bar.training', True), 3978 ('foo.name', 'foo'), 3979 ('foo.training', True), 3980 ('name', 'a'), 3981 ('p', 'P'), 3982 ('training', True)], 3983 'named_buffers': [], 3984 'named_buffers_r': [('foo.b', 'B')], 3985 'named_children': [('another', 'another'), ('foo', 'foo')], 3986 'named_modules': [('', 'a'), 3987 ('another', 'another'), 3988 ('foo', 'foo'), 3989 ('foo.bar', 'bar')], 3990 'named_parameters': [('p', 'P')], 3991 'named_parameters_r': [('p', 'P')], 3992 'parameters': ['P'], 3993 'parameters_r': ['P']} 3994 self.assertEqual(expected, result) 3995 3996 def test_parameter_order(self): 3997 m = nn.Module() 3998 for i, name in enumerate(string.ascii_letters): 3999 setattr(m, name, nn.Parameter(torch.tensor([float(i)]))) 4000 ms = torch.jit.script(m) 4001 print(torch.cat(list(m.parameters()))) 4002 print(torch.cat(list(ms.parameters()))) 4003 self.assertEqual(list(m.parameters()), list(ms.parameters())) 4004 4005 def test_python_op_builtins(self): 4006 @torch.jit.unused 4007 def fn(x): 4008 # type: (List[int]) -> int 4009 return sum(x) 4010 4011 @torch.jit.script 4012 def script_fn(x): 4013 # type: (List[int]) -> int 4014 return fn(x) 4015 4016 def test_submodule_twice(self): 4017 @torch.jit.script 4018 def foo(x): 4019 return x * x 4020 4021 class What(torch.jit.ScriptModule): 4022 def __init__(self, x): 4023 super().__init__() 4024 self.foo = x 4025 a = What(foo) 4026 c = What(foo) 4027 4028 def test_training_param(self): 4029 class What(torch.jit.ScriptModule): 4030 @torch.jit.script_method 4031 def forward(self, x): 4032 # type: (int) -> int 4033 if self.training: 4034 r = x 4035 else: 4036 r = x + 4 4037 # check double use of training 4038 if self.training: 4039 r = r + 1 4040 return r 4041 4042 w = What() 4043 self.assertEqual(4, w(3)) 4044 w.train(False) 4045 self.assertEqual(7, w(3)) 4046 self.assertFalse("training" in w.state_dict()) 4047 4048 def test_class_as_attribute(self): 4049 @torch.jit.script 4050 class Foo321: 4051 def __init__(self) -> None: 4052 self.x = 3 4053 4054 class FooBar1234(torch.nn.Module): 4055 def __init__(self) -> None: 4056 super().__init__() 4057 self.f = Foo321() 4058 4059 def forward(self, x): 4060 return x + self.f.x 4061 4062 scripted = torch.jit.script(FooBar1234()) 4063 eic = self.getExportImportCopy(scripted) 4064 x = torch.rand(3, 4) 4065 self.assertEqual(scripted(x), eic(x)) 4066 4067 def test_module_str(self): 4068 class Foo(torch.nn.Module): 4069 def forward(self, x): 4070 return torch.relu(x) 4071 4072 f = torch.jit.script(Foo()) 4073 4074 str_f = str(f._c) 4075 self.assertTrue(str_f.startswith('ScriptObject')) 4076 self.assertTrue('__torch__.' in str_f) 4077 self.assertTrue('.Foo' in str_f) 4078 4079 def test_jitter_bug(self): 4080 @torch.jit.script 4081 def fn2(input, kernel_size): 4082 # type: (Tensor, List[int]) -> Tensor 4083 if kernel_size[0] > 1: 4084 _stride = [2] 4085 else: 4086 _stride = kernel_size 4087 print(_stride, kernel_size) 4088 return input 4089 4090 @torch.jit.script 4091 def fn(input): 4092 # type: (Tensor) -> Tensor 4093 return fn2(input, [1]) 4094 4095 def test_parser_kwargonly(self): 4096 cu = torch.jit.CompilationUnit(''' 4097 def foo(x, *, y) -> Tuple[Tensor, Tensor]: 4098 return x, x 4099 def bar(x): 4100 return foo(x, y=x) 4101 ''') 4102 self.assertTrue('*' in str(cu.foo.schema)) 4103 with self.assertRaisesRegex(RuntimeError, "not provided"): 4104 torch.jit.CompilationUnit(''' 4105 def foo(x, *, y) -> Tuple[Tensor, Tensor]: 4106 return x, x 4107 def bar(x): 4108 return foo(x, x) 4109 ''') 4110 4111 def test_annoying_doubles(self): 4112 mod = types.ModuleType("temp") 4113 mod.inf = float("inf") 4114 mod.ninf = float("-inf") 4115 mod.nan = float("nan") 4116 4117 with torch._jit_internal._disable_emit_hooks(): 4118 class Foo(torch.jit.ScriptModule): 4119 @torch.jit.script_method 4120 def forward(self): 4121 return math.pi, 0.1, mod.inf, mod.ninf, 2.225073858507201e-308, mod.nan 4122 4123 foo = Foo() 4124 buffer = io.BytesIO() 4125 torch.jit.save(foo, buffer) 4126 4127 buffer.seek(0) 4128 foo_loaded = torch.jit.load(buffer) 4129 4130 r = foo() 4131 r2 = foo_loaded() 4132 # use precise assert, we are checking floating point details 4133 self.assertTrue(r[:-1] == r2[:-1]) 4134 self.assertTrue(math.isnan(r[-1]) and math.isnan(r2[-1])) 4135 4136 def test_type_annotate(self): 4137 4138 def foo(a): 4139 return torch.jit.annotate(torch.Tensor, a) 4140 4141 self.checkScript(foo, (torch.rand(3),)) 4142 4143 def bar(): 4144 a = torch.jit.annotate(List[int], []) 4145 for _ in range(10): 4146 a.append(4) 4147 return a 4148 4149 self.checkScript(bar, ()) 4150 4151 def baz(a): 4152 return torch.jit.annotate(float, a) 4153 self.checkScript(baz, (torch.rand(()),)) 4154 4155 # test annotate none types 4156 def annotate_none(): 4157 return torch.jit.annotate(Optional[torch.Tensor], None) 4158 4159 self.checkScript(annotate_none, ()) 4160 4161 4162 def test_robust_op_resolution(self): 4163 neg = torch.add # misleading name to make sure we resolve by function 4164 4165 def stuff(x): 4166 return neg(x, x) 4167 4168 a = (torch.rand(3),) 4169 self.checkScript(stuff, a) 4170 4171 def test_nested_aug_assign(self): 4172 @torch.jit.script 4173 class SomeClass: 4174 def __init__(self) -> None: 4175 self.num = 99 4176 4177 def __iadd__(self, x): 4178 # type: (int) 4179 self.num += x 4180 return self 4181 4182 def __eq__(self, other): 4183 # type: (SomeClass) -> bool 4184 return self.num == other.num 4185 4186 @torch.jit.script 4187 class SomeOutOfPlaceClass: 4188 def __init__(self) -> None: 4189 self.num = 99 4190 4191 def __add__(self, x): 4192 # type: (int) 4193 self.num = x 4194 return self 4195 4196 def __eq__(self, other): 4197 # type: (SomeClass) -> bool 4198 return self.num == other.num 4199 4200 class Child(nn.Module): 4201 def __init__(self) -> None: 4202 super().__init__() 4203 self.x = 2 4204 self.o = SomeClass() 4205 self.oop = SomeOutOfPlaceClass() 4206 self.list = [1, 2, 3] 4207 4208 class A(nn.Module): 4209 def __init__(self) -> None: 4210 super().__init__() 4211 self.child = Child() 4212 4213 def forward(self): 4214 self.child.x += 1 4215 self.child.o += 5 4216 self.child.oop += 5 4217 some_list = [1, 2] 4218 self.child.list += some_list 4219 self.child.list *= 2 4220 return self.child.x, self.child.o, self.child.list, self.child.oop 4221 4222 a = A() 4223 sa = torch.jit.script(A()) 4224 eager_result = a() 4225 script_result = sa() 4226 self.assertEqual(eager_result, script_result) 4227 self.assertEqual(a.child.x, sa.child.x) 4228 self.assertEqual(a.child.o, sa.child.o) 4229 self.assertEqual(a.child.list, sa.child.list) 4230 4231 @torch.jit.script 4232 class SomeNonAddableClass: 4233 def __init__(self) -> None: 4234 self.num = 99 4235 4236 def __eq__(self, other): 4237 # type: (SomeClass) -> bool 4238 return self.num == other.num 4239 4240 # with self.assertRaisesRegex(RuntimeError, "") 4241 class A(nn.Module): 4242 def __init__(self) -> None: 4243 super().__init__() 4244 self.x = SomeNonAddableClass() 4245 4246 def forward(self): 4247 self.x += SomeNonAddableClass() 4248 return self.x 4249 4250 with self.assertRaisesRegex(RuntimeError, "Cannot emit inplace op"): 4251 torch.jit.script(A()) 4252 4253 def test_var_aug_assign(self): 4254 @torch.jit.script 4255 class SomeNonAddableClass: 4256 def __init__(self) -> None: 4257 self.num = 99 4258 4259 def __eq__(self, other): 4260 # type: (SomeNonAddableClass) -> bool 4261 return self.num == other.num 4262 4263 with self.assertRaisesRegex(RuntimeError, "Cannot emit inplace op"): 4264 @torch.jit.script 4265 def fn(): 4266 a = SomeNonAddableClass() 4267 a += SomeNonAddableClass() 4268 return a 4269 4270 @torch.jit.script 4271 class SomeClass: 4272 def __init__(self) -> None: 4273 self.num = 99 4274 4275 def __iadd__(self, x): 4276 # type: (int) 4277 self.num += x 4278 return self 4279 4280 def __eq__(self, other): 4281 # type: (SomeClass) -> bool 4282 return self.num == other.num 4283 4284 @torch.jit.script 4285 class SomeOutOfPlaceClass: 4286 def __init__(self) -> None: 4287 self.num = 99 4288 4289 def __add__(self, x): 4290 # type: (int) 4291 self.num = x 4292 return self 4293 4294 def __eq__(self, other): 4295 # type: (SomeClass) -> bool 4296 return self.num == other.num 4297 4298 def fn2(): 4299 a = SomeClass() 4300 a_copy = a 4301 a += 20 4302 assert a is a_copy 4303 b = SomeOutOfPlaceClass() 4304 b_copy = b 4305 b += 99 4306 assert b is b_copy 4307 c = [1, 2, 3] 4308 c_copy = c 4309 c *= 2 4310 assert c is c_copy 4311 c += [4, 5, 6] 4312 d = torch.ones(2, 2) 4313 d_copy = d 4314 d += torch.ones(2, 2) 4315 assert d is d_copy 4316 return a, b, c, d 4317 4318 self.checkScript(fn2, []) 4319 4320 def test_nested_list_construct(self): 4321 def foo(): 4322 return [[4]] + [[4, 5]] 4323 self.checkScript(foo, ()) 4324 4325 def test_file_line_error(self): 4326 def foobar(xyz): 4327 return torch.blargh(xyz) 4328 4329 _, lineno = inspect.getsourcelines(foobar) 4330 with self.assertRaisesRegex(RuntimeError, f'test_jit.py", line {lineno + 1}'): 4331 scripted = torch.jit.script(foobar) 4332 4333 def test_file_line_error_class_defn(self): 4334 class FooBar: 4335 def baz(self, xyz): 4336 return torch.blargh(xyz) 4337 4338 _, lineno = inspect.getsourcelines(FooBar) 4339 with self.assertRaisesRegex(RuntimeError, f'test_jit.py", line {lineno + 2}'): 4340 torch.jit.script(FooBar) 4341 4342 def test_file_line_graph(self): 4343 def foobar(xyz): 4344 return torch.neg(xyz) 4345 4346 scripted = torch.jit.script(foobar) 4347 4348 _, lineno = inspect.getsourcelines(foobar) 4349 fc = FileCheck().check(f'test_jit.py:{lineno + 1}:19') 4350 fc.run(scripted.graph) 4351 fc.run(str(scripted.graph)) 4352 4353 def test_file_line_save_load(self): 4354 class Scripted(torch.jit.ScriptModule): 4355 @torch.jit.script_method 4356 def forward(self, xyz): 4357 return torch.neg(xyz) 4358 4359 scripted = Scripted() 4360 4361 # NB: not using getExportImportCopy because that takes a different 4362 # code path that calls CompilationUnit._import rather than 4363 # going through the full save/load pathway 4364 buffer = scripted.save_to_buffer() 4365 bytesio = io.BytesIO(buffer) 4366 scripted = torch.jit.load(bytesio) 4367 4368 _, lineno = inspect.getsourcelines(Scripted) 4369 fc = FileCheck().check(f':{lineno + 3}') 4370 fc.run(scripted.graph) 4371 fc.run(str(scripted.graph)) 4372 4373 def test_file_line_string(self): 4374 scripted = torch.jit.CompilationUnit(''' 4375def foo(xyz): 4376 return torch.neg(xyz) 4377 ''') 4378 4379 fc = FileCheck().check('<string>:3:11') 4380 fc.run(scripted.foo.graph) 4381 fc.run(str(scripted.foo.graph)) 4382 4383 @skipIfCrossRef 4384 def test_file_line_trace(self): 4385 def foobar(xyz): 4386 return torch.neg(xyz) 4387 4388 scripted = torch.jit.trace(foobar, (torch.rand(3, 4))) 4389 4390 _, lineno = inspect.getsourcelines(foobar) 4391 fc = FileCheck().check(f'test_jit.py:{lineno + 1}:0') 4392 fc.run(scripted.graph) 4393 fc.run(str(scripted.graph)) 4394 4395 def test_serialized_source_ranges(self): 4396 4397 class FooTest(torch.jit.ScriptModule): 4398 @torch.jit.script_method 4399 def forward(self, x, w): 4400 return torch.mm(x, w.t()) 4401 4402 ft = FooTest() 4403 loaded = self.getExportImportCopy(ft) 4404 _, lineno = inspect.getsourcelines(FooTest) 4405 4406 with self.assertRaisesRegex(RuntimeError, f'test_jit.py", line {lineno + 3}'): 4407 loaded(torch.rand(3, 4), torch.rand(30, 40)) 4408 4409 def test_serialized_source_ranges_graph(self): 4410 4411 class FooTest3(torch.jit.ScriptModule): 4412 @torch.jit.script_method 4413 def forward(self, x, w): 4414 return torch.mm(x, w.t()) 4415 4416 ft = FooTest3() 4417 loaded = self.getExportImportCopy(ft) 4418 _, lineno = inspect.getsourcelines(FooTest3) 4419 4420 fc = FileCheck().check(f'test_jit.py:{lineno + 3}') 4421 fc.run(loaded.graph) 4422 4423 def test_serialized_source_ranges2(self): 4424 4425 class FooTest2(torch.jit.ScriptModule): 4426 @torch.jit.script_method 4427 def forward(self): 4428 raise RuntimeError('foo') 4429 4430 _, lineno = inspect.getsourcelines(FooTest2) 4431 4432 with self.assertRaisesRegex(torch.jit.Error, f'test_jit.py", line {lineno + 3}'): 4433 ft = FooTest2() 4434 loaded = self.getExportImportCopy(ft) 4435 loaded() 4436 4437 def test_serialized_source_ranges_dont_jitter(self): 4438 class FooTest3(torch.jit.ScriptModule): 4439 @torch.jit.script_method 4440 def forward(self, lim): 4441 first = 1 4442 second = 1 4443 i = 1 4444 somenum = 5 4445 dontmutateme = 3 4446 third = 0 4447 while bool(i < lim): 4448 third = first + second 4449 first = second 4450 second = third 4451 j = 0 4452 while j < 10: 4453 somenum = somenum * 2 4454 j = j + 1 4455 i = i + j 4456 i = i + dontmutateme 4457 4458 st = second + third 4459 fs = first + second 4460 return third, st, fs 4461 4462 ft3 = FooTest3() 4463 4464 def debug_records_from_mod(self, mod): 4465 buffer = io.BytesIO() 4466 torch.jit.save(ft3, buffer) 4467 buffer.seek(0) 4468 archive = zipfile.ZipFile(buffer) 4469 files = filter(lambda x: x.startswith('archive/code/'), archive.namelist()) 4470 debug_files = list(filter(lambda f: f.endswith('.debug_pkl'), files)) 4471 self.assertEqual(len(debug_files), 1) 4472 debug_file = archive.open(debug_files[0]) 4473 return pickle.load(debug_file), buffer 4474 4475 records1, buffer = debug_records_from_mod(self, ft3) 4476 4477 buffer.seek(0) 4478 loaded = torch.jit.load(buffer) 4479 records2, buffer = debug_records_from_mod(self, loaded) 4480 4481 buffer.seek(0) 4482 loaded2 = torch.jit.load(buffer) 4483 records3, _ = debug_records_from_mod(self, loaded2) 4484 4485 self.assertEqual(records1, records2) 4486 self.assertEqual(records2, records3) 4487 4488 def test_serialized_source_ranges_no_dups(self): 4489 class FooTest3(torch.jit.ScriptModule): 4490 @torch.jit.script_method 4491 def forward(self, lim): 4492 first = 1 4493 second = 1 4494 i = 1 4495 somenum = 5 4496 dontmutateme = 3 4497 third = 0 4498 while bool(i < lim): 4499 third = first + second 4500 first = second 4501 second = third 4502 j = 0 4503 while j < 10: 4504 somenum = somenum * 2 4505 j = j + 1 4506 i = i + j 4507 i = i + dontmutateme 4508 4509 st = second + third 4510 fs = first + second 4511 return third, st, fs 4512 4513 ft3 = FooTest3() 4514 4515 def debug_records_from_mod(mod): 4516 buffer = io.BytesIO() 4517 torch.jit.save(ft3, buffer) 4518 buffer.seek(0) 4519 archive = zipfile.ZipFile(buffer) 4520 files = list(filter(lambda x: x.startswith('archive/code/'), archive.namelist())) 4521 debug_files = filter(lambda f: f.endswith('.debug_pkl'), files) 4522 debug_files = (archive.open(f) for f in debug_files) 4523 debug_files = (pickle.load(f) for f in debug_files) 4524 debug_files = (f[2] for f in debug_files) 4525 return list(debug_files) 4526 4527 debug_files = debug_records_from_mod(ft3) 4528 for debug_file in debug_files: 4529 for i in range(len(debug_file) - 1): 4530 offset, source_range_tag, source_range = debug_file[i] 4531 offset2, source_range_tag2, source_range2 = debug_file[i + 1] 4532 self.assertNotEqual(source_range, source_range2) 4533 4534 def test_circular_dependency(self): 4535 """ 4536 https://github.com/pytorch/pytorch/issues/25871 4537 """ 4538 class A(torch.jit.ScriptModule): 4539 @torch.jit.script_method 4540 def forward(self, x): 4541 return x 4542 4543 class B(torch.jit.ScriptModule): 4544 def __init__(self) -> None: 4545 super().__init__() 4546 self.foo = torch.nn.ModuleList([A()]) 4547 4548 @torch.jit.script_method 4549 def forward(self, x): 4550 for f in self.foo: 4551 x = f(x) 4552 return x 4553 4554 class C(torch.jit.ScriptModule): 4555 def __init__(self) -> None: 4556 super().__init__() 4557 self.foo = torch.nn.Sequential(B()) 4558 4559 @torch.jit.script_method 4560 def forward(self, x): 4561 for f in self.foo: 4562 x = f(x) 4563 return x 4564 self.getExportImportCopy(C()) 4565 4566 def test_serialize_long_lines(self): 4567 class OrderModuleLong(torch.nn.Module): 4568 def forward(self, long_arg_name: List[torch.Tensor]): 4569 return [(long_arg_name[1],), (long_arg_name[0].argmax(),)] 4570 src = str(torch.jit.script(OrderModuleLong()).code) 4571 # make long_arg_name[1] does not get reordered after the argmax 4572 FileCheck().check("long_arg_name[1]").check("argmax").run(src) 4573 4574 def test_tensor_shape(self): 4575 x = torch.empty(34, 56, 78) 4576 4577 def f(x): 4578 return x.shape 4579 4580 self.checkScript(f, (x,)) 4581 4582 4583 def test_block_input_grad_in_loop(self): 4584 4585 x = torch.randn(3, 3, requires_grad=False) 4586 y = torch.randn(3, 3, requires_grad=True) 4587 4588 def grad_in_loop(x, y): 4589 for i in range(100): 4590 x = y @ x 4591 return x 4592 4593 scripted = torch.jit.script(grad_in_loop) 4594 outer = scripted.graph_for(x, y) 4595 loop = outer.findNode("prim::Loop") 4596 loop_block = next(loop.blocks()) 4597 param_node = loop_block.paramNode() 4598 x_value = list(param_node.outputs())[1] 4599 self.assertTrue(x_value.requires_grad()) 4600 4601 def test_tensor_grad(self): 4602 x = torch.randn(3, 4, requires_grad=True) 4603 y = torch.randn(3, 4, requires_grad=False) 4604 4605 def f_requires_grad(x): 4606 return x.requires_grad 4607 4608 self.checkScript(f_requires_grad, (x,)) 4609 self.checkScript(f_requires_grad, (y,)) 4610 4611 def f_grad(x): 4612 return x.grad 4613 4614 x.sum().backward() 4615 self.checkScript(f_grad, (x,)) 4616 self.checkScript(f_grad, (y,)) 4617 4618 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "shape analysis is only enabled in Legacy") 4619 def test_prim_grad_undefined(self): 4620 4621 x = torch.ones(2) 4622 4623 def f_grad(x): 4624 return x.grad 4625 4626 scripted = self.checkScript(f_grad, (x,)) 4627 g = scripted.graph_for(x) 4628 4629 prim_grad_node = g.findNode("prim::grad") 4630 self.assertTrue(next(prim_grad_node.outputs()).type().undefined() is None) 4631 4632 def test_tensor_data(self): 4633 x = torch.randn(3, 4, requires_grad=True) 4634 y = torch.randn(4, 5) 4635 4636 def f_data(x): 4637 return x.data 4638 4639 scripted_f_data = torch.jit.script(f_data) 4640 4641 scripted_x = scripted_f_data(x) 4642 self.assertEqual(scripted_x, f_data(x)) 4643 self.assertEqual(scripted_x.requires_grad, False) 4644 4645 scripted_y = scripted_f_data(y) 4646 self.assertEqual(scripted_y, f_data(y)) 4647 self.assertEqual(scripted_x.requires_grad, False) 4648 4649 def test_tensor_dtype(self): 4650 x_byte = torch.empty(34, 56, 78, dtype=torch.uint8) 4651 x_long = torch.empty(34, 56, 78, dtype=torch.long) 4652 x_float32 = torch.empty(34, 56, 78, dtype=torch.float32) 4653 4654 @torch.jit.script 4655 def byte(x): 4656 return x.dtype == torch.uint8 4657 4658 @torch.jit.script 4659 def long(x): 4660 return x.dtype == torch.long 4661 4662 @torch.jit.script 4663 def float32(x): 4664 return x.dtype == torch.float32 4665 4666 self.assertTrue(byte(x_byte)) 4667 self.assertFalse(byte(x_long)) 4668 self.assertFalse(byte(x_float32)) 4669 self.assertFalse(long(x_byte)) 4670 self.assertTrue(long(x_long)) 4671 self.assertFalse(long(x_float32)) 4672 self.assertFalse(float32(x_byte)) 4673 self.assertFalse(float32(x_long)) 4674 self.assertTrue(float32(x_float32)) 4675 4676 @unittest.skipIf(not RUN_CUDA, "device tests require CUDA") 4677 def test_tensor_device(self): 4678 cpu = torch.empty(34, 56, 78, device='cpu') 4679 gpu = torch.empty(34, 56, 78, device='cuda') 4680 4681 @torch.jit.script 4682 def same_device(x, y): 4683 return x.device == y.device 4684 4685 self.assertTrue(same_device(cpu, cpu)) 4686 self.assertTrue(same_device(gpu, gpu)) 4687 self.assertFalse(same_device(cpu, gpu)) 4688 4689 @unittest.skipIf(not RUN_CUDA, "device tests require CUDA") 4690 def test_tensor_to_device(self): 4691 def to_device(x): 4692 return x.to(device="cuda").to(device=torch.device("cpu")) 4693 4694 self.checkScript(to_device, (torch.ones(3, 4),)) 4695 4696 def test_tensor_to_cpu(self): 4697 def to_cpu(x): 4698 return x.cpu() 4699 4700 x = torch.ones(3, 4) 4701 script_fn = torch.jit.script(to_cpu) 4702 self.assertEqual(to_cpu(x).device, script_fn(x).device) 4703 self.checkScript(to_cpu, (x,)) 4704 4705 @unittest.skipIf(not RUN_CUDA, "device tests require CUDA") 4706 def test_tensor_to_cuda(self): 4707 def to_cuda(x): 4708 return x.cuda() 4709 4710 x = torch.ones(3, 4) 4711 script_fn = torch.jit.script(to_cuda) 4712 self.assertEqual(to_cuda(x).device, script_fn(x).device) 4713 self.checkScript(to_cuda, (x,)) 4714 4715 def test_generic_list_errors(self): 4716 with self.assertRaisesRegex(RuntimeError, "previously matched to type"): 4717 @torch.jit.script 4718 def foo(x): 4719 return [[x]] + [[1]] 4720 4721 def test_script_cu(self): 4722 cu = torch.jit.CompilationUnit(''' 4723 def foo(a): 4724 b = a 4725 return b 4726 ''') 4727 a = Variable(torch.rand(1)) 4728 self.assertEqual(a, cu.foo(a)) 4729 4730 # because the compilation unit ingests python strings 4731 # to use an escape sequence escape the backslash (\\n = \n) 4732 def test_string_cu(self): 4733 cu = torch.jit.CompilationUnit(''' 4734 def foo(a): 4735 print(a, """a\\n\tb\\n""", 2, "a\ 4736a") 4737 return a 4738 ''') 4739 FileCheck().check("aa").check("a\\n\\tb\\n").run(str(cu.foo.graph)) 4740 4741 def test_function_compilation_caching(self): 4742 def fun(): 4743 return 1 + 2 4744 4745 fun_compiled = torch.jit.script(fun) 4746 # python wrapper around the script function is a different pointer, 4747 # but the underlying script function graph is the same 4748 self.assertIs(fun_compiled.graph, torch.jit.script(fun).graph) 4749 4750 def fun(): 4751 return 3 + 4 4752 4753 num_ref_counts = sys.getrefcount(fun) 4754 4755 # caching doesn't get tripped up by same qualname 4756 fun_compiled_2 = torch.jit.script(fun) 4757 self.assertIsNot(fun_compiled, fun_compiled_2) 4758 self.assertEqual(fun_compiled_2(), 7) 4759 4760 # caching doesnt increase refcounts to function (holds weak reference) 4761 self.assertTrue(sys.getrefcount(fun), num_ref_counts) 4762 4763 def test_string_ops(self): 4764 def foo(): 4765 a = "a" + "b" 4766 return a + a, "ab" == "b", "ab" != "b", "ab" == "ab", "ab" != "ab" 4767 4768 self.checkScript(foo, ()) 4769 4770 def test_string_sorted(self): 4771 def foo(strs: List[str]): 4772 return sorted(strs) 4773 4774 FileCheck() \ 4775 .check("graph") \ 4776 .check_next("str[] = aten::sorted") \ 4777 .check_next("return") \ 4778 .run(str(torch.jit.script(foo).graph)) 4779 4780 inputs = ["str3", "str2", "str1"] 4781 self.checkScript(foo, (inputs,)) 4782 4783 def test_string_sort(self): 4784 def foo(strs: List[str]): 4785 strs.sort() 4786 return strs 4787 4788 inputs = ["str3", "str2", "str1"] 4789 self.checkScript(foo, (inputs,)) 4790 4791 def test_tuple_sorted(self): 4792 def foo(tups: List[Tuple[int, int]]): 4793 return sorted(tups) 4794 4795 inputs = [(1, 2), (0, 2), (1, 3)] 4796 self.checkScript(foo, (inputs,)) 4797 4798 def test_tuple_sort(self): 4799 def foo(tups: List[Tuple[int, int]]): 4800 tups.sort() 4801 return tups 4802 4803 inputs = [(1, 2), (0, 2), (1, 3)] 4804 self.checkScript(foo, (inputs,)) 4805 4806 def test_tuple_sort_reverse(self): 4807 def foo(tups: List[Tuple[int, int]]): 4808 tups.sort(reverse=True) 4809 return tups 4810 4811 inputs = [(1, 2), (0, 2), (1, 3)] 4812 self.checkScript(foo, (inputs,)) 4813 4814 def test_tuple_unsortable_element_type(self): 4815 @torch.jit.script 4816 def foo(): 4817 tups = [({1: 2}, {2: 3})] 4818 tups.sort() 4819 return tups 4820 4821 with self.assertRaisesRegexWithHighlight(RuntimeError, "are not sortable", "tups.sort"): 4822 foo() 4823 4824 def test_tuple_unsortable_diff_type(self): 4825 @torch.jit.script 4826 def foo(inputs: List[Any]): 4827 inputs.sort() 4828 return inputs 4829 4830 inputs = [(1, 2), ("foo", "bar")] 4831 with self.assertRaisesRegexWithHighlight(RuntimeError, "Only values of same type can be compared", "inputs.sort"): 4832 foo(inputs) 4833 4834 def test_tuple_nested_sort(self): 4835 def foo(inputs: List[Tuple[int, Tuple[int, str]]]): 4836 inputs.sort() 4837 return inputs 4838 4839 inputs = [(1, (2, "foo")), (1, (2, "bar")), (1, (0, "bar"))] 4840 self.checkScript(foo, (inputs,)) 4841 4842 def test_tuple_unsortable_nested_diff_type(self): 4843 @torch.jit.script 4844 def foo(inputs: List[Any]): 4845 inputs.sort() 4846 return inputs 4847 4848 inputs = [(1, (2, 3)), (2, ("foo", "bar"))] 4849 with self.assertRaisesRegexWithHighlight(RuntimeError, "Only values of same type can be compared", "inputs.sort"): 4850 foo(inputs) 4851 4852 def test_string_new_line(self): 4853 with self.assertRaisesRegex(RuntimeError, "expected a valid token*"): 4854 torch.jit.CompilationUnit(''' 4855 def test_while(a): 4856 print(" 4857 a") 4858 return a 4859 ''') 4860 4861 def test_string_single_escape(self): 4862 with self.assertRaisesRegex(RuntimeError, "expected a valid token*"): 4863 torch.jit.CompilationUnit(''' 4864 def test_while(a): 4865 print("\\") 4866 return a 4867 ''') 4868 4869 def test_script_annotation(self): 4870 @torch.jit.script 4871 def foo(a): 4872 return a + a + a 4873 s = Variable(torch.rand(2)) 4874 self.assertEqual(s + s + s, foo(s)) 4875 4876 def test_torch_pow(self): 4877 def func(a, b): 4878 return pow(a, b) 4879 4880 def func2(a, b, c, d): 4881 return pow(pow(c + a, b), d) 4882 4883 def func3(a : int, b : float): 4884 # type: (int, float) -> float 4885 return pow(a, b) 4886 4887 def func4(): 4888 # type: () -> float 4889 return pow(2, -2) 4890 4891 def func5(x, y): 4892 return pow(x.item(), y.item()) 4893 4894 def func6(a : int, b : int): 4895 # type: (int, int) -> float 4896 return pow(a, b) 4897 4898 a = torch.rand(1) 4899 b = torch.rand(1) 4900 c = torch.rand(1) 4901 d = torch.rand(1) 4902 self.checkScript(func, (a, b)) 4903 self.checkScript(func2, (a, b, c, d)) 4904 self.checkScript(func3, (4, -0.5)) 4905 self.checkScript(func4, ()) 4906 self.checkScript(func6, (2, 4)) 4907 4908 inputs = [torch.tensor(2), torch.tensor(-2), torch.tensor(.5), torch.tensor(.2)] 4909 for x in inputs: 4910 for y in inputs: 4911 if x < 0: 4912 continue 4913 else: 4914 self.checkScript(func5, (x, y)) 4915 4916 @unittest.skipIf(not RUN_CUDA, "device tests require CUDA") 4917 def test_pow_scalar_backward_cuda(self): 4918 # see that scalar exponent works with cuda base (#19253) 4919 with enable_profiling_mode_for_profiling_tests(): 4920 for dtype in [torch.float, torch.double]: 4921 @torch.jit.script 4922 def func(a, b): 4923 # type: (Tensor, float) -> Tensor 4924 return (a * 2) ** b 4925 4926 a = torch.rand(1, requires_grad=True, device='cuda', dtype=dtype) 4927 func(a, 1, profile_and_replay=True).backward() 4928 4929 @torch.jit.script 4930 def func(a, b): 4931 # type: (float, Tensor) -> Tensor 4932 return a ** (b * 2 + 1) 4933 4934 a = torch.rand(1, requires_grad=True, device='cuda', dtype=dtype) 4935 func(2, a, profile_and_replay=True).backward() 4936 4937 def _check_code(self, code_str, fn_name, inputs): 4938 scope = {} 4939 exec(code_str, globals(), scope) 4940 cu = torch.jit.CompilationUnit(code_str) 4941 self.assertEqual(cu.func(*inputs), scope[fn_name](*inputs)) 4942 4943 @unittest.skipIf(not RUN_CUDA, 'no CUDA') 4944 def test_scriptmodule_releases_tensors_cuda(self): 4945 with enable_profiling_mode_for_profiling_tests(): 4946 @torch.jit.script 4947 def fn(x, y): 4948 return x.sigmoid() * y.tanh() 4949 4950 def test(backward=False): 4951 x = torch.randn(3, 3, dtype=torch.double, device='cuda', requires_grad=True) 4952 y = torch.randn(3, 3, dtype=torch.double, device='cuda', requires_grad=True) 4953 out = fn(x, y, profile_and_replay=True) 4954 if backward: 4955 out.sum().backward() 4956 4957 with self.assertLeaksNoCudaTensors(): 4958 test() 4959 test() 4960 test() 4961 4962 if GRAPH_EXECUTOR != ProfilingMode.SIMPLE: 4963 with self.assertLeaksNoCudaTensors(): 4964 test(backward=True) 4965 test(backward=True) 4966 test(backward=True) 4967 4968 @skipIfTorchDynamo("Not a TorchDynamo suitable test") 4969 def test_index(self): 4970 def consec(size, start=0): 4971 numel = torch.tensor(size).prod().item() 4972 return torch.arange(numel).view(size) 4973 4974 def consec_list(size): 4975 return list(range(size)) 4976 4977 def random_string(size): 4978 letters = string.ascii_lowercase 4979 return "".join(random.choice(letters) for i in range(size)) 4980 4981 def check_indexing(indexing, tensor): 4982 template = dedent(""" 4983 def func(x): 4984 return x{} 4985 """) 4986 4987 self._check_code(template.format(indexing), "func", [tensor]) 4988 4989 def check_dynamic_indexing(indexing, tensor, value1, value2): 4990 value1 = torch.tensor(value1) 4991 value2 = torch.tensor(value2) 4992 4993 template = dedent(""" 4994 def func(x, value1, value2): 4995 i = int(value1) 4996 j = int(value2) 4997 return x{} 4998 """) 4999 5000 self._check_code(template.format(indexing), "func", [tensor, value1, value2]) 5001 5002 # Torchscript assumes type Tensor by default, so we need this explicit 5003 # declaration. 5004 def check_indexing_list_int(indexing, list): 5005 template = dedent(""" 5006 def func(x): 5007 # type: (List[int]) -> Any 5008 return x{} 5009 """) 5010 5011 self._check_code(template.format(indexing), "func", [list]) 5012 5013 def check_indexing_str(indexing, str): 5014 template = dedent(""" 5015 def func(x): 5016 # type: (str) -> Any 5017 return x{} 5018 """) 5019 5020 self._check_code(template.format(indexing), "func", [str]) 5021 5022 # basic slices 5023 check_indexing('[0]', consec((3, 3))) 5024 check_indexing('[1]', consec((3, 3), 10)) 5025 check_indexing('[2]', consec((3, 3), 19)) 5026 check_indexing('[2]', consec((3,))) 5027 check_indexing('[-1]', consec((3, 3), 19)) 5028 check_indexing('[0:2]', consec((3, 3, 3))) 5029 check_indexing('[1:-1]', consec((3, 3, 3))) 5030 check_indexing('[-3:-1]', consec((6, 3))) 5031 check_indexing('[1:]', consec((3, 3))) 5032 check_indexing('[:1]', consec((3, 3))) 5033 check_indexing('[:]', consec((3, 2))) 5034 5035 # multi-dim: indexes 5036 check_indexing('[0, 1]', consec((3, 3))) 5037 check_indexing('[0, 1]', consec((3, 3, 2))) 5038 check_indexing('[1, 0, 2]', consec((3, 3, 3))) 5039 check_indexing('[2, -1]', consec((3, 3))) 5040 5041 # multi-dim: mixed slicing and indexing 5042 check_indexing('[0, 1:2]', consec((3, 3))) 5043 check_indexing('[0, :1]', consec((3, 3, 2))) 5044 check_indexing('[1, 2:]', consec((3, 3, 3))) 5045 check_indexing('[-1, 1:, 0]', consec((3, 3, 3, 3))) 5046 check_indexing('[1:, -1, 0]', consec((3, 3, 3, 3))) 5047 check_indexing('[-1, 2:, 1:2]', consec((3, 3, 3, 3))) 5048 check_indexing('[-1, 1:, 0]', consec((3, 3, 3, 3))) 5049 check_indexing('[-1, :, 0, 2]', consec((3, 3, 3, 3))) 5050 5051 # zero-sized slices 5052 check_indexing('[0:0]', consec((2, 2))) 5053 check_indexing('[0:0, 1]', consec((3, 3))) 5054 5055 # trivial expression usage 5056 check_indexing('[1+1]', consec((3, 3))) 5057 check_indexing('[1:(0 + 2)]', consec((3, 3, 3))) 5058 5059 # None for new dimensions 5060 check_indexing('[None, 0]', consec((3, 3))) 5061 check_indexing('[1, None]', consec((3, 3), 10)) 5062 check_indexing('[None, None, 2]', consec((3, 3), 19)) 5063 check_indexing('[None, 2, None]', consec((3,))) 5064 check_indexing('[0:2, None]', consec((3, 3, 3))) 5065 check_indexing('[None, 1:-1]', consec((3, 3, 3))) 5066 check_indexing('[None, -3:-1, None]', consec((6, 3))) 5067 check_indexing('[-1, None, 2:, None, 1:2]', consec((3, 3, 3, 3))) 5068 check_indexing('[None, -1, None, 2:, None, 1:2, None]', consec((3, 3, 3, 3))) 5069 5070 # dynamic expression usage 5071 check_dynamic_indexing("[i + j]", consec((3, 3)), 0, 1) 5072 check_dynamic_indexing("[i:j, i]", consec((3, 3, 2)), 0, 2) 5073 5074 # positive striding 5075 check_indexing_list_int('[0]', consec_list(6)) 5076 check_indexing_list_int('[1]', consec_list(7)) 5077 check_indexing_list_int('[2]', consec_list(8)) 5078 check_indexing_list_int('[2]', consec_list(9)) 5079 check_indexing_list_int('[-1]', consec_list(10)) 5080 check_indexing_list_int('[0:2]', consec_list(11)) 5081 check_indexing_list_int('[1:-1]', consec_list(12)) 5082 check_indexing_list_int('[-3:-1]', consec_list(13)) 5083 check_indexing_list_int('[1:]', consec_list(15)) 5084 check_indexing_list_int('[:1]', consec_list(16)) 5085 check_indexing_list_int('[:]', consec_list(17)) 5086 check_indexing_list_int('[::]', consec_list(0)) 5087 check_indexing_list_int('[1000::]', consec_list(0)) 5088 check_indexing_list_int('[:1000:]', consec_list(0)) 5089 5090 # negative striding 5091 check_indexing_list_int('[::-1]', consec_list(7)) 5092 check_indexing_list_int('[:3:-1]', consec_list(7)) 5093 check_indexing_list_int('[3::-1]', consec_list(7)) 5094 check_indexing_list_int('[1000::-1]', consec_list(7)) 5095 check_indexing_list_int('[3:0:-1]', consec_list(7)) 5096 check_indexing_list_int('[3:-1000:-1]', consec_list(7)) 5097 check_indexing_list_int('[0:0:-1]', consec_list(7)) 5098 check_indexing_list_int('[0:-1000:-1]', consec_list(7)) 5099 5100 # only step is specified 5101 check_indexing_list_int('[::-1]', consec_list(0)) 5102 check_indexing_list_int('[::-1]', consec_list(7)) 5103 check_indexing_list_int('[::-2]', consec_list(7)) 5104 check_indexing_list_int('[::2]', consec_list(7)) 5105 check_indexing_list_int('[::42]', consec_list(7)) 5106 check_indexing_list_int('[::-42]', consec_list(7)) 5107 check_indexing_list_int('[::42]', consec_list(0)) 5108 check_indexing_list_int('[::-42]', consec_list(0)) 5109 check_indexing_list_int('[::9223372036854775807]', consec_list(42)) 5110 check_indexing_list_int('[::-9223372036854775807]', consec_list(42)) 5111 with self.assertRaisesRegex(RuntimeError, "out of bounds"): 5112 check_indexing_list_int('[::-9223372036854775808]', consec_list(42)) 5113 with self.assertRaisesRegex(RuntimeError, "should have non-zero step"): 5114 check_indexing_list_int('[::0]', consec_list(42)) 5115 5116 # striding strings 5117 check_indexing_str('[0]', random_string(6)) 5118 check_indexing_str('[1]', random_string(7)) 5119 check_indexing_str('[2]', random_string(8)) 5120 check_indexing_str('[2]', random_string(9)) 5121 check_indexing_str('[-1]', random_string(10)) 5122 check_indexing_str('[0:2]', random_string(11)) 5123 check_indexing_str('[1:-1]', random_string(12)) 5124 check_indexing_str('[-3:-1]', random_string(13)) 5125 check_indexing_str('[1:]', random_string(15)) 5126 check_indexing_str('[:1]', random_string(16)) 5127 check_indexing_str('[:]', random_string(17)) 5128 check_indexing_str('[::]', random_string(0)) 5129 check_indexing_str('[1000::]', random_string(0)) 5130 check_indexing_str('[:1000:]', random_string(0)) 5131 5132 check_indexing_str('[::-1]', random_string(7)) 5133 check_indexing_str('[:3:-1]', random_string(7)) 5134 check_indexing_str('[3::-1]', random_string(7)) 5135 check_indexing_str('[1000::-1]', random_string(7)) 5136 check_indexing_str('[3:0:-1]', random_string(7)) 5137 check_indexing_str('[3:-1000:-1]', random_string(7)) 5138 check_indexing_str('[0:0:-1]', random_string(7)) 5139 check_indexing_str('[0:-1000:-1]', random_string(7)) 5140 5141 check_indexing_str('[::-1]', random_string(0)) 5142 check_indexing_str('[::-1]', random_string(7)) 5143 check_indexing_str('[::-2]', random_string(7)) 5144 check_indexing_str('[::2]', random_string(7)) 5145 check_indexing_str('[::42]', random_string(7)) 5146 check_indexing_str('[::-42]', random_string(7)) 5147 check_indexing_str('[::42]', random_string(0)) 5148 check_indexing_str('[::-42]', random_string(0)) 5149 check_indexing_str('[::9223372036854775807]', random_string(42)) 5150 check_indexing_str('[::-9223372036854775807]', random_string(42)) 5151 with self.assertRaisesRegex(RuntimeError, "out of bounds"): 5152 check_indexing_str('[::-9223372036854775808]', random_string(42)) 5153 with self.assertRaisesRegex(RuntimeError, "should have non-zero step"): 5154 check_indexing_str('[::0]', random_string(42)) 5155 5156 def test_module_copy_with_attributes(self): 5157 class Vocabulary(torch.jit.ScriptModule): 5158 def __init__(self, vocab_list): 5159 super().__init__() 5160 self._vocab = torch.jit.Attribute(vocab_list, List[str]) 5161 self.some_idx = torch.jit.Attribute(2, int) 5162 self.idx = torch.jit.Attribute( 5163 {word: i for i, word in enumerate(vocab_list)}, Dict[str, int] 5164 ) 5165 5166 @torch.jit.script_method 5167 def lookup_indices_1d(self, values): 5168 # type: (List[str]) -> List[int] 5169 result = torch.jit.annotate(List[int], []) 5170 # Direct list iteration not supported 5171 for i in range(len(values)): 5172 value = values[i] 5173 result.append(self.idx.get(value, self.some_idx)) 5174 return result 5175 5176 @torch.jit.script_method 5177 def forward(self, values): 5178 # type: (List[List[str]]) -> List[List[int]] 5179 result = torch.jit.annotate(List[List[int]], []) 5180 # Direct list iteration not supported 5181 for i in range(len(values)): 5182 result.append(self.lookup_indices_1d(values[i])) 5183 return result 5184 5185 v = Vocabulary(list('uabcdefg')) 5186 v.__copy__() 5187 5188 def test_tuple_to_opt_list(self): 5189 @torch.jit.script 5190 def foo(x): 5191 # type: (Optional[List[int]]) -> int 5192 return 1 5193 5194 @torch.jit.script 5195 def tuple_call(): 5196 return foo((1, 2)) 5197 5198 def test_keyword(self): 5199 @torch.jit.script 5200 def func(x): 5201 return torch.sum(x, dim=0) 5202 5203 x = torch.rand(10, dtype=torch.float, requires_grad=True) 5204 y = func(x) 5205 y2 = torch.sum(x, dim=0) 5206 self.assertEqual(y, y2) 5207 5208 def test_constant_pooling_none(self): 5209 @torch.jit.script 5210 def typed_nones(a=None, b=None, c=None): 5211 # type: (Optional[int], Optional[bool], Optional[Tensor]) -> Tuple[Optional[int], Optional[bool], Optional[Tensor]] 5212 return a, b, c 5213 5214 @torch.jit.script 5215 def test(a): 5216 # type: (bool) -> None 5217 if a: 5218 print(typed_nones()) 5219 else: 5220 print(typed_nones()) 5221 5222 graph_str = str(test.graph) 5223 self.assertTrue(graph_str.count("NoneType = prim::Constant") == 1) 5224 5225 def test_constant_pooling_same_identity(self): 5226 def foo(): 5227 a = torch.tensor([4]) 5228 b = (a,) 5229 index = len(a) - 1 5230 c = b[index] 5231 d = b[index] 5232 return c, d 5233 5234 foo_script = torch.jit.script(foo) 5235 self.run_pass('constant_propagation', foo_script.graph) 5236 self.run_pass('constant_pooling', foo_script.graph) 5237 # even though the c & d escape scope, we are still able 5238 # pool them into one constant because they are the same object 5239 FileCheck().check_count("prim::Constant", 1, exactly=True).run(foo_script.graph) 5240 self.assertEqual(foo(), foo_script()) 5241 5242 def test_constant_pooling_introduce_aliasing(self): 5243 @torch.jit.script 5244 def foo(): 5245 a = torch.tensor(1) 5246 b = torch.tensor(1) 5247 return a, b 5248 5249 self.run_pass('constant_propagation', foo.graph) 5250 self.run_pass('constant_pooling', foo.graph) 5251 # dont pool constants bc it would introduce observable alias relationship changing 5252 a, b = foo() 5253 self.assertIsNot(a, b) 5254 5255 def test_literal(self): 5256 def func1(a, b): 5257 c = a, b 5258 d, e = c 5259 return d + e 5260 5261 def func2(a, b): 5262 c = a, (a, b) 5263 d, e = c 5264 f, g = e 5265 return d + f + g 5266 5267 def func3(a, b): 5268 # type: (float, float) -> float 5269 c = 0., (0., 0.) 5270 x = True 5271 while x: 5272 x = False 5273 c = a, (a, b) 5274 d, e = c 5275 f, g = e 5276 return d + f + g 5277 5278 a = torch.rand(1, requires_grad=True) 5279 b = torch.rand(1, requires_grad=True) 5280 self.checkScript(func1, (a, b), optimize=True) 5281 self.checkScript(func2, (a, b), optimize=True) 5282 self.checkScript(func3, (a.item(), b.item()), optimize=True) 5283 5284 def test_expand(self): 5285 @torch.jit.script 5286 def func(x, y): 5287 return x + y 5288 5289 x = torch.rand(2, 3, dtype=torch.float, requires_grad=True) 5290 y = torch.rand(3, dtype=torch.float, requires_grad=True) 5291 out = func(x, y) 5292 self.assertEqual(func(x, y), x + y) 5293 5294 grad = torch.randn(2, 3, dtype=torch.float) 5295 out.backward(grad) 5296 self.assertEqual(x.grad, grad) 5297 self.assertEqual(y.grad, grad.sum(dim=0)) 5298 5299 def test_sum(self): 5300 @torch.jit.script 5301 def func(x): 5302 return x.sum(dim=[4]) 5303 5304 @torch.jit.script 5305 def func2(x): 5306 return x.sum(dim=4) 5307 5308 # test that shape analysis is written correctly for sum with OptionalIntArrayRef[1] dim argument 5309 self.run_pass('constant_propagation', func.graph) 5310 self.run_pass('constant_propagation', func2.graph) 5311 g = _propagate_shapes(func.graph, (torch.zeros(1, 1, 1, 1, 4),), False) 5312 g2 = _propagate_shapes(func2.graph, (torch.zeros(1, 1, 1, 1, 4),), False) 5313 5314 def test_cat(self): 5315 with enable_profiling_mode_for_profiling_tests(): 5316 @torch.jit.script 5317 def func(x): 5318 return torch.cat((x, x), dim=0) 5319 5320 x = torch.rand(10, dtype=torch.float, requires_grad=True) 5321 self.assertEqual(func(x, profile_and_replay=True), torch.cat((x, x), dim=0)) 5322 5323 @torch.jit.script 5324 def func2(x, y): 5325 return torch.cat((x, x), y) 5326 5327 with disable_autodiff_subgraph_inlining(): 5328 for sizes in ((2, 2), (0, 2)): 5329 x = torch.rand(sizes).requires_grad_() 5330 y = torch.tensor(1) 5331 5332 output = func2(x, y, profile_and_replay=True) 5333 output_ref = torch.cat((x, x), y) 5334 self.assertEqual(output, output_ref) 5335 5336 if GRAPH_EXECUTOR != ProfilingMode.SIMPLE: 5337 self.assertAutodiffNode(func2.graph_for(x, y), True, ['aten::cat'], []) 5338 5339 grad = torch.autograd.grad(output.sum(), x) 5340 grad_ref = torch.autograd.grad(output_ref.sum(), x) 5341 self.assertEqual(grad, grad_ref) 5342 5343 def test_cat_lifts(self): 5344 @torch.jit.script 5345 def foo(x): 5346 return torch.cat([x, x], dim=1) 5347 5348 @torch.jit.script 5349 def foo2(x): 5350 return torch.cat([], dim=1) 5351 5352 @torch.jit.script 5353 def foo3(x): 5354 return torch.cat([x], dim=1) 5355 5356 for g in [foo.graph, foo2.graph, foo3.graph]: 5357 FileCheck().check("int =").check("ListConstruct").check("aten::cat").run(str(g)) 5358 5359 def test_stack(self): 5360 with enable_profiling_mode_for_profiling_tests(): 5361 @torch.jit.script 5362 def func(x): 5363 return torch.stack((x, x), dim=1) 5364 x = torch.rand(10, 10) 5365 self.assertEqual(func(x, profile_and_replay=True), torch.stack((x, x), dim=1)) 5366 5367 @torch.jit.script 5368 def func2(x, y): 5369 return torch.stack((x, y), dim=0) 5370 5371 with disable_autodiff_subgraph_inlining(): 5372 x = torch.randn([2, 2]).requires_grad_() 5373 y = torch.randn([2, 2]).requires_grad_() 5374 5375 output = func2(x, y, profile_and_replay=True) 5376 output_ref = torch.stack((x, y), 0) 5377 self.assertEqual(output, output_ref) 5378 if GRAPH_EXECUTOR != ProfilingMode.SIMPLE: 5379 self.assertAutodiffNode(func2.graph_for(x, y), True, ['aten::stack'], []) 5380 5381 grads = torch.autograd.grad(output.sum(), (x, y)) 5382 grads_ref = torch.autograd.grad(output_ref.sum(), (x, y)) 5383 self.assertEqual(grads, grads_ref) 5384 5385 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, 5386 "Profiling executor will be using different heuristics for constructing differentiable graphs") 5387 def test_unbind(self): 5388 with enable_profiling_mode_for_profiling_tests(): 5389 @torch.jit.script 5390 def func(x, y): 5391 # type: (Tensor, int) -> List[Tensor] 5392 return torch.unbind(x, y) 5393 5394 with disable_autodiff_subgraph_inlining(): 5395 x = torch.rand([2, 2]).requires_grad_() 5396 y = 0 5397 outputs = func(x, y, profile_and_replay=True) 5398 outputs_ref = torch.unbind(x, dim=y) 5399 self.assertEqual(outputs, outputs_ref) 5400 self.assertAutodiffNode(func.graph_for(x, y), True, [], []) 5401 5402 grad = torch.autograd.grad(_sum_of_list(outputs), x) 5403 grad_ref = torch.autograd.grad(_sum_of_list(outputs_ref), x) 5404 self.assertEqual(grad, grad_ref) 5405 5406 5407 @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.PROFILING, 5408 "Profiling executor fails to recognize that tensors in a list require gradients") 5409 def test_meshgrid(self): 5410 with enable_profiling_mode_for_profiling_tests(): 5411 @torch.jit.script 5412 def func(a): 5413 # type: (List[Tensor]) -> List[Tensor] 5414 return torch.meshgrid(a) 5415 with disable_autodiff_subgraph_inlining(): 5416 a = torch.tensor([1.0, 2, 3]).requires_grad_() 5417 b = torch.tensor([1.0, 2, 3, 4]).requires_grad_() 5418 inputs = [a, b] 5419 5420 outputs_ref = torch.meshgrid(inputs) 5421 outputs = func(inputs, profile_and_replay=True) 5422 self.assertEqual(outputs, outputs_ref) 5423 5424 if GRAPH_EXECUTOR != ProfilingMode.SIMPLE: 5425 self.assertAutodiffNode(func.graph_for(inputs), True, [], []) 5426 5427 grads = torch.autograd.grad(_sum_of_list(outputs), inputs) 5428 grads_ref = torch.autograd.grad(_sum_of_list(outputs_ref), inputs) 5429 self.assertEqual(grads, grads_ref) 5430 5431 def test_tensor_len(self): 5432 def func(x): 5433 return len(x) 5434 5435 self.checkScript(func, [torch.ones(4, 5, 6)]) 5436 5437 def test_func_call(self): 5438 def add(a, b): 5439 return a + b 5440 5441 def mul(a, x): 5442 return a * x 5443 5444 def func(alpha, beta, x, y): 5445 return add(mul(alpha, x), mul(beta, y)) 5446 5447 alpha = torch.rand(1, dtype=torch.float, requires_grad=True) 5448 beta = torch.rand(1, dtype=torch.float, requires_grad=True) 5449 x = torch.rand(3, dtype=torch.float, requires_grad=True) 5450 y = torch.rand(3, dtype=torch.float, requires_grad=True) 5451 5452 # NOTE: cannot optimize yet because broadcasts are not inserted before the fuser runs 5453 self.checkScript(func, [alpha, beta, x, y], optimize=False) 5454 5455 @unittest.skip("bailouts are being deprecated") 5456 def test_profiling_graph_executor(self): 5457 @torch.jit.script 5458 def def_in_one_branch(x, z): 5459 # type: (Tensor, bool) -> float 5460 y = x 5461 if z is False: 5462 y = x + 1 5463 5464 return y.sum() 5465 5466 a = torch.rand(2, 3) 5467 5468 with enable_profiling_mode_for_profiling_tests(): 5469 # check prim::profile are inserted 5470 profiled_graph_str = str(def_in_one_branch.graph_for(a, True)) 5471 FileCheck().check_count("prim::profile", 4).run(profiled_graph_str) 5472 # this call is optimized for 5473 # the given shape of (2, 3) 5474 def_in_one_branch(a, False) 5475 # change shape to (3) 5476 # so we go down a bailout path 5477 a = torch.ones(3) 5478 # check prim::BailOuts are inserted 5479 bailout_graph_str = str(def_in_one_branch.graph_for(a, True)) 5480 FileCheck().check_count("prim::BailOut", 3).run(bailout_graph_str) 5481 # this triggers all 3 bailouts 5482 self.assertEqual(def_in_one_branch(a, False), 6.0) 5483 # this triggers 2 bailouts 5484 self.assertEqual(def_in_one_branch(a, True), 3.0) 5485 5486 @unittest.skip("bailouts are being deprecated") 5487 def test_maxpool_guard_elimination(self): 5488 @torch.jit.script 5489 def my_maxpool(x): 5490 return F.max_pool1d(x, kernel_size=[1]) + torch.ones([32, 32, 32]) 5491 5492 a = torch.rand(32, 32, 32) 5493 5494 with enable_profiling_mode_for_profiling_tests(): 5495 my_maxpool(a) 5496 bailout_graph_str = str(my_maxpool.graph_for(a)) 5497 FileCheck().check_count("prim::BailOut", 1).run(bailout_graph_str) 5498 5499 @unittest.skip("bailouts are being deprecated") 5500 def test_slice_guard_elimination(self): 5501 @torch.jit.script 5502 def my_slice(x): 5503 return x[0:16:2] + x[0:16:2] 5504 5505 a = torch.rand(32, 4) 5506 5507 with enable_profiling_mode_for_profiling_tests(): 5508 my_slice(a) 5509 bailout_graph_str = str(my_slice.graph_for(a)) 5510 FileCheck().check_count("prim::BailOut", 1).run(bailout_graph_str) 5511 5512 @unittest.skip("bailouts are being deprecated") 5513 def test_unsqueeze_guard_elimination(self): 5514 @torch.jit.script 5515 def my_unsqueeze(x): 5516 return torch.unsqueeze(x, 0) + torch.unsqueeze(x, 0) 5517 5518 a = torch.rand(32, 4) 5519 5520 with enable_profiling_mode_for_profiling_tests(): 5521 my_unsqueeze(a) 5522 bailout_graph_str = str(my_unsqueeze.graph_for(a)) 5523 FileCheck().check_count("prim::BailOut", 2).run(bailout_graph_str) 5524 5525 def test_resize_input_ops(self): 5526 # resize_ and resize_as resize the input tensor. because our shape analysis 5527 # is flow invariant, we set any Tensor that can alias a resized Tensor 5528 # to the base Tensor Type, without size information. 5529 5530 # testing that value which is an input of a graph gets handled 5531 def out_op_graph_input(): 5532 @torch.jit.script 5533 def test(x, y, z): 5534 torch.mul(x, y, out=z) 5535 return z 5536 5537 graph = _propagate_shapes(test.graph, 5538 (torch.zeros(2, 1), torch.zeros(1, 2), torch.zeros(1, 1, 1)), False) 5539 self.assertTrue(next(graph.outputs()).type() == TensorType.get()) 5540 out_op_graph_input() 5541 5542 def test_resize(): 5543 @torch.jit.script 5544 def test(x): 5545 after_resize_alias = torch.zeros([2]) 5546 for _i in range(5): 5547 b = x + 1 5548 f = [1] 5549 before_resize_alias = b.sub_(1) 5550 # for i in range(10): 5551 f.append(1) 5552 b.resize_(f) 5553 after_resize_alias = b.add_(1) 5554 return after_resize_alias 5555 5556 self.run_pass('constant_propagation', test.graph) 5557 g = _propagate_shapes(test.graph, (torch.zeros(1, 1),), False) 5558 resize_node = g.findNode("aten::resize_") 5559 # first input and output of b.resize_ is b 5560 self.assertTrue(next(resize_node.inputs()).type() == TensorType.get()) 5561 self.assertTrue(next(resize_node.outputs()).type() == TensorType.get()) 5562 5563 # correctly propagates to b alias set 5564 before_resize = g.findNode("aten::sub_") 5565 self.assertTrue(next(before_resize.outputs()).type() == TensorType.get()) 5566 5567 after_resize = g.findNode("aten::add_") 5568 self.assertTrue(next(after_resize.outputs()).type() == TensorType.get()) 5569 5570 test_resize() 5571 5572 def test_resize_as(): 5573 @torch.jit.script 5574 def test(x): 5575 b = torch.zeros([2, 2]) 5576 b.resize_as_(x) 5577 return b 5578 5579 g = test.graph 5580 self.run_pass('constant_propagation', g) 5581 g = _propagate_shapes(test.graph, (torch.zeros(1, 1),), False) 5582 5583 # x doesn't alias a resized op so it shouldn't be set to base Tensor type 5584 self.assertTrue(next(g.inputs()).type() != TensorType.get()) 5585 # return is resized 5586 self.assertTrue(next(g.outputs()).type() == TensorType.get()) 5587 5588 test_resize_as() 5589 5590 def test_uninitialized(self): 5591 graph_str = """graph(): 5592 %1 : int = prim::Uninitialized() 5593 %2 : int = prim::Constant[value=1]() 5594 %3 : int = aten::add(%1, %2) 5595 return (%3) 5596 """ 5597 g = parse_ir(graph_str) 5598 m = self.createFunctionFromGraph(g) 5599 self.getExportImportCopy(m) 5600 with self.assertRaisesRegex(RuntimeError, "expected int"): 5601 m() 5602 5603 5604 @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't use requires_grad information") 5605 @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.PROFILING, "Peeling is now disabled") 5606 def test_requires_grad_loop(self): 5607 @torch.jit.script 5608 def test(x, y, z): 5609 # type: (Tensor, Tensor, int) -> Tensor 5610 for _ in range(z): 5611 x = y 5612 return x 5613 5614 # x requires grad, y does not 5615 # testing that requires grad analysis correctly exits, with its input 5616 # to the loop (x) requiring grad and its output to the loop not requiring grad 5617 # and the output of the node conservatively setting grad to true 5618 5619 inps = (torch.tensor(1.0, requires_grad=True), torch.tensor(1), 10) 5620 test(*inps, profile_and_replay=True) 5621 5622 graph = test.graph_for(*inps) 5623 loop = graph.findNode("prim::Loop") 5624 loop_body = next(loop.blocks()) 5625 loop_inputs = list(loop_body.inputs()) 5626 loop_outputs = list(loop_body.outputs()) 5627 5628 if GRAPH_EXECUTOR == ProfilingMode.PROFILING: 5629 # TODO: simplify this test as it's very sensitive 5630 # the optimized graph will have 3 loops 5631 # the original loop is peeled 5632 # peeled loop also gets unrolled 5633 index_of_x_in_peeled_unrolled_loop = -2 5634 self.assertTrue(loop_inputs[index_of_x_in_peeled_unrolled_loop].requires_grad()) 5635 bailouts_in_outer_block = graph.findAllNodes("prim::BailOut", False) 5636 last_bailout_index_on_loops_output = -1 5637 self.assertFalse(bailouts_in_outer_block[last_bailout_index_on_loops_output].output().requires_grad()) 5638 else: 5639 self.assertTrue(loop_inputs[1].requires_grad()) 5640 self.assertTrue(loop.output().requires_grad()) 5641 self.assertFalse(loop_outputs[1].requires_grad()) 5642 5643 def test_view_shape_prop(self): 5644 cu = torch.jit.CompilationUnit(''' 5645 def test_view_shape_prop(a): 5646 return a.view(size=[-1]) 5647 ''') 5648 inputs = [torch.zeros(10, 10)] 5649 outputs = torch.zeros(100) 5650 5651 real_outs = cu.test_view_shape_prop(*inputs) 5652 self.assertEqual(real_outs, outputs) 5653 5654 @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 5655 def test_view_listconstruct_shape_prop(self): 5656 def fn(x): 5657 B = x.size(0) 5658 C = x.size(1) 5659 T = x.size(2) 5660 return x.view(T, B, C) 5661 5662 x = torch.randn(3, 1, 5, requires_grad=True) 5663 fn = torch.jit.script(fn) 5664 graph = _propagate_shapes(fn.graph, (x,), False) 5665 self.assertTrue(next(graph.outputs()).type().scalarType() == 'Float') 5666 5667 def test_shape_prop_promotion(self): 5668 @torch.jit.script 5669 def fn(x, y): 5670 return x + y 5671 5672 x, y = torch.rand(3, 4, dtype=torch.float), torch.rand(3, 4, dtype=torch.double) 5673 graph = _propagate_shapes(fn.graph, (x, y), False) 5674 FileCheck().check('Double(*, *, device=cpu) = aten::add').run(graph) 5675 5676 def test_shape_prop_promote_scalar_arg(self): 5677 @torch.jit.script 5678 def fn(x): 5679 return math.pi + x 5680 5681 x = torch.zeros(3, 4, dtype=torch.long) 5682 graph = _propagate_shapes(fn.graph, (x,), False) 5683 default = torch.get_default_dtype() 5684 if default == torch.float: 5685 FileCheck().check('Float(*, *, requires_grad=0, device=cpu) = aten::add').run(graph) 5686 else: 5687 FileCheck().check('Double(*, *, requires_grad=0, device=cpu) = aten::add').run(graph) 5688 5689 def test_integral_shape_inference(self): 5690 cu = torch.jit.CompilationUnit(''' 5691 def test_integral_shape_inference(a): 5692 return a * a 5693 ''') 5694 inputs = [torch.ones(10, 10, dtype=torch.long)] 5695 outputs = torch.ones(10, 10, dtype=torch.long) 5696 5697 self.assertEqual(cu.test_integral_shape_inference(*inputs), outputs) 5698 5699 @unittest.skipIf(RUN_CUDA, 'This tests the CPU fuser') 5700 @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle") 5701 @enable_cpu_fuser 5702 def test_batchnorm_fuser_cpu(self): 5703 code = ''' 5704 graph(%3 : Tensor, 5705 %7 : Tensor, 5706 %12 : Float(*, *), 5707 %13 : Tensor, 5708 %25 : Tensor): 5709 %23 : int = prim::Constant[value=1]() 5710 %22 : float = prim::Constant[value=1e-05]() 5711 %26 : Tensor = aten::sqrt(%25) 5712 %24 : Tensor = aten::add(%26, %22, %23) 5713 %20 : Tensor = aten::reciprocal(%24) 5714 %norm_invstd : Tensor = aten::mul(%20, %23) 5715 %15 : Tensor = aten::sub(%12, %13, %23) 5716 %11 : Tensor = aten::mul(%15, %norm_invstd) 5717 %8 : Tensor = aten::mul(%11, %7) 5718 %5 : Tensor = aten::add(%8, %3, %23) 5719 %1 : Float(*, *) = aten::relu(%5) 5720 return (%1) 5721 ''' 5722 5723 graph = parse_ir(code) 5724 inputs = 5 * [torch.rand(26, 2048, dtype=torch.float)] 5725 code = torch._C._jit_fuser_get_fused_kernel_code(graph, inputs) 5726 FileCheck().check('sqrtf').run(code) 5727 5728 @slowTest 5729 @unittest.skipIf(RUN_CUDA, 'This tests the CPU fuser') 5730 @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle") 5731 @enable_cpu_fuser 5732 def test_fuser_double_float_codegen(self): 5733 fns = ['log', 'log10', 'log1p', 'log2', 'lgamma', 'exp', 'expm1', 'erf', 5734 'erfc', 'cos', 'acos', 'cosh', 'sin', 'asin', 'sinh', 'tan', 5735 'atan', 'tanh', 'sqrt', 'ceil', 'floor', 'round', 'trunc', 5736 'frac'] 5737 5738 def lookup_c_equivalent_fn(aten_fn): 5739 return aten_fn 5740 5741 def test_dispatch(op, expects, dtype, binary=False): 5742 if dtype == torch.double: 5743 dtype_str = 'Double' 5744 elif dtype == torch.float: 5745 dtype_str = 'Float' 5746 else: 5747 raise RuntimeError('Unknown dtype') 5748 5749 if binary: 5750 code = f''' 5751 graph(%3 : Tensor, %4 : Tensor): 5752 %2 : {dtype_str}(*, *) = aten::{op}(%3, %4) 5753 %1 : {dtype_str}(*, *) = aten::relu(%2) 5754 return (%1) 5755 ''' 5756 else: 5757 code = f''' 5758 graph(%3 : Tensor): 5759 %2 : {dtype_str}(*, *) = aten::{op}(%3) 5760 %1 : {dtype_str}(*, *) = aten::relu(%2) 5761 return (%1) 5762 ''' 5763 5764 graph = parse_ir(code) 5765 inputs = (2 if binary else 1) * [torch.rand(26, 2048, dtype=dtype)] 5766 code = torch._C._jit_fuser_get_fused_kernel_code(graph, inputs) 5767 FileCheck().check(expects).run(code) 5768 5769 for fn in fns: 5770 test_dispatch(fn, lookup_c_equivalent_fn(fn) + '(', torch.double) 5771 test_dispatch(fn, lookup_c_equivalent_fn(fn) + 'f(', torch.float) 5772 5773 # 'min', 'max' were previously tested but are now replaced with ternary expressions 5774 # instead of fmin() and fmax() 5775 binary_fns = ['pow'] 5776 for fn in binary_fns: 5777 test_dispatch(fn, lookup_c_equivalent_fn(fn) + '(', torch.double, binary=True) 5778 test_dispatch(fn, lookup_c_equivalent_fn(fn) + 'f(', torch.float, binary=True) 5779 5780 @unittest.skipIf(RUN_CUDA, 'This tests the CPU fuser') 5781 @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle") 5782 @enable_cpu_fuser 5783 def test_fuser_double_literal_precision(self): 5784 code = ''' 5785 graph(%2 : Float(*, *)): 5786 %4 : int = prim::Constant[value=1]() 5787 %3 : float = prim::Constant[value=1.282549830161864]() 5788 %5 : Float(*, *) = aten::add(%2, %3, %4) 5789 %1 : Float(*, *) = aten::relu(%5) 5790 return (%1) 5791 ''' 5792 5793 graph = parse_ir(code) 5794 code = torch._C._jit_fuser_get_fused_kernel_code(graph, [torch.rand(3, 4)]) 5795 FileCheck().check('1.282549830161864').run(code) 5796 5797 def test_fuser_multiple_blocks(self): 5798 cu = torch.jit.CompilationUnit(''' 5799 def test_fuser_multiple_blocks(this, that, theother, meme): 5800 i = 0 5801 while i < 20: 5802 this = torch.cat([this, meme], dim=0) 5803 that = torch.cat([that, meme], dim=0) 5804 theother = torch.cat([theother, meme], dim=0) 5805 i = i + 1 5806 return this, that, theother 5807 ''') 5808 5809 inputs = [torch.ones(0, 10, 10)] * 3 5810 inputs += [torch.ones(1, 10, 10)] 5811 outputs = [torch.ones(20, 10, 10)] * 3 5812 5813 self.assertEqual(cu.test_fuser_multiple_blocks(*inputs), outputs) 5814 5815 @unittest.skip("RuntimeError: VariableType::ID() not implemented") 5816 def test_cast(self): 5817 script = ''' 5818 def to_int(x): 5819 return int(x) 5820 ''' 5821 x = Variable(torch.FloatTensor([1.1, 2.3]), requires_grad=True) 5822 out = Variable(torch.IntTensor([1, 2]), requires_grad=True) 5823 self.checkScript(script, [x], optimize=True, outputs=[out], func='to_int') 5824 5825 def test_str_cast(self): 5826 @torch.jit.script 5827 def to_str(x): 5828 # type: (int) -> str 5829 return str((x, x)) 5830 5831 self.assertEqual("(1, 1)", to_str(1)) 5832 5833 def test_int_cast(self): 5834 @torch.jit.script 5835 def to_int(x): 5836 # type: (str) -> int 5837 return int(x) 5838 5839 self.assertEqual(5, to_int('5')) 5840 self.assertEqual(-5, to_int('-5')) 5841 self.assertEqual(2147483647, to_int('2147483647')) 5842 self.assertEqual(-2147483648, to_int('-2147483648')) 5843 5844 with self.assertRaisesRegex(RuntimeError, "invalid literal for int()"): 5845 to_int('0x20') 5846 5847 with self.assertRaisesRegex(RuntimeError, "invalid literal for int()"): 5848 to_int('0b0001') 5849 5850 def test_python_frontend(self): 5851 def fn(x, y, z): 5852 q = None 5853 q = x + y - z.sigmoid() 5854 print(q) 5855 w = -z 5856 if not x and not y and z: 5857 m = x if not z else y 5858 while x < y > z: 5859 q = x 5860 assert 1 == 1, "hello" 5861 return x 5862 5863 ast = torch.jit.frontend.get_jit_def(fn, fn.__name__) 5864 self.assertExpected(str(ast)) 5865 5866 def test_python_frontend_source_range(self): 5867 def fn(): 5868 raise Exception("hello") # noqa: TRY002 5869 ast = torch.jit.frontend.get_jit_def(fn, fn.__name__) 5870 FileCheck().check("SourceRange at:") \ 5871 .check("def fn():") \ 5872 .check("~~~~~~~~~") \ 5873 .check('raise Exception("hello")') \ 5874 .check('~~~~~~~~~~~~~~~~~ <--- HERE') \ 5875 .run(str(ast.range())) 5876 5877 def test_python_frontend_py3(self): 5878 def fn(): 5879 raise Exception("hello") # noqa: TRY002 5880 ast = torch.jit.frontend.get_jit_def(fn, fn.__name__) 5881 self.assertExpected(str(ast)) 5882 5883 def _make_scalar_vars(self, arr, dtype): 5884 return [torch.tensor(val, dtype=dtype) for val in arr] 5885 5886 5887 def test_string_print(self): 5888 def func(a): 5889 print(a, "a" 'b' '''c''' """d""", 2, 1.5) 5890 return a 5891 5892 inputs = self._make_scalar_vars([1], torch.int64) 5893 self.checkScript(func, inputs, capture_output=True) 5894 5895 def test_while(self): 5896 def func(a, b, max): 5897 while bool(a < max): 5898 a = a + 1 5899 b = b + 1 5900 c = a + b 5901 return c 5902 5903 inputs = self._make_scalar_vars([1, 1, 10], torch.int64) 5904 self.checkScript(func, inputs, optimize=True) 5905 5906 def test_fibb(self): 5907 def func(lim): 5908 first = 1 5909 second = 1 5910 i = 1 5911 somenum = 5 5912 dontmutateme = 3 5913 third = 0 5914 while bool(i < lim): 5915 third = first + second 5916 first = second 5917 second = third 5918 j = 0 5919 while j < 10: 5920 somenum = somenum * 2 5921 j = j + 1 5922 i = i + j 5923 i = i + dontmutateme 5924 5925 st = second + third 5926 fs = first + second 5927 return third, st, fs 5928 5929 inputs = self._make_scalar_vars([10], torch.int64) 5930 self.checkScript(func, inputs, optimize=True) 5931 5932 def test_fibb_totally_better(self): 5933 def fib(x): 5934 # type: (int) -> int 5935 prev = 1 5936 v = 1 5937 for i in range(0, x): 5938 save = v 5939 v = v + prev 5940 prev = save 5941 return v 5942 5943 self.checkScript(fib, (10,)) 5944 5945 def test_if(self): 5946 def func(a, b): 5947 # type: (int, int) -> int 5948 d = 3 5949 if bool(a > 10): 5950 a = 3 + d 5951 else: 5952 b = 3 + d 5953 d = 4 5954 c = a + b 5955 return c 5956 5957 inputs = self._make_scalar_vars([1, -1], torch.int64) 5958 self.checkScript(func, inputs, optimize=True) 5959 5960 def test_if_for_in_range(self): 5961 def func(a, b): 5962 # type: (int, int) -> int 5963 d = 3 5964 for _ in range(20): 5965 if bool(a > 10): 5966 a = 3 + d 5967 else: 5968 b = 3 + d 5969 d = 4 5970 c = a + b 5971 return d 5972 inputs = self._make_scalar_vars([1, -1], torch.int64) 5973 self.checkScript(func, inputs, optimize=True) 5974 5975 def test_if_noelse(self): 5976 def func(a, b): 5977 if bool(a > 10): 5978 a = 3 + b 5979 c = a + b 5980 return c 5981 5982 inputs = self._make_scalar_vars([-1, 1], torch.int64) 5983 self.checkScript(func, inputs, optimize=True) 5984 5985 def test_if_is_none_dispatch(self): 5986 5987 @torch.jit.script 5988 def test_lhs_none_rhs_none(): 5989 # LHS, RHS both alwaysNone, dispatch always_none_branch 5990 # only emit one prim::Constant 5991 if None is None: 5992 return 1 5993 elif None is not None: 5994 return 2 5995 else: 5996 return 3 5997 5998 self.assertTrue(str(test_lhs_none_rhs_none.graph).count(': int = prim::Constant') == 1) 5999 6000 @torch.jit.script 6001 def test_lhs_opt_rhs_none(lhs=None): 6002 # type: (Optional[Tensor]) -> int 6003 # LHS maybeNone: emit normal if stmt that contains 3 constants 6004 if lhs is not None: 6005 return 2 6006 elif lhs is None: 6007 return 1 6008 else: 6009 return 3 6010 6011 self.assertTrue(str(test_lhs_opt_rhs_none.graph).count(': int = prim::Constant') == 3) 6012 6013 @torch.jit.script 6014 def test_lhs_none_rhs_opt(rhs=None): 6015 # type: (Optional[Tensor]) -> int 6016 # RHS maybeNone, emit normal if stmt that contains 3 constants 6017 if None is rhs: 6018 return 1 6019 elif None is not rhs: 6020 return 2 6021 else: 6022 return 3 6023 6024 self.assertTrue(str(test_lhs_opt_rhs_none.graph).count(': int = prim::Constant') == 3) 6025 6026 @torch.jit.script 6027 def test_lhs_never_rhs_none(lhs): 6028 # LHS neverNone, RHS alwaysNone dispatch never_none_branch 6029 # only emit one prim::Constant 6030 if lhs is None: 6031 return 1 6032 elif lhs is not None: 6033 return 2 6034 else: 6035 return 3 6036 6037 self.assertTrue(str(test_lhs_never_rhs_none.graph).count(': int = prim::Constant') == 1) 6038 6039 @torch.jit.script 6040 def test_lhs_none_rhs_never(rhs): 6041 # LHS alwaysNone, RHS neverNone dispatch never_none_branch 6042 # only emit one prim::Constant 6043 if None is rhs: 6044 return 1 6045 elif None is not rhs: 6046 return 2 6047 else: 6048 return 3 6049 6050 self.assertTrue(str(test_lhs_none_rhs_never.graph).count(': int = prim::Constant') == 1) 6051 6052 @torch.jit.script 6053 def test_bool_arith_and(lhs): 6054 if lhs is None and lhs is not None: 6055 return 1 6056 else: 6057 return 2 6058 self.assertEqual(test_bool_arith_and(torch.zeros(3)), 2) 6059 self.assertTrue(str(test_bool_arith_and.graph).count('if') == 0) 6060 6061 @torch.jit.script 6062 def test_bool_arith_or(lhs): 6063 if lhs is None or lhs is not None: 6064 return 1 6065 else: 6066 return 2 6067 self.assertEqual(test_bool_arith_or(torch.zeros(3)), 1) 6068 self.assertTrue(str(test_bool_arith_or.graph).count('if') == 0) 6069 6070 6071 @torch.jit.script 6072 def test_bool_arith_not(lhs): 6073 if lhs is not None: 6074 return 1 6075 else: 6076 return 2 6077 self.assertEqual(test_bool_arith_not(torch.zeros(3)), 1) 6078 self.assertTrue(str(test_bool_arith_not.graph).count('if') == 0) 6079 6080 def test_conditional_casting(self): 6081 def test_bool_cast_tensor(x): 6082 if x: 6083 return 1 6084 else: 6085 return 0 6086 6087 for make_one_dim in [True, False]: 6088 for inp_val in [0.1, 0.0, -0.0, -0.1, -1, 0, 1]: 6089 inp_val = [inp_val] if make_one_dim else inp_val 6090 self.checkScript(test_bool_cast_tensor, (torch.tensor(inp_val),)) 6091 6092 self.checkScriptRaisesRegex(test_bool_cast_tensor, (torch.tensor([1, 1]),), Exception, 6093 "Boolean value of Tensor with more than one value") 6094 6095 def test_not_cast(x): 6096 if not x: 6097 return 1 6098 else: 6099 return 0 6100 6101 self.checkScript(test_not_cast, (torch.tensor(1),)) 6102 self.checkScript(test_not_cast, (torch.tensor(0),)) 6103 6104 with self.assertRaisesRegex(RuntimeError, r"Could not cast value of type Tuple\[Tensor, Tensor\]"): # noqa: W605 6105 @torch.jit.script 6106 def test_mult(x, y): 6107 return not (x, y) 6108 6109 def test_cast_int(x): 6110 # type: (int) -> int 6111 if x: 6112 return 1 6113 else: 6114 return 0 6115 self.checkScript(test_cast_int, (1,)) 6116 self.checkScript(test_cast_int, (0,)) 6117 self.checkScript(test_cast_int, (-1,)) 6118 6119 def test_cast_float(x): 6120 # type: (float) -> int 6121 if x: 6122 return 1 6123 else: 6124 return 0 6125 self.checkScript(test_cast_float, (1.,)) 6126 self.checkScript(test_cast_float, (0.,)) 6127 self.checkScript(test_cast_float, (-1.,)) 6128 6129 with self.assertRaisesRegex(RuntimeError, r"Could not cast value of type Tuple\[int, int\] to bool"): # noqa: W605 6130 6131 @torch.jit.script 6132 def test_bad_conditional(x): 6133 if (1, 2): # noqa: F634 6134 return 6135 else: 6136 return 0 6137 6138 def test_while_nonexistent_value(self): 6139 with self.assertRaisesRegex(RuntimeError, "undefined value x"): 6140 torch.jit.CompilationUnit(''' 6141 def test_while(a, b): 6142 while bool(a < 10): 6143 a = a + x 6144 b = b + 1 6145 return a + b 6146 ''') 6147 6148 def test_while_nonexistent_cond_value(self): 6149 with self.assertRaisesRegex(RuntimeError, "undefined value x"): 6150 torch.jit.CompilationUnit(''' 6151 def test_while(a, b): 6152 while a < x: 6153 a = a + 1 6154 b = b + 1 6155 return a + b 6156 ''') 6157 6158 @torch.jit.script 6159 def test_ternary(x): 6160 # type: (Optional[int]) -> int 6161 x = x if x is not None else 2 6162 return x 6163 6164 @torch.jit.script 6165 def test_not_none(x): 6166 # type: (Optional[int]) -> None 6167 if x is not None: 6168 print(x + 1) 6169 6170 @torch.jit.script 6171 def test_and(x, y): 6172 # type: (Optional[int], Optional[int]) -> None 6173 if x is not None and y is not None: 6174 print(x + y) 6175 6176 @torch.jit.script 6177 def test_not(x, y): 6178 # type: (Optional[int], Optional[int]) -> None 6179 if not (x is not None and y is not None): 6180 pass 6181 else: 6182 print(x + y) 6183 6184 @torch.jit.script 6185 def test_bool_expression(x): 6186 # type: (Optional[int]) -> None 6187 if x is not None and x < 2: 6188 print(x + 1) 6189 6190 @torch.jit.script 6191 def test_nested_bool_expression(x, y): 6192 # type: (Optional[int], Optional[int]) -> int 6193 if x is not None and x < 2 and y is not None: 6194 x = x + y 6195 else: 6196 x = 5 6197 return x + 2 6198 6199 @torch.jit.script 6200 def test_or(x, y): 6201 # type: (Optional[int], Optional[int]) -> None 6202 if y is None or x is None: 6203 pass 6204 else: 6205 print(x + y) 6206 6207 # backwards compatibility 6208 @torch.jit.script 6209 def test_manual_unwrap_opt(x): 6210 # type: (Optional[int]) -> int 6211 if x is None: 6212 x = 1 6213 else: 6214 x = torch.jit._unwrap_optional(x) 6215 return x # noqa: T484 6216 6217 with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"): 6218 @torch.jit.script 6219 def or_error(x, y): 6220 # type: (Optional[int], Optional[int]) -> None 6221 if x is None or y is None: 6222 print(x + y) # noqa: T484 6223 6224 with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"): 6225 @torch.jit.script 6226 def and_error(x, y): 6227 # type: (Optional[int], Optional[int]) -> None 6228 if x is None and y is None: 6229 pass 6230 else: 6231 print(x + y) # noqa: T484 6232 6233 with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"): 6234 @torch.jit.script 6235 def named_var(x): 6236 # type: (Optional[int]) -> None 6237 x_none = x is not None 6238 if x_none: 6239 print(x + 1) # noqa: T484 6240 6241 with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"): 6242 @torch.jit.script 6243 def named_var_and(x, y): 6244 # type: (Optional[int], Optional[int]) -> None 6245 x_none = x is not None 6246 if y is not None and x_none: 6247 print(x + y) # noqa: T484 6248 6249 def test_assertion_optional_refinement(self): 6250 @torch.jit.script 6251 def test(x, y): 6252 # type: (Optional[int], Optional[int]) -> int 6253 assert x is not None and y is not None 6254 return x + y 6255 6256 self.assertEqual(test(2, 2), 4) 6257 with self.assertRaisesRegex(Exception, ""): 6258 test(1, None) 6259 6260 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "the current version of Profiler doesn't profile/specialize Optionals") 6261 def test_optional_tensor(self): 6262 @torch.jit.script 6263 def fn(x, y): 6264 # type: (Optional[Tensor], int) -> int 6265 if x is None: 6266 return y 6267 else: 6268 return 0 6269 6270 res = fn(None, 1) 6271 self.assertEqual(res, 1) 6272 g = torch.jit.last_executed_optimized_graph() 6273 first_input = next(g.inputs()) 6274 # check if input is disconnected 6275 self.assertEqual(first_input.type().kind(), 'OptionalType') 6276 self.assertEqual(first_input.uses(), []) 6277 t = torch.ones(1) 6278 res = fn(t, 1) 6279 self.assertEqual(res, 0) 6280 g = torch.jit.last_executed_optimized_graph() 6281 self.assertEqual(next(g.inputs()).type().kind(), 'TensorType') 6282 6283 @torch.jit.script 6284 def fn(x, y, b): 6285 # type: (Optional[Tensor], Tensor, bool) -> Tensor 6286 if b: 6287 res = y 6288 else: 6289 res = torch.jit._unwrap_optional(x) 6290 return res 6291 6292 t2 = torch.zeros(1) 6293 res = fn(t, t2, True) 6294 self.assertEqual(res, t2) 6295 with self.assertRaisesRegex(RuntimeError, "Unwrapping null optional"): 6296 res = fn(None, t2, False) 6297 res = fn(None, t2, True) 6298 g = torch.jit.last_executed_optimized_graph() 6299 self.assertIn(next(g.outputs()).type().str(), ("Tensor", "Tensor(requires_grad=1)")) 6300 6301 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "the current version of Profiler doesn't profile/specialize Optionals") 6302 def test_optional_list(self): 6303 @torch.jit.script 6304 def fn(x, y): 6305 # type: (Optional[List[int]], int) -> int 6306 if x is None: 6307 return y 6308 else: 6309 res = 0 6310 for d in x: 6311 res += d 6312 return res 6313 6314 res = fn(None, 1) 6315 self.assertEqual(res, 1) 6316 g = torch.jit.last_executed_optimized_graph() 6317 first_input = next(g.inputs()) 6318 # check if input is disconnected 6319 self.assertEqual(first_input.type().kind(), 'OptionalType') 6320 self.assertEqual(first_input.uses(), []) 6321 l = [2, 3] 6322 res = fn(l, 1) 6323 self.assertEqual(res, 5) 6324 g = torch.jit.last_executed_optimized_graph() 6325 self.assertEqual(next(g.inputs()).type().kind(), 'ListType') 6326 6327 @torch.jit.script 6328 def fn(x, y, b): 6329 # type: (Optional[List[int]], List[int], bool) -> List[int] 6330 if b: 6331 l = torch.jit._unwrap_optional(x) 6332 else: 6333 l = y 6334 return l 6335 6336 l2 = [0, 1] 6337 res = fn(l, l2, True) 6338 self.assertEqual(res, l) 6339 with self.assertRaisesRegex(RuntimeError, "Unwrapping null optional"): 6340 res = fn(None, l2, True) 6341 res = fn(None, l2, False) 6342 g = torch.jit.last_executed_optimized_graph() 6343 self.assertEqual(next(g.outputs()).type().str(), "int[]") 6344 6345 def test_alias_covariant_type_containers(self): 6346 @torch.jit.script 6347 def foo(x): 6348 # type: (bool) 6349 if x: 6350 a = (None,) 6351 else: 6352 a = ([],) 6353 return a 6354 6355 @torch.jit.script 6356 def foo2(x, li): 6357 # type: (bool, Tuple[Optional[List[Tensor]]]) 6358 if x: 6359 li = (None,) 6360 return li 6361 6362 def test_while_write_outer_then_read(self): 6363 def func(a, b): 6364 while bool(a < 10): 6365 a = a + 1 6366 b = a + 1 6367 return a + b 6368 6369 inputs = self._make_scalar_vars([42, 1337], torch.int64) 6370 self.checkScript(func, inputs, optimize=True) 6371 6372 @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 6373 def test_while_nest_if(self): 6374 def func(a, b): 6375 # type: (int, int) -> int 6376 c = 0 6377 while a < 10: 6378 a = a + 1 6379 b = b + 1 6380 if a > b: 6381 c = -a 6382 else: 6383 c = -b 6384 return c + 1 6385 6386 inputs = self._make_scalar_vars([-1234, 4321], torch.int64) 6387 self.checkScript(func, inputs, optimize=True) 6388 6389 def test_divmod(self): 6390 def func_int(a, b): 6391 # type: (int, int) -> Tuple[int, int] 6392 return divmod(a, b) 6393 6394 def func_float(a, b): 6395 # type: (float, float) -> Tuple[float, float] 6396 return divmod(a, b) 6397 6398 def func_int_float(a, b): 6399 # type: (int, float) -> Tuple[float, float] 6400 return divmod(a, b) 6401 6402 def func_float_int(a, b): 6403 # type: (float, int) -> Tuple[float, float] 6404 return divmod(a, b) 6405 6406 def divmod_test_iterator(func, num, den): 6407 for i in num: 6408 for j in den: 6409 self.checkScript(func, (i, j), frames_up=2) 6410 6411 num_int = [1024, -1024] 6412 den_int = [10, -10] 6413 num_float = [5.3, -5.3] 6414 den_float = [2.0, -2.0] 6415 divmod_test_iterator(func_int, num_int, den_int) 6416 divmod_test_iterator(func_float, num_float, den_float) 6417 divmod_test_iterator(func_int_float, num_int, den_float) 6418 divmod_test_iterator(func_float_int, num_float, den_int) 6419 6420 with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError: integer division or modulo by zero"): 6421 cu = torch.jit.CompilationUnit(dedent(inspect.getsource(func_int))) 6422 cu.func_int(1024, 0) 6423 with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError: float divmod()"): 6424 cu = torch.jit.CompilationUnit(dedent(inspect.getsource(func_float))) 6425 cu.func_float(5.3, 0.0) 6426 with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError: float divmod()"): 6427 cu = torch.jit.CompilationUnit(dedent(inspect.getsource(func_int_float))) 6428 cu.func_int_float(1024, 0.0) 6429 with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError: float divmod()"): 6430 cu = torch.jit.CompilationUnit(dedent(inspect.getsource(func_float_int))) 6431 cu.func_float_int(5.3, 0) 6432 6433 @skipIfTorchDynamo("Not a TorchDynamo suitable test") 6434 def test_math_ops(self): 6435 def checkMathWrap(func_name, num_args=1, is_float=True, **args): 6436 if is_float: 6437 checkMath(func_name, num_args, True, **args) 6438 checkMath(func_name, num_args, False, **args) 6439 else: 6440 checkMath(func_name, num_args, is_float, **args) 6441 6442 inf = float("inf") 6443 NaN = float("nan") 6444 mx_int = 2**31 - 1 6445 mn_int = -2**31 6446 float_vals = ([inf, NaN, 0.0, 1.0, 2.2, -1.0, -0.0, -2.2, -inf, 1, 0, 2] + 6447 [10.0 ** i for i in range(5)] + [-(10.0 ** i) for i in range(5)]) 6448 int_vals = list(range(-5, 5, 1)) + [mx_int + 5, mx_int * 2, mn_int - 5, mn_int * 2] 6449 6450 def checkMath(func_name, num_args, is_float=True, ret_type="float", debug=False, vals=None, args_type=None): 6451 funcs_template = dedent(''' 6452 def func(a, b): 6453 # type: {args_type} -> {ret_type} 6454 return math.{func}({args}) 6455 ''') 6456 if num_args == 1: 6457 args = "a" 6458 elif num_args == 2: 6459 args = "a, b" 6460 else: 6461 raise RuntimeError("Test doesn't support more than 2 arguments") 6462 if args_type is None: 6463 args_type = "(float, float)" if is_float else "(int, int)" 6464 funcs_str = funcs_template.format(func=func_name, args=args, args_type=args_type, ret_type=ret_type) 6465 scope = {} 6466 execWrapper(funcs_str, globals(), scope) 6467 cu = torch.jit.CompilationUnit(funcs_str) 6468 f_script = cu.func 6469 f = scope['func'] 6470 6471 if vals is None: 6472 vals = float_vals if is_float else int_vals 6473 vals = [(i, j) for i in vals for j in vals] 6474 6475 for a, b in vals: 6476 res_python = None 6477 res_script = None 6478 try: 6479 res_python = f(a, b) 6480 except Exception as e: 6481 res_python = e 6482 try: 6483 res_script = f_script(a, b) 6484 except Exception as e: 6485 res_script = e 6486 if debug: 6487 print("in: ", a, b) 6488 print("out: ", res_python, res_script) 6489 # We can't use assertEqual because of a couple of differences: 6490 # 1. nan == nan should return true 6491 # 2. When python functions throw an exception, we usually want to silently ignore them. 6492 # (ie: We want to return `nan` for math.sqrt(-5)) 6493 if res_python != res_script: 6494 if isinstance(res_python, Exception): 6495 continue 6496 6497 if type(res_python) == type(res_script): 6498 if isinstance(res_python, tuple) and (math.isnan(res_python[0]) == math.isnan(res_script[0])): 6499 continue 6500 if isinstance(res_python, float) and math.isnan(res_python) and math.isnan(res_script): 6501 continue 6502 msg = (f"Failed on {func_name} with inputs {a} {b}. Python: {res_python}, Script: {res_script}") 6503 # math.pow() behavior has changed in 3.11, see https://docs.python.org/3/library/math.html#math.pow 6504 if sys.version_info >= (3, 11) and func_name == "pow" and a == 0.0 and b == -math.inf: 6505 self.assertTrue(res_python == math.inf and type(res_script) is RuntimeError) 6506 else: 6507 self.assertEqual(res_python, res_script, msg=msg, atol=(1e-4) * max(abs(res_python), res_script), rtol=0) 6508 6509 unary_float_ops = ["log", "log1p", "log10", "exp", "sqrt", "gamma", "lgamma", "erf", 6510 "erfc", "expm1", "fabs", "acos", "asin", "atan", "cos", "sin", "tan", 6511 "asinh", "atanh", "acosh", "sinh", "cosh", "tanh", "degrees", "radians"] 6512 binary_float_ops = ["atan2", "fmod", "copysign"] 6513 for op in unary_float_ops: 6514 checkMathWrap(op, 1) 6515 for op in binary_float_ops: 6516 checkMathWrap(op, 2) 6517 6518 checkMath("modf", 1, ret_type="Tuple[float, float]") 6519 checkMath("frexp", 1, ret_type="Tuple[float, int]") 6520 checkMath("isnan", 1, ret_type="bool") 6521 checkMath("isinf", 1, ret_type="bool") 6522 checkMath("ldexp", 2, is_float=False, ret_type="float", args_type="(float, int)", 6523 vals=[(i, j) for i in float_vals for j in range(-10, 10)]) 6524 checkMath("pow", 2, is_float=False, ret_type="float") 6525 checkMath("pow", 2, is_float=True, ret_type="float") 6526 checkMathWrap("floor", ret_type="int") 6527 checkMathWrap("ceil", ret_type="int") 6528 checkMathWrap("gcd", 2, is_float=False, ret_type="int") 6529 checkMath("isfinite", 1, ret_type="bool") 6530 checkMathWrap("remainder", 2) 6531 checkMathWrap("factorial", 1, is_float=False, ret_type="int", vals=[(i, 0) for i in range(-2, 10)]) 6532 6533 @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 6534 def test_if_nest_while(self): 6535 def func(a, b): 6536 # type: (int, int) -> int 6537 c = 0 6538 if a > b: 6539 while a > b: 6540 b = b + 1 6541 c = -b 6542 return c 6543 6544 inputs = self._make_scalar_vars([4321, 1234], torch.int64) 6545 self.checkScript(func, inputs) 6546 6547 def test_script_optional_none(self): 6548 def none_stmt(x): 6549 output = None 6550 output = x 6551 return output 6552 6553 def none_args(x): 6554 # type: (Optional[Tensor]) -> Optional[Tensor] 6555 return None 6556 6557 self.checkScript(none_stmt, [torch.arange(0, 2)], optimize=True) 6558 self.checkScript(none_args, [None], optimize=True) 6559 6560 # test undefined tensor None as default param 6561 def test_script_optional_tensor_none(x=None): 6562 # type: (Optional[Tensor]) -> Tensor 6563 res = torch.zeros(1, dtype=torch.int8) 6564 if x is None: 6565 res = res + 1 6566 else: 6567 res = x 6568 return res 6569 6570 fn = test_script_optional_tensor_none 6571 scripted_fn = torch.jit.script(fn) 6572 self.assertEqual(fn(), scripted_fn()) 6573 self.assertEqual(fn(torch.zeros(1)), scripted_fn(torch.zeros(1))) 6574 6575 # test typical None as default param 6576 def test_script_optional_other_none(x=None): 6577 # type: (Optional[float]) -> float 6578 res = 2.0 6579 if x is None: 6580 res = res + 1.0 6581 else: 6582 res = x 6583 return res 6584 6585 fn = test_script_optional_other_none 6586 scripted_fn = torch.jit.script(fn) 6587 self.assertEqual(fn(), scripted_fn()) 6588 self.assertEqual(fn(1.0), scripted_fn(1.0)) 6589 6590 def test_script_clamp_none(self): 6591 def test_script_clamp_max_none(x): 6592 return torch.clamp(x, min=2, max=None) 6593 6594 def test_script_clamp_max(x): 6595 return torch.clamp(x, max=2) 6596 6597 def test_script_clamp_min_none(x): 6598 return torch.clamp(x, min=None, max=2) 6599 6600 def test_script_clamp_min(x): 6601 return torch.clamp(x, min=2) 6602 6603 input = [torch.arange(0, 3)] 6604 self.checkScript(test_script_clamp_max_none, input, optimize=True) 6605 self.checkScript(test_script_clamp_max, input, optimize=True) 6606 self.checkScript(test_script_clamp_min_none, input, optimize=True) 6607 self.checkScript(test_script_clamp_min, input, optimize=True) 6608 6609 def test_script_bool_constant(self): 6610 def test_script_bool_constant(): 6611 a = True 6612 return a 6613 self.checkScript(test_script_bool_constant, []) 6614 6615 def test_ternary(self): 6616 def func(a, b): 6617 c = 3 6618 c = a + b if bool(a > 3) else b 6619 return c 6620 6621 inputs_true = self._make_scalar_vars([5, 2], torch.int64) 6622 inputs_false = self._make_scalar_vars([1, 0], torch.int64) 6623 self.checkScript(func, inputs_true, optimize=True) 6624 self.checkScript(func, inputs_false, optimize=True) 6625 6626 def test_ternary_module_type_hint(self): 6627 class M1(torch.nn.Module): 6628 def forward(self) -> Any: 6629 return 'out' if self.training else {} 6630 6631 class M2(torch.nn.Module): 6632 def forward(self) -> Any: 6633 out: Any = 'out' if self.training else {} 6634 return out 6635 6636 class M3(torch.nn.Module): 6637 def forward(self) -> Optional[int]: 6638 return None if self.training else 1 6639 6640 for module in [M1, M2, M3]: 6641 self.checkModule(module().train(), ()) 6642 self.checkModule(module().eval(), ()) 6643 6644 def test_ternary_static_if(self): 6645 # Test for True branch when condition variable 6646 # is annotated as Final 6647 class M1(torch.nn.Module): 6648 flag: torch.jit.Final[bool] 6649 6650 def __init__(self) -> None: 6651 super().__init__() 6652 self.flag = True 6653 6654 def forward(self) -> torch.Tensor: 6655 return torch.ones(3) if self.flag else {} 6656 6657 # Test for True branch when condition variable 6658 # is annotated as Final 6659 class M2(torch.nn.Module): 6660 flag: torch.jit.Final[bool] 6661 6662 def __init__(self) -> None: 6663 super().__init__() 6664 self.flag = False 6665 6666 def forward(self) -> torch.Tensor: 6667 return {} if self.flag else torch.ones(3) 6668 6669 model1 = M1() 6670 model2 = M2() 6671 script_model_1 = torch.jit.script(model1) 6672 script_model_2 = torch.jit.script(model2) 6673 self.assertEqual(model1.forward(), script_model_1.forward()) 6674 self.assertEqual(model2.forward(), script_model_2.forward()) 6675 6676 def test_ternary_right_associative(self): 6677 def plus_123(x: int): 6678 return x + 1 if x == 1 else x + 2 if x == 2 else x + 3 6679 self.checkScript(plus_123, (1,)) 6680 self.checkScript(plus_123, (2,)) 6681 self.checkScript(plus_123, (3,)) 6682 6683 @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 6684 def test_print(self): 6685 def func(x, y): 6686 q = (x + y).sigmoid() 6687 print(q, 1, 2, [1, 2], [1.0, 2.0]) 6688 w = -q 6689 return w * w 6690 6691 x = torch.arange(4., requires_grad=True) 6692 y = torch.arange(0., 8, 2, requires_grad=True) 6693 self.checkScript(func, [x, y], optimize=True, capture_output=True) 6694 6695 def test_format(self): 6696 def func(x): 6697 print("{}, I'm a {}".format("Hello", "test")) 6698 print("format blank".format()) 6699 print("stuff before {}".format("hi")) 6700 print("{} stuff after".format("hi")) 6701 return x + 1 6702 6703 x = torch.arange(4., requires_grad=True) 6704 self.checkScript(func, [x], optimize=True, capture_output=True) 6705 6706 def test_logical_short_circuit(self): 6707 @torch.jit.script 6708 def testNoThrows(t): 6709 c1 = 1 6710 if (False and bool(t[1])) or (True or bool(t[1])): 6711 c1 = 0 6712 return c1 6713 6714 FileCheck().check_not("prim::If").run(testNoThrows.graph) 6715 self.assertEqual(0, testNoThrows(torch.randn(0))) 6716 self.assertEqual(0, testNoThrows(torch.randn([2, 3]))) 6717 6718 @torch.jit.script 6719 def throwsOr(t): 6720 c0 = False or bool(t[1]) 6721 print(c0) 6722 6723 @torch.jit.script 6724 def throwsAnd(t): 6725 c0 = True and bool(t[1]) 6726 print(c0) 6727 6728 t = torch.randn(0) 6729 with self.assertRaisesRegex(RuntimeError, "index 1 out of range for tensor of size"): 6730 throwsOr(t) 6731 with self.assertRaisesRegex(RuntimeError, "index 1 out of range for tensor of size"): 6732 throwsAnd(t) 6733 6734 def test_type_cast(self): 6735 template = dedent(''' 6736 def func(v): 6737 # type: ({from_type}) -> {to_type} 6738 return {to_type}(v) 6739 ''') 6740 6741 def check_cast(from_type, to_type, value, raises=False): 6742 code = template.format(from_type=from_type, to_type=to_type) 6743 self.checkScript(code, (value,)) 6744 6745 check_cast('int', 'float', 1) 6746 check_cast('int', 'bool', 1) 6747 check_cast('int', 'bool', 0) 6748 6749 check_cast('float', 'int', 1.) 6750 check_cast('float', 'bool', 1.) 6751 check_cast('float', 'bool', 0.) 6752 6753 check_cast('bool', 'int', True) 6754 check_cast('bool', 'float', True) 6755 6756 def test_multiple_assignment(self): 6757 def outer_func(x): 6758 return x * 2, x + 2 6759 6760 @torch.jit.script 6761 def func(x): 6762 y, z = outer_func(x) 6763 return y + z 6764 6765 x = torch.arange(4) 6766 self.assertEqual(func(x), x * 2 + x + 2) 6767 6768 def test_literals(self): 6769 def func(a): 6770 return a.view(size=[1, 2, 3]) 6771 6772 a = torch.randn(6) 6773 self.checkScript(func, [a], optimize=True) 6774 6775 def test_return(self): 6776 def no_return(a): 6777 a + 1 6778 6779 def void_return(a): 6780 return 6781 6782 def one_return(a): 6783 return a + 1. 6784 6785 def multiple_returns(a): 6786 return a * 1., a * 2., a * 3. 6787 6788 a = torch.randn(1, dtype=torch.float) 6789 self.checkScript(no_return, [a], optimize=True) 6790 self.checkScript(void_return, [a], optimize=True) 6791 self.checkScript(one_return, [a], optimize=True) 6792 self.checkScript(multiple_returns, [a], optimize=True) 6793 6794 with self.assertRaisesRegex(RuntimeError, "does not return along all paths"): 6795 torch.jit.CompilationUnit(''' 6796 def no_return_bad_annotation(a): 6797 # type: (Tensor) -> Tensor 6798 a + 1 6799 ''') 6800 6801 def test_error(self): 6802 @torch.jit.script 6803 def foo(a): 6804 return a.t() 6805 s = Variable(torch.rand(5, 5, 5)) 6806 # XXX: this should stay quiet in stay propagation and only fail in the interpreter 6807 with self.assertRaisesRegex(RuntimeError, "failed in the TorchScript interpreter"): 6808 foo(s) 6809 6810 @torch.jit.script 6811 def bar(c, b): 6812 return c + b 6813 6814 with self.assertRaisesRegex(RuntimeError, "failed in the TorchScript interpreter"): 6815 bar(Variable(torch.rand(10), requires_grad=True), Variable(torch.rand(9), requires_grad=True)) 6816 6817 def test_error_stacktrace(self): 6818 @torch.jit.script 6819 def baz(c, b): 6820 return c + b 6821 6822 @torch.jit.script 6823 def foo(c, b): 6824 return baz(c, b) 6825 6826 @torch.jit.script 6827 def bar(c, b): 6828 return foo(c, b) 6829 6830 with self.assertRaises(RuntimeError) as cm: 6831 bar(torch.rand(10), torch.rand(9)) 6832 FileCheck().check("The following operation failed in the TorchScript interpreter") \ 6833 .check("Traceback") \ 6834 .check("in foo").check("in baz").run(str(cm.exception)) 6835 6836 def test_error_stacktrace_interface(self): 6837 @torch.jit.script 6838 def baz(c, b): 6839 return c + b 6840 6841 @torch.jit.script 6842 def foo(c, b): 6843 return baz(c, b) 6844 6845 @torch.jit.script 6846 def bar(c, b): 6847 return foo(c, b) 6848 6849 @torch.jit.script 6850 class Bar: 6851 def one(self, x, y): 6852 return bar(x, y) 6853 6854 @torch.jit.interface 6855 class IFace: 6856 def one(self, x, y): 6857 # type: (Tensor, Tensor) -> Tensor 6858 pass 6859 6860 make_global(IFace) 6861 6862 @torch.jit.script 6863 def as_interface(x): 6864 # type: (IFace) -> IFace 6865 return x 6866 6867 f = as_interface(Bar()) 6868 6869 with self.assertRaises(RuntimeError) as cm: 6870 x = f.one(torch.rand(10), torch.rand(9)) 6871 bar(torch.rand(10), torch.rand(9)) 6872 FileCheck().check("The following operation failed in the TorchScript interpreter") \ 6873 .check("Traceback") \ 6874 .check("in foo").check("in baz").run(str(cm.exception)) 6875 6876 def test_operator_precedence(self): 6877 def double(x): 6878 # type: (int) -> int 6879 return 2 * x 6880 6881 def complicated_arithmetic_operation(): 6882 # TODO we need to test exponent operator '**' and bitwise not 6883 # operator '~' once they are properly supported. 6884 list = [0, 1, 2, 3] 6885 result = list[1:3][0] + double(4) + (-3 + 8) * 6 // 2 % 4 << 2 + 1 >> 1 | 23 & 16 + 3 ^ 4 6886 return result 6887 6888 self.checkScript(complicated_arithmetic_operation, ()) 6889 6890 def test_in_operator_with_two_strings(self): 6891 def fn() -> bool: 6892 return "a" in "abcd" 6893 self.checkScript(fn, ()) 6894 6895 def test_bitwise_ops(self): 6896 6897 def int_test(): 6898 return 2 & 3, 2 ^ 3, 2 | 3, 2 << 3, 2 >> 3 6899 6900 self.checkScript(int_test, ()) 6901 6902 def bool_test(x, y): 6903 # type: (bool, bool) -> Tuple[bool, bool, bool] 6904 return x & y, x ^ y, x | y 6905 6906 self.checkScript(bool_test, (True, False)) 6907 self.checkScript(bool_test, (True, True)) 6908 6909 def tensor_test(x, y): 6910 return x & y, x ^ y, x | y 6911 6912 def tensor_with_int_test(x, y): 6913 # type: (Tensor, int) -> Tuple[Tensor, Tensor] 6914 return x << y, x >> y 6915 6916 x = torch.tensor(2) 6917 y = torch.tensor(3) 6918 6919 self.checkScript(tensor_test, (x, y)) 6920 self.checkScript(tensor_with_int_test, (x, 2)) 6921 6922 def not_test(x): 6923 return ~x 6924 6925 self.checkScript(not_test, (torch.tensor([2, 4]), )) 6926 6927 def test_all(self): 6928 @torch.jit.script 6929 def test_all_tensor(x): 6930 return all(x) 6931 self.assertFalse(test_all_tensor(torch.tensor([1, 0, 3], dtype=torch.uint8))) 6932 self.assertTrue(test_all_tensor(torch.tensor([3.14, 3, 99], dtype=torch.uint8))) 6933 self.assertTrue(test_all_tensor(torch.tensor([True, True], dtype=torch.uint8))) 6934 self.assertFalse(test_all_tensor(torch.tensor([True, False], dtype=torch.uint8))) 6935 6936 @torch.jit.script 6937 def test_all_bool_list(x): 6938 # type: (List[bool]) -> bool 6939 return all(x) 6940 self.assertTrue(test_all_bool_list([True, True])) 6941 self.assertTrue(test_all_bool_list([True, 1])) 6942 self.assertFalse(test_all_bool_list([True, False])) 6943 self.assertFalse(test_all_bool_list([True, 0])) 6944 self.assertFalse(test_all_bool_list([False, 0])) 6945 self.assertTrue(test_all_bool_list([])) 6946 6947 @torch.jit.script 6948 def test_all_int_list(x): 6949 # type: (List[int]) -> bool 6950 return all(x) 6951 self.assertTrue(test_all_int_list([3, 6])) 6952 self.assertFalse(test_all_int_list([2, 0])) 6953 6954 @torch.jit.script 6955 def test_all_float_list(x): 6956 # type: (List[float]) -> bool 6957 return all(x) 6958 self.assertTrue(test_all_float_list([3.14, 8.1])) 6959 self.assertFalse(test_all_float_list([3.14, 0, 8.9])) 6960 6961 6962 @skipIfTorchDynamo("Not a TorchDynamo suitable test") 6963 def test_number_math(self): 6964 ops_template = dedent(''' 6965 def func(): 6966 return {scalar1} {op} {scalar2} 6967 ''') 6968 ops = ['+', '-', '*', '%', '<', '<=', '>', '>=', '==', '!=', '//'] 6969 funcs_template = dedent(''' 6970 def func(): 6971 return {func}({scalar1}, {scalar2}) 6972 ''') 6973 funcs = ['min', 'max'] 6974 scalars = ['7', '2', '3', '-3', '3.14', '0.125', '-0.5', '2.0', '-2.0'] 6975 scalar_pairs = [(scalar1, scalar2) for scalar1 in scalars for scalar2 in scalars] 6976 6977 def run_test(code): 6978 scope = {} 6979 execWrapper(code, globals(), scope) 6980 cu = torch.jit.CompilationUnit(code) 6981 6982 self.assertEqual(cu.func(), scope['func']()) 6983 6984 for scalar1, scalar2 in scalar_pairs: 6985 for op in ops: 6986 code = ops_template.format(op=op, scalar1=scalar1, scalar2=scalar2) 6987 run_test(code) 6988 for func in funcs: 6989 code = funcs_template.format(func=func, scalar1=scalar1, scalar2=scalar2) 6990 run_test(code) 6991 6992 # test Scalar overloads 6993 for scalar1, scalar2 in scalar_pairs: 6994 item1 = 'torch.tensor(' + scalar1 + ').item()' 6995 item2 = 'torch.tensor(' + scalar2 + ').item()' 6996 for op in ops: 6997 code = ops_template.format(op=op, scalar1=item1, scalar2=scalar2) 6998 run_test(code) 6999 code = ops_template.format(op=op, scalar1=scalar1, scalar2=item2) 7000 run_test(code) 7001 code = ops_template.format(op=op, scalar1=item1, scalar2=item2) 7002 run_test(code) 7003 for func in funcs: 7004 code = funcs_template.format(func=func, scalar1=item1, scalar2=scalar2) 7005 run_test(code) 7006 code = funcs_template.format(func=func, scalar1=scalar1, scalar2=item2) 7007 run_test(code) 7008 code = funcs_template.format(func=func, scalar1=item1, scalar2=item2) 7009 run_test(code) 7010 7011 def test_number_abs(self): 7012 def func1(x): 7013 # type: (float) -> float 7014 return abs(x) 7015 7016 def func2(x): 7017 # type: (int) -> int 7018 return abs(x) 7019 7020 def func3(x): 7021 return abs(x) 7022 7023 self.checkScript(func1, (-3.14,)) 7024 self.checkScript(func1, (3.14,)) 7025 self.checkScript(func2, (-10,)) 7026 self.checkScript(func2, (10,)) 7027 self.checkScript(func3, (torch.tensor([-5, -10, -20]),)) 7028 self.checkScript(func3, (torch.tensor([5, 10, 20]),)) 7029 self.checkScript(func3, (torch.tensor([-5, 10, -20]),)) 7030 7031 def test_number_div(self): 7032 self.assertEqual(div_int_future(), torch.jit.script(div_int_future)()) 7033 self.checkScript(div_float_future, ()) 7034 7035 self.checkScript(div_int_nofuture, ()) 7036 self.checkScript(div_float_nofuture, ()) 7037 7038 # Testing bitwise shorthand aug assignment 7039 def test_bool_augassign_bitwise_or(self): 7040 def func(a: bool, b: bool) -> bool: 7041 a |= b 7042 return a 7043 7044 self.checkScript(func, (True, False), optimize=True) 7045 self.checkScript(func, (True, True), optimize=True) 7046 self.checkScript(func, (False, False), optimize=True) 7047 self.checkScript(func, (False, True), optimize=True) 7048 7049 def test_bool_augassign_bitwise_and(self): 7050 def func(a: bool, b: bool) -> bool: 7051 a &= b 7052 return a 7053 7054 self.checkScript(func, (True, False), optimize=True) 7055 self.checkScript(func, (True, True), optimize=True) 7056 self.checkScript(func, (False, False), optimize=True) 7057 self.checkScript(func, (False, True), optimize=True) 7058 7059 def test_bool_augassign_bitwise_xor(self): 7060 def func(a: bool, b: bool) -> bool: 7061 a ^= b 7062 return a 7063 7064 self.checkScript(func, (True, False), optimize=True) 7065 self.checkScript(func, (True, True), optimize=True) 7066 self.checkScript(func, (False, False), optimize=True) 7067 self.checkScript(func, (False, True), optimize=True) 7068 7069 def test_number_augassign_bitwise_lshift(self): 7070 def func() -> int: 7071 z = 8 7072 z <<= 2 7073 return z 7074 7075 self.checkScript(func, (), optimize=True) 7076 7077 def test_number_augassign_bitwise_rshift(self): 7078 def func() -> int: 7079 z = 8 7080 z >>= 2 7081 return z 7082 7083 self.checkScript(func, (), optimize=True) 7084 7085 def test_number_augassign_bitwise_pow(self): 7086 def func() -> float: 7087 z = 8 7088 z **= 2 7089 return z 7090 7091 self.checkScript(func, (), optimize=True) 7092 7093 def test_number_augassign(self): 7094 def func(): 7095 z = 1 7096 z += 2 7097 return z 7098 7099 self.checkScript(func, (), optimize=True) 7100 7101 def test_nested_select_assign(self): 7102 class SubSubModule(torch.nn.Module): 7103 def __init__(self) -> None: 7104 super().__init__() 7105 self.abc = 11 7106 7107 def forward(self, x): 7108 return self.abc 7109 7110 class SubModule(torch.nn.Module): 7111 def __init__(self) -> None: 7112 super().__init__() 7113 self.a = 11 7114 self.nested = SubSubModule() 7115 7116 def forward(self, x): 7117 return self.a 7118 7119 class TestModule(torch.nn.Module): 7120 def __init__(self) -> None: 7121 super().__init__() 7122 self.sub = SubModule() 7123 self.hi = 1 7124 7125 def forward(self): 7126 self.hi = 5 7127 self.sub.a = 1 7128 self.sub.nested.abc = 5 7129 return self.sub.a * 20 + self.sub.nested.abc * 3 + self.hi 7130 7131 self.checkModule(TestModule(), ()) 7132 7133 def test_number_neg(self): 7134 # int -> int 7135 def func1(): 7136 return -8 7137 7138 # float -> float 7139 def func2(): 7140 return -3.14 7141 7142 self.checkScript(func1, (), optimize=True) 7143 self.checkScript(func2, (), optimize=True) 7144 7145 def test_compare_two_bool_inputs(self): 7146 def compare_eq(a: bool, b: bool): 7147 return a == b 7148 7149 def compare_ne(a: bool, b: bool): 7150 return a != b 7151 7152 scripted_fn_eq = torch.jit.script(compare_eq) 7153 scripted_fn_ne = torch.jit.script(compare_ne) 7154 self.assertEqual(scripted_fn_eq(True, False), compare_eq(True, False)) 7155 self.assertEqual(scripted_fn_eq(False, True), compare_eq(False, True)) 7156 self.assertEqual(scripted_fn_eq(True, True), compare_eq(True, True)) 7157 self.assertEqual(scripted_fn_eq(False, False), compare_eq(False, False)) 7158 7159 self.assertEqual(scripted_fn_ne(True, False), compare_ne(True, False)) 7160 self.assertEqual(scripted_fn_ne(False, True), compare_ne(False, True)) 7161 self.assertEqual(scripted_fn_ne(True, True), compare_ne(True, True)) 7162 self.assertEqual(scripted_fn_ne(False, False), compare_ne(False, False)) 7163 7164 7165 def _test_tensor_number_math(self, device='cpu'): 7166 template = dedent(''' 7167 def func(t): 7168 return {lhs} {op} {rhs} 7169 ''') 7170 7171 def test(op, tensor, const, swap_args, template=template): 7172 args = ('t', const) 7173 if swap_args: 7174 args = (const, 't') 7175 7176 code = template.format(lhs=args[0], rhs=args[1], op=op) 7177 scope = {} 7178 execWrapper(code, globals(), scope) 7179 cu = torch.jit.CompilationUnit(code) 7180 message = f'with code `{args[0]} {op} {args[1]}` and t={tensor}' 7181 res1 = cu.func(tensor) 7182 res2 = scope['func'](tensor) 7183 self.assertEqual(res1, res2, msg=message + "\nres1=" + str(res1) + "\nres2=" + str(res2)) 7184 self.assertEqual(res1.dtype, res2.dtype, msg=message + "\nres1=" + str(res1) + "\nres2=" + str(res2)) 7185 7186 var_int = [2, -2] 7187 var_float = [1.4321, -1.2] 7188 7189 ops = ['+', '-', '*', '%', '<', '<=', '>', '>=', '==', '!=', '/'] 7190 7191 float_tensor = torch.randn(5, 5, device=device) 7192 double_tensor = torch.randn(5, 5, dtype=torch.double, device=device) 7193 long_tensor = torch.randint(-5, 5, (5, 5), dtype=torch.long, device=device) 7194 long_tensor[long_tensor == 0] = 2 7195 7196 tensors = [float_tensor, double_tensor, long_tensor] 7197 consts = var_int + var_float 7198 7199 for op, tensor, const, swap_args in product(ops, tensors, consts, [True, False]): 7200 # FIXME: things like 2 / long_tensor are not implemented correctly 7201 # Look in torch/_tensor.py to see how pytorch implements it. 7202 if op == '/' and tensor.data_ptr() == long_tensor.data_ptr(): 7203 continue 7204 7205 # % operator does not take: const % tensor 7206 if op == '%' and swap_args is True: 7207 continue 7208 7209 test(op, tensor, const, swap_args) 7210 7211 @skipIfTorchDynamo("Not a TorchDynamo suitable test") 7212 def test_tensor_number_math(self): 7213 self._test_tensor_number_math() 7214 7215 def test_torch_tensor_bad_input(self): 7216 with self.assertRaisesRegex(RuntimeError, "must be of ints, floats, " 7217 "or bools, got None"): 7218 @torch.jit.script 7219 def test(): 7220 return torch.tensor([None]) 7221 test() 7222 7223 with self.assertRaisesRegex(RuntimeError, r"Empty lists default to List\[Tensor\]"): 7224 @torch.jit.script 7225 def tmp(): 7226 return torch.tensor([]) 7227 tmp() 7228 7229 @torch.jit.script 7230 def foo(): 7231 return torch.tensor([[2, 2], [1]]) 7232 with self.assertRaisesRegex(RuntimeError, "Expected sequence of length"): 7233 foo() 7234 7235 @suppress_warnings 7236 def test_torch_tensor_as_tensor_empty_list(self): 7237 tensor_template = dedent(''' 7238 def func(): 7239 empty_list = torch.jit.annotate(List[int], []) 7240 ten1 = torch.{tensor_op}({input}) 7241 return ten1 7242 ''') 7243 ops = ['tensor', 'as_tensor'] 7244 inputs = ['empty_list', '[empty_list, empty_list]', '[[[empty_list]]]'] 7245 7246 for op in ops: 7247 for inp in inputs: 7248 code = tensor_template.format(tensor_op=op, input=inp) 7249 scope = {} 7250 exec(code, globals(), scope) 7251 cu = torch.jit.CompilationUnit(code) 7252 t1 = cu.func() 7253 t2 = scope['func']() 7254 if inp == 'empty_list': 7255 # torchscript returns int tensor, python returns float tensor 7256 self.assertNotEqual(t1.dtype, t2.dtype) 7257 self.assertEqual(t1, t2, exact_dtype=False) 7258 self.assertEqual(t1.device, t2.device) 7259 7260 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "Simple Executor doesn't have any shapes to propagate") 7261 def test_tensor_as_tensor_shape_prop(self): 7262 tensor_template = dedent(''' 7263 def func(): 7264 return torch.{tensor_op}({input}) 7265 ''') 7266 ops = ['tensor', 'as_tensor'] 7267 inputs = ['[1]', '[False]', '[2.5]', '0.5', '1', 'False', '[[1]]', 'torch.jit.annotate(List[List[int]], [])'] 7268 expected_shape = ["Long(*, device=cpu)", "Bool(*, device=cpu)", 7269 "Float(*, device=cpu)", "Float(device=cpu)", 7270 "Long(device=cpu)", "Bool(device=cpu)", "Long(*, *, device=cpu)"] 7271 7272 for op in ops: 7273 for inp, expect in zip(inputs, expected_shape): 7274 code = tensor_template.format(tensor_op=op, input=inp) 7275 scope = {} 7276 exec(code, globals(), scope) 7277 cu = torch.jit.CompilationUnit(code) 7278 torch._C._jit_pass_complete_shape_analysis(cu.func.graph, (), False) 7279 FileCheck().check(expect).check(f"aten::{op}").run(cu.func.graph) 7280 7281 @torch.jit.script 7282 def test_dtype(inp_dtype: torch.dtype): 7283 a = torch.tensor(1.0, dtype=torch.float, requires_grad=True) 7284 return a, torch.tensor(1.0, dtype=inp_dtype) 7285 7286 if GRAPH_EXECUTOR == ProfilingMode.PROFILING: 7287 g = test_dtype.graph_for(5, profile_and_replay=True) 7288 # both should have completed shapes 7289 FileCheck().check("Tensor = aten::tensor").check("Float(device=cpu) = prim::BailOut") \ 7290 .check("Tensor = aten::tensor").check("Half(device=cpu) = prim::BailOut").run(g) 7291 else: 7292 g = test_dtype.graph_for(5) 7293 # first should have type set second should not 7294 FileCheck().check("Float(requires_grad=1, device=cpu) = aten::tensor") \ 7295 .check("Tensor(requires_grad=0) = aten::tensor").run(g) 7296 7297 @torch.jit.script 7298 def test_as_tensor_tensor_input(input): 7299 a = torch.as_tensor(input, dtype=input.dtype) 7300 return a, torch.as_tensor(input, dtype=torch.float) 7301 7302 if GRAPH_EXECUTOR == ProfilingMode.PROFILING: 7303 g = test_as_tensor_tensor_input.graph_for(torch.ones(3, 4), profile_and_replay=True) 7304 FileCheck().check("Tensor = aten::as_tensor").check("Float(3, 4) = prim::BailOut") \ 7305 .check("Tensor = aten::as_tensor").check("Float(3, 4) = prim::BailOut").run(g) 7306 else: 7307 g = test_as_tensor_tensor_input.graph_for(torch.ones(3, 4)) 7308 FileCheck().check("Tensor = aten::as_tensor").check("Float(*, *, requires_grad=0, device=cpu) = aten::as_tensor").run(g) 7309 7310 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "testing legacy behavior") 7311 def test_tensor_requires_grad(self): 7312 @torch.jit.script 7313 def test(b): 7314 # type: (bool) -> Tuple[Tensor, Tensor, Tensor] 7315 a = torch.tensor(1., requires_grad=b) 7316 b = torch.tensor(1., requires_grad=True) 7317 c = torch.tensor(1., requires_grad=False) 7318 return a, b, c 7319 7320 g = test.graph_for(True) 7321 out = next(g.outputs()) 7322 out_inp = list(out.node().inputs()) 7323 7324 self.assertTrue(out_inp[0].requires_grad()) 7325 self.assertTrue(out_inp[1].requires_grad()) 7326 self.assertFalse(out_inp[2].requires_grad()) 7327 7328 def test_grad_from_script(self): 7329 def test(): 7330 a = torch.tensor(2.5, requires_grad=True) 7331 b = a * 2 7332 return a, b 7333 7334 a, b = test() 7335 b.backward() 7336 7337 a_script, b_script = torch.jit.script(test)() 7338 b_script.backward() 7339 self.assertEqual(a.grad, a_script.grad) 7340 7341 def test_torch_tensor_as_tensor(self): 7342 tensor_template = dedent(''' 7343 def func(): 7344 li = {list_create} 7345 ten1 = torch.{tensor_op}(li {options}) 7346 return ten1 7347 ''') 7348 7349 lists = ["2.5", "4", "True", "False", "[2]", "[-.5]", "[False, True, False]", "[2, 2]", "(1, 1)", 7350 "torch.jit.annotate(List[List[int]], [])", 7351 "torch.jit.annotate(List[int], [])", "[2.5, 2.5]", "[[2], [2]]", "[[-.5], [2.2]]", "[[False], [True]]"] 7352 7353 dtypes = ["", ", dtype=torch.float", ", dtype=torch.double", ", dtype=torch.half", 7354 ", dtype=torch.uint8", ", dtype=torch.int8", ", dtype=torch.short", 7355 ", dtype=torch.int", ", dtype=torch.long", ", dtype=torch.cfloat", 7356 ", dtype=torch.cdouble"] 7357 7358 ops = ['tensor', 'as_tensor'] 7359 devices = ['', ", device='cpu'"] 7360 if RUN_CUDA: 7361 devices.append(", device='cuda'") 7362 7363 option_pairs = [dtype + device for dtype in dtypes for device in devices] 7364 for op in ops: 7365 for li in lists: 7366 for option in option_pairs: 7367 # tensor from empty list is type float in python and annotated type in torchscript 7368 if "annotate" in li and "dtype" not in option: 7369 continue 7370 # Skip unsigned tensor initializaton for signed values on 3.10 7371 if sys.version_info[:2] >= (3, 10) and "torch.uint8" in option and "-" in li: 7372 continue 7373 code = tensor_template.format(list_create=li, tensor_op=op, options=option) 7374 scope = {} 7375 exec(code, globals(), scope) 7376 cu = torch.jit.CompilationUnit(code) 7377 t1 = cu.func() 7378 t2 = scope['func']() 7379 if t1.dtype == torch.float16: # equality NYI for half tensor 7380 self.assertTrue(str(t1) == str(t2)) 7381 else: 7382 self.assertEqual(t1, t2) 7383 self.assertEqual(t1.dtype, t2.dtype) 7384 self.assertEqual(t1.device, t2.device) 7385 7386 def test_as_tensor_tensor_input(input): 7387 # type: (Tensor) -> Tuple[Tensor, Tensor, Tensor] 7388 return torch.as_tensor(input, dtype=torch.cfloat), torch.as_tensor(input, dtype=torch.float), \ 7389 torch.as_tensor(input, dtype=torch.int32) 7390 7391 inp = torch.randn(3, 4, dtype=torch.cfloat) 7392 self.checkScript(test_as_tensor_tensor_input, (inp,)) 7393 7394 def test_torch_tensor_dtype(self): 7395 def foo(s: float): 7396 return torch.tensor(s), torch.tensor([s, s]) 7397 7398 # need to clear function cache so we re run shape analysis 7399 with set_default_dtype(torch.double): 7400 self.assertEqual(torch.jit.script(foo)(1.), foo(1.), exact_dtype=True) 7401 if GRAPH_EXECUTOR == ProfilingMode.LEGACY: 7402 FileCheck().check("Double").check_same("aten::tensor").run(torch.jit.last_executed_optimized_graph()) 7403 with set_default_dtype(torch.float): 7404 del torch.jit._state._jit_caching_layer[foo] 7405 self.assertEqual(torch.jit.script(foo)(1.), foo(1.), exact_dtype=True) 7406 if GRAPH_EXECUTOR == ProfilingMode.LEGACY: 7407 FileCheck().check("Float").check_same("aten::tensor").run(torch.jit.last_executed_optimized_graph()) 7408 with set_default_dtype(torch.half): 7409 del torch.jit._state._jit_caching_layer[foo] 7410 self.assertEqual(torch.jit.script(foo)(1.), foo(1.), exact_dtype=True) 7411 if GRAPH_EXECUTOR == ProfilingMode.LEGACY: 7412 FileCheck().check("Half").check_same("aten::tensor").run(torch.jit.last_executed_optimized_graph()) 7413 7414 def test_shape_analysis_grad_property(self): 7415 @torch.jit.script 7416 def foo(x): 7417 return torch.sub(x, torch.tanh(x)) 7418 7419 torch._C._jit_pass_complete_shape_analysis(foo.graph, (torch.tensor([0.39]),), False) 7420 7421 # requires_grad property shouldn't be accidentally set by shape analysis 7422 self.assertTrue(foo.graph.findNode("aten::sub").output().requiresGrad() is None) 7423 7424 def test_empty_like_memory_format_bc(self): 7425 def f(x): 7426 # type: (Tensor) -> Tensor 7427 return torch.zeros_like(x, memory_format=None) 7428 7429 scripted_f = torch.jit.script(f) 7430 x = torch.rand(3, 4) 7431 self.assertEqual(scripted_f(x), f(x)) 7432 7433 def test_multiline_string_dedents(self): 7434 def foo() -> None: 7435 multiline_string_dedent_1 = """ 7436This is a string dedent """ 7437 multiline_string_dedent_2 = """ This is a 7438 string dedent """ 7439 multiline_string_dedent_3 = """ 7440 This is a string 7441dedent """ 7442 multiline_string_dedent_4 = """ This is a string dedent """ 7443 7444 scripted_foo = torch.jit.script(foo) 7445 self.assertEqual(scripted_foo(), foo()) 7446 7447 def test_class_with_comment_at_lower_indentation(self): 7448 class Foo(torch.nn.Module): 7449 def forward(self, x): 7450 x = torch.neg(x) 7451 # This comment is at the wrong indent 7452 return x 7453 7454 torch.jit.script(Foo()) 7455 7456 # adapted from test in test_torch 7457 def test_tensor_to(self): 7458 template = dedent(''' 7459 def func(t): 7460 cuda = "{cuda}" 7461 device = "{device}" 7462 non_blocking = {non_blocking} 7463 return {to_str} 7464 ''') 7465 7466 def s(t, to_str, non_blocking=None, device=None, cuda=None): 7467 device = device if device is not None else str(t.device) 7468 non_blocking = non_blocking if non_blocking is not None else False 7469 cuda = "cuda" if cuda is None else cuda 7470 code = template.format(to_str=to_str, device=device, non_blocking=non_blocking, cuda=cuda) 7471 scope = {} 7472 cu = torch.jit.CompilationUnit(code) 7473 return cu.func(t, profile_and_replay=True) 7474 7475 def test_copy_behavior(t, non_blocking=False): 7476 self.assertIs(t, s(t, 't.to(t, non_blocking=non_blocking)', non_blocking)) 7477 self.assertIs(t, s(t, 't.to(t.dtype, non_blocking=non_blocking)', non_blocking)) 7478 self.assertIs(t, s(t, 't.to(torch.empty_like(t), non_blocking=non_blocking)', non_blocking)) 7479 self.assertIsNot(t, s(t, 't.to(t, non_blocking=non_blocking, copy=True)', non_blocking)) 7480 self.assertIsNot(t, s(t, 't.to(t.dtype, non_blocking=non_blocking, copy=True)', non_blocking)) 7481 self.assertIsNot(t, s(t, 't.to(torch.empty_like(t), non_blocking=non_blocking, copy=True)', non_blocking)) 7482 7483 devices = [t.device] 7484 if t.device.type == 'cuda': 7485 if t.device.index == -1: 7486 devices.append(f'cuda:{torch.cuda.current_device()}') 7487 elif t.device.index == torch.cuda.current_device(): 7488 devices.append('cuda') 7489 for device in devices: 7490 self.assertIs(t, s(t, 't.to(device, non_blocking=non_blocking)', non_blocking, device)) 7491 self.assertIs(t, s(t, 't.to(device, t.dtype, non_blocking=non_blocking)', non_blocking, device)) 7492 self.assertIsNot(t, s(t, 't.to(device, non_blocking=non_blocking, copy=True)', non_blocking, device)) 7493 self.assertIsNot(t, s(t, 't.to(device, t.dtype, non_blocking=non_blocking, copy=True)', 7494 non_blocking, device)) 7495 7496 t = torch.tensor(5) 7497 test_copy_behavior(t) 7498 7499 self.assertEqual(t.device, s(t, "t.to('cpu')").device) 7500 self.assertEqual(t.device, s(t, "t.to('cpu', dtype=torch.float32)").device) 7501 self.assertIs(torch.float32, s(t, "t.to('cpu', dtype=torch.float32)").dtype) 7502 self.assertEqual(t.device, s(t, "t.to(torch.float32)").device) 7503 self.assertIs(torch.float32, s(t, "t.to(dtype=torch.float32)").dtype) 7504 self.assertEqual(t.data_ptr(), s(t, "t.to('cpu')").data_ptr()) 7505 self.assertEqual(t.data_ptr(), s(t, "t.to(dtype=t.dtype, device=t.device, copy=False)").data_ptr()) 7506 self.assertEqual(t.data_ptr(), s(t, "t.to('cpu', copy=False)").data_ptr()) 7507 self.assertNotEqual(t.data_ptr(), s(t, "t.to('cpu', copy=True)").data_ptr()) 7508 7509 a = torch.tensor(5) 7510 if torch.cuda.is_available(): 7511 for non_blocking in [True, False]: 7512 for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']: 7513 b = torch.tensor(5., device=cuda) 7514 test_copy_behavior(b, non_blocking) 7515 self.assertEqual(b.device, s(b, "t.to(cuda, non_blocking=non_blocking).device", cuda=cuda)) 7516 self.assertEqual(a.device, s(b, "t.to('cpu', non_blocking=non_blocking).device")) 7517 self.assertEqual(b.device, s(b, "t.to(cuda, non_blocking=non_blocking).device", cuda=cuda)) 7518 self.assertIs(torch.int32, s(b, "t.to('cpu', dtype=torch.int32, non_blocking=non_blocking)").dtype) 7519 self.assertEqual(a.device, s(b, "t.to('cpu', dtype=torch.int32, non_blocking=non_blocking)").device) 7520 self.assertIs(torch.int32, s(b, "t.to(dtype=torch.int32)").dtype) 7521 self.assertEqual(b.device, s(b, "t.to(dtype=torch.int32)").device) 7522 7523 # Test AD: aten::to(Tensor self, int dtype, bool non_blocking, bool copy) -> Tensor 7524 t = torch.tensor(5).float().requires_grad_() 7525 out_ref = t.to(torch.float32) 7526 out = s(t, "t.to(torch.float32)") 7527 self.assertEqual(out_ref, out) 7528 7529 grad_ref = torch.autograd.grad(out_ref.sum(), t) 7530 grad = torch.autograd.grad(out.sum(), t) 7531 self.assertEqual(grad_ref, grad) 7532 7533 # Test AD: aten::to(Tensor self, Device? device, int? dtype, bool non_blocking, bool copy) -> Tensor 7534 out_ref = t.to('cpu') 7535 out = s(t, "t.to('cpu')") 7536 self.assertEqual(out_ref, out) 7537 7538 grad_ref = torch.autograd.grad(out_ref.sum(), t) 7539 grad = torch.autograd.grad(out.sum(), t) 7540 self.assertEqual(grad_ref, grad) 7541 7542 # Test AD: aten::to(Tensor self, Tensor other, bool non_blocking, bool copy) -> Tensor 7543 @torch.jit.script 7544 def func2(t, t_ref): 7545 return t.to(t_ref) 7546 7547 with disable_autodiff_subgraph_inlining(): 7548 t_ref = torch.tensor(4).double() 7549 out_ref = t.to(t_ref) 7550 out = func2(t, t_ref) 7551 grad_ref = torch.autograd.grad(out_ref.sum(), t) 7552 grad = torch.autograd.grad(out.sum(), t) 7553 self.assertEqual(grad_ref, grad) 7554 7555 @unittest.skipIf(not RUN_CUDA, "No CUDA") 7556 def test_tensor_number_math_cuda(self): 7557 self._test_tensor_number_math(device='cuda') 7558 7559 def test_not(self): 7560 # test not operator in python 7561 # TODO: add more tests when bool conversions ready 7562 def test_not_op(a): 7563 return not bool(a > 1) 7564 7565 self.checkScript(test_not_op, (torch.tensor(2), ), optimize=True) 7566 7567 def test_is_isnot(self): 7568 # test is and is not operator in python 7569 template = dedent(''' 7570 def func(): 7571 # type: () -> bool 7572 return {lhs} {op} {rhs} 7573 ''') 7574 7575 def test(op, args): 7576 code = template.format(lhs=args[0], rhs=args[1], op=op) 7577 scope = {} 7578 execWrapper(code, globals(), scope) 7579 cu = torch.jit.CompilationUnit(code) 7580 self.assertEqual( 7581 cu.func(), 7582 scope['func'](), 7583 msg=f"Failed with op: {op}, lhs: {args[0]}, rhs: {args[1]}" 7584 ) 7585 7586 ops = ['is', 'is not'] 7587 type_literals = [True, False, None, [1, 1], 1, 2, .5, 1.5] 7588 7589 # do literals product to try any types combinations 7590 for op, lhs, rhs in product(ops, type_literals, type_literals): 7591 test(op, [lhs, rhs]) 7592 7593 def test_isinstance_refinement(self): 7594 @torch.jit.script 7595 def foo(a): 7596 # type: (Optional[int]) -> int 7597 if isinstance(a, int): 7598 return a + 3 7599 else: 7600 return 4 7601 self.assertEqual(foo(4), 7) 7602 self.assertEqual(foo(None), 4) 7603 7604 @torch.jit.script 7605 def foo2(a, b): 7606 # type: (Optional[int], Optional[int]) -> int 7607 if not isinstance(a, int) or not isinstance(b, int): 7608 return 0 7609 else: 7610 return a + b 7611 self.assertEqual(foo2(3, 4), 7) 7612 self.assertEqual(foo2(None, 4), 0) 7613 self.assertEqual(foo2(4, None), 0) 7614 7615 @torch.jit.script 7616 def any_refinement(a, b): 7617 # type: (Any, Any) -> int 7618 if isinstance(a, int) and isinstance(b, int): 7619 return a + b 7620 return 0 7621 7622 self.assertEqual(any_refinement(3, 4), 7) 7623 self.assertEqual(any_refinement(3, "hi"), 0) 7624 7625 @torch.jit.script 7626 def any_refinement2(a): 7627 # type: (Any) -> Tensor 7628 if isinstance(a, Tensor): 7629 return a 7630 return torch.tensor(3) 7631 7632 self.assertEqual(any_refinement2(3), torch.tensor(3)) 7633 self.assertEqual(any_refinement2(torch.tensor(5)), torch.tensor(5)) 7634 7635 @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.LEGACY, "bug persists in deprecated executor") 7636 def test_unspecialized_any_binding(self): 7637 # any binding will infer the type, if it infers 7638 # a specialized tensor type `x` Dict type will fail isinstance check 7639 7640 @torch.jit.script 7641 def foo(x: Any): 7642 assert isinstance(x, Dict[str, torch.Tensor]) 7643 7644 foo({"1": torch.tensor(3)}) 7645 with self.assertRaises(Exception): 7646 foo(2) 7647 7648 @skipIfTorchDynamo("Not a TorchDynamo suitable test") 7649 def test_isinstance(self): 7650 # test isinstance operator for static type checking 7651 template = dedent(''' 7652 def func(x): 7653 # type: ({type_hint}) -> bool 7654 return isinstance(x, {typ}) 7655 ''') 7656 7657 def test(inp, typ, type_hint): 7658 code = template.format(typ=typ, type_hint=type_hint) 7659 scope = {} 7660 execWrapper(code, globals(), scope) 7661 cu = torch.jit.CompilationUnit(code) 7662 self.assertEqual( 7663 cu.func(inp), 7664 scope['func'](inp), 7665 msg=f"Failed with typ: {typ}" 7666 ) 7667 7668 inputs = [True, 1, 1.0, torch.tensor(1), [1, 2], (1.0,), [1, 2], 1] 7669 type_literals = ['bool', 'int', 'float', 'torch.Tensor', 'list', 'tuple', 7670 '(list, tuple)', '(int, float, bool)'] 7671 type_annotations = ['bool', 'int', 'float', 'Tensor', 'List[int]', 'Tuple[float]', 7672 'List[int]', 'int'] 7673 7674 # do zipping to try different types 7675 for inp, typ, type_hint in zip(inputs, type_literals, type_annotations): 7676 test(inp, typ, type_hint) 7677 7678 # test optional isinstance check 7679 @torch.jit.script 7680 def opt_func(x): 7681 # type: (Optional[int]) -> bool 7682 return isinstance(x, int) 7683 self.assertTrue(opt_func(3)) 7684 self.assertFalse(opt_func(None)) 7685 7686 def test_dropout_eval(self): 7687 class ScriptedConv2d(torch.jit.ScriptModule): 7688 def __init__(self, in_channels, out_channels, **kwargs): 7689 super().__init__() 7690 self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 7691 self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 7692 7693 @torch.jit.script_method 7694 def forward(self, x): 7695 x = self.conv(x) 7696 x = self.bn(x) 7697 return F.relu(x, inplace=True) 7698 7699 class ScriptMod(torch.jit.ScriptModule): 7700 def __init__(self) -> None: 7701 super().__init__() 7702 self.Conv2d_1a_3x3 = ScriptedConv2d(3, 32, kernel_size=3, stride=2) 7703 7704 @torch.jit.script_method 7705 def forward(self, x): 7706 x = self.Conv2d_1a_3x3(x) 7707 return F.dropout(x, training=self.training) 7708 7709 class EagerConv2d(torch.nn.Module): 7710 def __init__(self, in_channels, out_channels, **kwargs): 7711 super().__init__() 7712 self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 7713 self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 7714 7715 def forward(self, x): 7716 x = self.conv(x) 7717 x = self.bn(x) 7718 return F.relu(x, inplace=True) 7719 7720 class EagerMod(torch.nn.Module): 7721 def __init__(self) -> None: 7722 super().__init__() 7723 self.Conv2d_1a_3x3 = EagerConv2d(3, 32, kernel_size=3, stride=2) 7724 7725 def forward(self, x): 7726 x = self.Conv2d_1a_3x3(x) 7727 return F.dropout(x, training=self.training) 7728 7729 script_input = torch.rand(4, 3, 299, 299) 7730 eager_input = script_input.clone() 7731 7732 with freeze_rng_state(): 7733 script_mod = ScriptMod() 7734 script_mod.eval() 7735 script_output = script_mod(script_input) 7736 7737 with freeze_rng_state(): 7738 eager_mod = EagerMod() 7739 eager_mod.eval() 7740 eager_output = eager_mod(eager_input) 7741 7742 self.assertEqual(script_output, eager_output) 7743 7744 with freeze_rng_state(): 7745 script_mod = ScriptMod() 7746 script_mod.train() 7747 script_output = script_mod(script_input) 7748 7749 with freeze_rng_state(): 7750 eager_mod = EagerMod() 7751 eager_mod.train() 7752 eager_output = eager_mod(eager_input) 7753 7754 self.assertEqual(script_output, eager_output) 7755 7756 def test_nested_breaks(self): 7757 def no_bool_loop_outputs(g): 7758 # testing that the "did exit" transform values are not loop block 7759 # outputs (and thus not affecting one loop from another) 7760 loops = g.findAllNodes("prim::Loop") 7761 for loop in loops: 7762 for out in loop.outputs(): 7763 self.assertTrue(out.type() != BoolType.get()) 7764 7765 def test(y): 7766 # type: (int) 7767 ret = 0 7768 tensor = torch.tensor(0) 7769 while int(tensor.add_(1)) < 4: 7770 if y == 1: 7771 continue 7772 for i in range(y): 7773 continue 7774 ret += 1 7775 ret += 1 7776 return ret, int(tensor) 7777 7778 self.assertEqual(torch.jit.script(test)(1), test(1)) 7779 self.assertEqual(torch.jit.script(test)(2), test(2)) 7780 no_bool_loop_outputs(torch.jit.script(test).graph) 7781 7782 def foo(): 7783 y = torch.tensor(0) 7784 z = 0 7785 while int(y.add_(1)) < 20: 7786 if int(y) < 10: 7787 for i in range(6): 7788 if i == 3: 7789 continue 7790 else: 7791 if i > 3: 7792 break 7793 z += 2 7794 if int(y) == 18: 7795 break 7796 if int(y) == 15: 7797 continue 7798 z += 1 7799 return int(y), z 7800 7801 no_bool_loop_outputs(torch.jit.script(foo).graph) 7802 self.checkScript(foo, ()) 7803 7804 def test_nested_two(): 7805 i = 0 7806 k = 0 7807 while i < 5: 7808 for j in range(5): 7809 k += 1 7810 if j == 3: 7811 continue 7812 i += 1 7813 k += 1 7814 if i == 4: 7815 break 7816 return i, k 7817 7818 self.checkScript(test_nested_two, ()) 7819 no_bool_loop_outputs(torch.jit.script(test_nested_two).graph) 7820 7821 def test_breaks_continues(self): 7822 def foo_continue(cond): 7823 # type: (int) 7824 j = 1 7825 for i in range(5): 7826 if i == cond: 7827 continue 7828 j += 1 7829 return j 7830 7831 def foo_break(cond): 7832 # type: (int) 7833 j = 1 7834 for i in range(5): 7835 if i == cond: 7836 break 7837 j += 1 7838 return j 7839 7840 for i in range(1, 4): 7841 self.checkScript(foo_continue, (i,)) 7842 self.checkScript(foo_break, (i,)) 7843 7844 def test_refine_outside_loop(): 7845 if 1 == 1: 7846 x = None 7847 else: 7848 x = 1 7849 i = 0 7850 j = 0 7851 while (x is None or torch.jit._unwrap_optional(x) > 3): 7852 if i < 3: 7853 if i < 3: 7854 x = torch.jit.annotate(Optional[int], None) 7855 i += 1 7856 continue 7857 x = 1 7858 else: 7859 x = 1 if x is None else x 7860 x = x + 1 7861 j = x + x 7862 7863 return x, j 7864 7865 self.checkScript(test_refine_outside_loop, ()) 7866 7867 def assign_after_break(y): 7868 # type: (int) 7869 x = 0 7870 for i in range(y): 7871 x = y * 2 + i 7872 break 7873 x = 4 7874 return x 7875 7876 self.checkScript(assign_after_break, (1,)) 7877 self.checkScript(assign_after_break, (2,)) 7878 self.checkScript(assign_after_break, (3,)) 7879 7880 def assign_after_break_nested(y): 7881 # type: (int) 7882 x = 0 7883 for i in range(y): 7884 if y == 1: 7885 x = 5 7886 break 7887 assert 1 == 2 7888 else: 7889 x = x + 1 7890 break 7891 assert 1 == 2 7892 x = -30 7893 assert 1 == 2 7894 return x 7895 7896 self.checkScript(assign_after_break_nested, (1,)) 7897 self.checkScript(assign_after_break_nested, (2,)) 7898 self.checkScript(assign_after_break_nested, (3,)) 7899 7900 def may_break(y): 7901 # type: (int) 7902 x = 0 7903 for i in range(y): 7904 if y == 1: 7905 x = 5 7906 else: 7907 x = x + 1 7908 break 7909 x = -30 7910 return x 7911 7912 self.checkScript(may_break, (1,)) 7913 self.checkScript(may_break, (2,)) 7914 self.checkScript(may_break, (3,)) 7915 7916 def test(x, y): 7917 # type: (int, int) 7918 a = 1 7919 while (x > 0): 7920 if y == 3: 7921 for i in range(y): 7922 a += (1 % (i + 1)) 7923 x -= 1 7924 if x == 3: 7925 a = x * 3 7926 break 7927 if x < 3: 7928 if x == 1: 7929 a -= 2 7930 x -= 1 7931 break 7932 a -= 1 7933 x -= 3 7934 return a, x 7935 7936 self.checkScript(test, (10, 3)) 7937 self.checkScript(test, (10, 2)) 7938 self.checkScript(test, (3, 2)) 7939 self.checkScript(test, (5, 3)) 7940 self.checkScript(test, (2, 3)) 7941 7942 def test_delete_after_break(x): 7943 # type: (int) 7944 a = 1 7945 b = 1 7946 for i in range(x): 7947 a = i * 3 7948 break 7949 b = i * 5 7950 return a, b 7951 7952 self.checkScript(test_delete_after_break, (0,)) 7953 self.checkScript(test_delete_after_break, (1,)) 7954 7955 def test_will_break_after_guard(x): 7956 # type: (int) 7957 a = 1 7958 for i in range(x): 7959 if i == 4: 7960 a = 3 7961 break 7962 a -= 1 7963 break 7964 assert 1 == 2 7965 a -= -100 7966 return a 7967 7968 self.checkScript(test_will_break_after_guard, (0,)) 7969 self.checkScript(test_will_break_after_guard, (2,)) 7970 self.checkScript(test_will_break_after_guard, (4,)) 7971 7972 def test_varexit(cond): 7973 # type: (int) 7974 m = 0 7975 for i in range(3): 7976 if cond == 2: 7977 if cond == 2: 7978 m = 2 7979 break 7980 k = 1 7981 else: 7982 k = 2 7983 m += k 7984 return m 7985 7986 # use of k tests the pathway where we have to insert unitialized 7987 self.checkScript(test_varexit, (3,)) 7988 self.checkScript(test_varexit, (2,)) 7989 7990 def test_break_true(): 7991 i = 0 7992 while True: 7993 i += 1 7994 if i == 3: 7995 break 7996 while False: 7997 i += 1 7998 return i 7999 8000 self.checkScript(test_break_true, ()) 8001 8002 def test_break_continue_error(self): 8003 with self.assertRaisesRegex(RuntimeError, "Syntax"): 8004 cu = torch.jit.CompilationUnit(''' 8005 def other_func(a): 8006 break 8007 ''') 8008 8009 with self.assertRaisesRegex(RuntimeError, "Syntax"): 8010 cu = torch.jit.CompilationUnit(''' 8011 def other_func(a): 8012 for i in range(5): 8013 def foo(): 8014 break 8015 ''') 8016 8017 with self.assertRaisesRegex(RuntimeError, "do not support break or continue inside"): 8018 @torch.jit.script 8019 def foo(x): 8020 i = 0 8021 for a in (1, "2", 1.5): 8022 b = a 8023 if x: 8024 break 8025 return b 8026 8027 def test_python_call(self): 8028 def pyfunc(a): 8029 return a * 3.0 8030 8031 cu = torch.jit.CompilationUnit(''' 8032 def other_func(a): 8033 return a + a 8034 8035 def test_call_python(a): 8036 b = pyfunc(a) 8037 b = other_func(b) 8038 i = 0 8039 step = 1 8040 while i < 10: 8041 b = pyfunc(b) 8042 if bool(b > 3.0): 8043 b = pyfunc(b) 8044 i = 11 8045 return b 8046 ''') 8047 inputs = self._make_scalar_vars([1], torch.float) 8048 outputs = self._make_scalar_vars([54], torch.float) 8049 8050 self.assertEqual(cu.test_call_python(*inputs), outputs[0]) 8051 8052 def test_python_call_failure(self): 8053 with self.assertRaisesRegex(RuntimeError, "undefined value pyfunc2"): 8054 def pyfunc(a): 8055 return a * 3.0 8056 8057 cu = torch.jit.CompilationUnit(''' 8058 def other_func(a): 8059 return a + a 8060 8061 def test_call_python(a): 8062 b = pyfunc(a) 8063 b = other_func(b) 8064 i = 0 8065 step = 1 8066 while i < 10: 8067 b = pyfunc2(b) 8068 if b > 3.0: 8069 b = pyfunc(b) 8070 i = 11 8071 return b 8072 ''') 8073 inputs = self._make_scalar_vars([1], torch.float) 8074 outputs = self._make_scalar_vars([54], torch.float) 8075 8076 self.assertEqual(cu.test_call_python(*inputs), outputs) 8077 8078 def test_type_call_in_script(self): 8079 @torch.jit.script 8080 def fn(x): 8081 return type(x) 8082 8083 with self.assertRaisesRegex(RuntimeError, "value of type _TensorMeta"): 8084 fn(torch.tensor(.5)) 8085 8086 def test_python_call_annotation(self): 8087 def pyfunc(a): 8088 return a * 3.0 8089 8090 @torch.jit.script 8091 def foo(a): 8092 return pyfunc(a) + pyfunc(a) 8093 8094 inputs = self._make_scalar_vars([1], torch.float) 8095 outputs = self._make_scalar_vars([6], torch.float) 8096 self.assertEqual(foo(*inputs), outputs[0]) 8097 8098 def test_python_call_annoytation_failure(self): 8099 with self.assertRaisesRegex(RuntimeError, "undefined value pyfunc2"): 8100 def pyfunc(a): 8101 return a * 3.0 8102 8103 @torch.jit.script 8104 def foo(a): 8105 return pyfunc2(a) + pyfunc(a) # noqa: F821 8106 8107 inputs = self._make_scalar_vars([1], torch.float) 8108 outputs = self._make_scalar_vars([6], torch.float) 8109 8110 self.assertEqual(foo(*inputs), outputs[0]) 8111 8112 def test_desugar_module(self): 8113 import torch.nn.functional as F 8114 8115 def fn(x, slope): 8116 a = torch.abs(x) 8117 b = torch.nn.functional.prelu(x, slope) 8118 c = F.prelu(x, slope) 8119 return a, b, c 8120 8121 x = torch.arange(-3., 4) 8122 slope = torch.tensor([0.5]) 8123 self.checkScript(fn, [x, slope], optimize=True) 8124 8125 def test_script_docstring(self): 8126 @torch.jit.script 8127 def with_docstring(x): 8128 """test str""" 8129 y = x 8130 """y is the same as x""" 8131 return y 8132 self.assertEqual(with_docstring.__doc__, 'test str') 8133 8134 def test_script_method_docstring(self): 8135 class A(torch.jit.ScriptModule): 8136 @torch.jit.script_method 8137 def with_docstring(self, x): 8138 """test str""" 8139 y = x 8140 """y is the same as x""" 8141 return y 8142 a = A() 8143 self.assertEqual(a.with_docstring.__doc__, 'test str') 8144 8145 def test_script_module(self): 8146 class M1(torch.jit.ScriptModule): 8147 def __init__(self) -> None: 8148 super().__init__() 8149 self.weight = nn.Parameter(torch.randn(2)) 8150 8151 @torch.jit.script_method 8152 def forward(self, thing): 8153 return self.weight + thing 8154 8155 class PModule(nn.Module): 8156 def __init__(self) -> None: 8157 super().__init__() 8158 self.a = nn.Parameter(torch.randn(2, 3)) 8159 8160 def forward(self, a): 8161 return self.a.mm(a) 8162 8163 class M2(torch.jit.ScriptModule): 8164 def __init__(self) -> None: 8165 super().__init__() 8166 # test submodule 8167 self.sub = M1() 8168 self.sub2 = PModule() 8169 # test parameters 8170 self.weight = nn.Parameter(torch.randn(2, 3)) 8171 self.bias = nn.Parameter(torch.randn(2)) 8172 # test defining a method from a string 8173 self.define(""" 8174 def hi(self, a): 8175 return self.weight.mm(a) 8176 """) 8177 # test script methods 8178 8179 @torch.jit.script_method 8180 def doit(self, input): 8181 # test use of parameter 8182 return self.weight.mm(input) 8183 8184 @torch.jit.script_method 8185 def doit2(self, input): 8186 return self.weight.mm(input) 8187 8188 @torch.jit.script_method 8189 def forward(self, input): 8190 a = self.doit(input) 8191 b = self.doit2(input) 8192 c = self.hi(input) 8193 d = self.sub2(input) 8194 return a + b + self.bias + self.sub(a) + c + d 8195 with torch.jit.optimized_execution(False): 8196 m2 = M2() 8197 input = torch.randn(3, 2) 8198 a = m2.weight.mm(input) 8199 b = m2.weight.mm(input) 8200 c = m2.weight.mm(input) 8201 d = m2.sub2.a.mm(input) 8202 ref = a + b + m2.bias + m2.sub.weight + a + c + d 8203 self.assertEqual(ref, m2.forward(input)) 8204 m2.weight = nn.Parameter(torch.zeros_like(m2.weight)) 8205 m2.bias = nn.Parameter(torch.zeros_like(m2.bias)) 8206 m2.sub.weight = nn.Parameter(torch.zeros_like(m2.sub.weight)) 8207 m2.sub2.a.data.zero_() 8208 self.assertEqual(torch.zeros(2, 2), m2.forward(torch.randn(3, 2))) 8209 8210 def test_irparser(self): 8211 graph_str = """graph(%0 : Double(5, 5)): 8212 # CHECK: aten::relu 8213 %1 : Double(5, 5) = aten::relu(%0) 8214 return (%1) 8215 """ 8216 FileCheck().run(graph_str, parse_ir(graph_str)) 8217 8218 def test_parse_tensor_constants(self): 8219 def foo(): 8220 return torch.zeros([4, 4]) 8221 8222 foo_s = torch.jit.script(foo) 8223 torch._C._jit_pass_constant_propagation(foo_s.graph) 8224 8225 g = str(foo_s.graph) 8226 g_parsed = parse_ir(g, parse_tensor_constants=True) 8227 self.assertEqual(str(canonical(g_parsed)), str(canonical(foo_s.graph))) 8228 func = torch._C._create_function_from_graph("forward", g_parsed) 8229 8230 out_parsed = func() 8231 out_func = foo() 8232 # not checking data, just dtype, size etc 8233 out_parsed[:] = 0 8234 out_func[:] = 0 8235 self.assertEqual(out_func, out_parsed) 8236 8237 with self.assertRaises(RuntimeError): 8238 parse_ir(g, parse_tensor_constants=False) 8239 8240 def test_parse_nested_names(self): 8241 g_str = """ 8242 graph(%x.1 : Tensor): 8243 %3 : int = prim::Constant[value=1]() 8244 %2 : int = prim::Constant[value=2]() 8245 %hi.submod.value.5 : Tensor = aten::add(%x.1, %2, %3) 8246 return (%hi.submod.value.5) 8247 """ 8248 g = parse_ir(g_str) 8249 round_trip_g = parse_ir(str(g)) 8250 self.assertEqual(canonical(g), canonical(round_trip_g)) 8251 8252 func1 = torch._C._create_function_from_graph("forward", g) 8253 func2 = torch._C._create_function_from_graph("forward", round_trip_g) 8254 self.assertEqual(func1(torch.ones([2])), func2(torch.ones([2]))) 8255 8256 def test_is_after_use(self): 8257 def sorted_input_use(g): 8258 uses = list(next(g.inputs()).uses()) 8259 return sorted(uses, key=functools.cmp_to_key(type(uses[0]).isAfter)) 8260 8261 @torch.jit.script 8262 def foo(x): 8263 a = x + 1 8264 return (x, x, a) 8265 8266 uses_sorted = sorted_input_use(foo.graph) 8267 # sorts last use to the end 8268 self.assertFalse(uses_sorted[0].isAfter(uses_sorted[1])) 8269 self.assertTrue(uses_sorted[0].user.kind() == "aten::add") 8270 self.assertEqual(uses_sorted[1].offset, 0) 8271 8272 @torch.jit.script 8273 def foo(x, cond: bool): 8274 if cond: 8275 return x + 3 8276 else: 8277 return x - 3 8278 8279 uses_sorted = sorted_input_use(foo.graph) 8280 self.assertTrue(uses_sorted[0].user.kind() == "aten::add") 8281 self.assertTrue(uses_sorted[1].user.kind() == "aten::sub") 8282 8283 @torch.jit.script 8284 def foo(x, cond: bool, cond2: bool): 8285 if cond: 8286 return x + 3 8287 elif cond2 : 8288 return x - 3 8289 8290 return x / 3 8291 8292 graph1 = foo.graph 8293 8294 @torch.jit.script 8295 def foo(x, cond: bool, cond2: bool): 8296 if cond: 8297 return x + 3 8298 else: 8299 if cond2 : 8300 return x - 3 8301 return x / 3 8302 8303 graph2 = foo.graph 8304 8305 for graph in [graph1, graph2]: 8306 uses_sorted = sorted_input_use(graph) 8307 self.assertTrue(uses_sorted[0].user.kind() == "aten::add") 8308 self.assertTrue(uses_sorted[1].user.kind() == "aten::sub") 8309 self.assertTrue(uses_sorted[2].user.kind() == "aten::div") 8310 8311 def test_canonicalize_control_outputs(self): 8312 def test_all_outputs(g): 8313 ifs = g.findAllNodes("prim::If") 8314 loops = g.findAllNodes("prim::Loop") 8315 8316 def contained_blocks(node): 8317 return len(node.findAllNodes("prim::If")) * 2 + len(node.findAllNodes("prim::Loop")) 8318 for node in ifs + loops: 8319 outs = list(node.outputs()) 8320 out_name = [x.debugName() for x in outs] 8321 if len(out_name) == 0: 8322 continue 8323 fc = FileCheck() 8324 # find the last output, then all subsequent uses 8325 fc.check(out_name[-1] + " : ") 8326 # skip past node body 8327 for i in range(contained_blocks(node)): 8328 fc.check("->") 8329 if (node.kind() == "prim::If"): 8330 fc.check("->").check("->").check("\n") 8331 else: 8332 fc.check("->").check("\n") 8333 # the canonical order is the same order as the first use 8334 # appears in text 8335 for name in out_name: 8336 fc.check(name) 8337 fc.run(g) 8338 8339 @torch.jit.script 8340 def test(x): 8341 # type: (bool) -> Tuple[int, int] 8342 b = 2 8343 a = 1 8344 if x: 8345 a = 1 8346 b = 2 8347 x = False 8348 if x: 8349 b = a 8350 else: 8351 a = b 8352 8353 return a, b 8354 test_all_outputs(test.graph) 8355 8356 @torch.jit.script 8357 def test2(x): 8358 # type: (bool) -> Tuple[int, int] 8359 b = 2 8360 a = 1 8361 if x: 8362 a = 1 8363 b = 2 8364 x = False 8365 if x: 8366 print(a) 8367 else: 8368 if x: 8369 print(b) 8370 8371 return a, b 8372 test_all_outputs(test2.graph) 8373 8374 @torch.jit.script 8375 def test_loop(x, iter): 8376 # type: (bool, int) -> (None) 8377 a = 1 8378 b = 2 8379 c = 3 8380 for i in range(iter): 8381 a = 4 8382 b = 5 8383 c = 6 8384 x = True 8385 print(c) 8386 if x: 8387 print(a, b) 8388 test_all_outputs(test_loop.graph) 8389 8390 @torch.jit.script 8391 def loop_unused(iter): 8392 # type: (int) -> (None) 8393 a = 1 8394 b = 2 8395 c = 3 8396 for i in range(iter): 8397 c = c + 1 8398 b = b + 1 8399 a = a + 1 8400 print(a, b) 8401 print(c) 8402 8403 # c is used, then unused should be ordered by alphabetical 8404 FileCheck().check(r"%c : int, %a : int, %b : int").run(loop_unused.graph) 8405 8406 def test_filecheck(self): 8407 def test_check(): 8408 file = "232" 8409 FileCheck().check("2").check("3").check("2").run(file) 8410 FileCheck().check("232").run(file) 8411 8412 with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'): 8413 FileCheck().check("22").run(file) 8414 with self.assertRaisesRegex(RuntimeError, "CHECK: 3"): 8415 FileCheck().check("3").check("3").run(file) 8416 8417 test_check() 8418 8419 def test_check_count(): 8420 file = "22222" 8421 FileCheck().check_count("2", 5).run(file) 8422 FileCheck().check_count("22", 2).run(file) 8423 FileCheck().check_count("222", 1).run(file) 8424 8425 with self.assertRaisesRegex(RuntimeError, 'Expected to not find'): 8426 FileCheck().check_count("2", 4, exactly=True).run(file) 8427 8428 with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'): 8429 FileCheck().check_count("22", 3).run(file) 8430 8431 with self.assertRaisesRegex(RuntimeError, "CHECK-COUNT-6: 2"): 8432 FileCheck().check_count("2", 6).run(file) 8433 8434 test_check_count() 8435 8436 def test_check_same(): 8437 file = "22\n33" 8438 FileCheck().check_same("22").run(file) 8439 8440 with self.assertRaisesRegex(RuntimeError, "Expected to not find"): 8441 FileCheck().check_same("33").run(file) 8442 8443 file = "22 1 3" 8444 8445 FileCheck().check("2").check_same("3").run(file) 8446 FileCheck().check_count("2", 2).check_same("3").run(file) 8447 8448 test_check_same() 8449 8450 def test_check_next(): 8451 file = "\n1\n2\n3" 8452 FileCheck().check("1").check_next("2").check_next("3").run(file) 8453 FileCheck().check_next("1").check_next("2").check_next("3").run(file) 8454 8455 with self.assertRaisesRegex(RuntimeError, "Expected to find"): 8456 FileCheck().check("1").check_next("2").run("12") 8457 8458 with self.assertRaisesRegex(RuntimeError, "Expected to not find"): 8459 FileCheck().check("1").check_next("2").run("1\n\n2") 8460 8461 test_check_next() 8462 8463 def test_check_dag(): 8464 fc = FileCheck().check_dag("1").check_dag("2").check_not("2") 8465 fc.run("12") 8466 fc.run("21") 8467 8468 fc = FileCheck() 8469 fc.check_not("3").check_dag("1").check_dag("2").check_not("3") 8470 fc.run("1 3 2") 8471 fc.run("2 3 1") 8472 8473 fc = FileCheck().check_dag("1").check_dag("2").check("3") 8474 with self.assertRaisesRegex(RuntimeError, 'Expected to find "3" but did not find it'): 8475 fc.run("1 3 2") 8476 8477 test_check_dag() 8478 8479 def test_check_not(): 8480 FileCheck().check_not("2").check("1").run("12") 8481 FileCheck().check("2").check_not("2").run("12") 8482 8483 with self.assertRaisesRegex(RuntimeError, 'Expected to not find "2"'): 8484 FileCheck().check_not("2").check("1").run("21") 8485 8486 with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'): 8487 FileCheck().check("2").check_not("1").run("21") 8488 8489 # checks with distinct range matchings 8490 fb = FileCheck().check_count("2", 2).check_count("2", 2).check_not("2") 8491 with self.assertRaisesRegex(RuntimeError, 'Expected to not find "2"'): 8492 fb.run("22 2 22") 8493 8494 fb = FileCheck().check_count("2", 2).check_not("1").check_count("2", 2) 8495 with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'): 8496 fb.run("22 1 22") 8497 8498 def _dtype_to_jit_name(self, dtype): 8499 if dtype == torch.float32: 8500 return "Float" 8501 if dtype == torch.float64: 8502 return "Double" 8503 if dtype == torch.int64: 8504 return "Long" 8505 if dtype == torch.int32: 8506 return "Int" 8507 if dtype == torch.bool: 8508 return "Bool" 8509 raise RuntimeError('dtype not handled') 8510 8511 def _dtype_to_expect(self, dtype, dim=0): 8512 param = ', '.join(['*'] * dim + ['device=cpu']) 8513 param = '(' + param + ')' 8514 jit_type = self._dtype_to_jit_name(dtype) 8515 if dim >= 0: 8516 return jit_type + param 8517 # special case representing wrapped number 8518 else: 8519 return jit_type.lower() 8520 8521 8522 def _test_dtype_op_shape(self, ops, args, input_dims=1): 8523 if input_dims < 1: 8524 raise RuntimeError("input dims must be at least 1") 8525 dtypes = [torch.float32, torch.float64, torch.int64, torch.int32] 8526 str_args = ', '.join([str(arg) for arg in args]) + (', ' if len(args) else '') 8527 tensor_data = ('[' * input_dims) + '1, 2, 3' + (input_dims * ']') 8528 template = dedent(''' 8529 def func(): 8530 return {return_line} 8531 ''') 8532 8533 for op in ops: 8534 for dtype in (dtypes + [None]): 8535 for tensor_type in dtypes: 8536 # a couple of ops aren't implemented for non-floating types 8537 if not tensor_type.is_floating_point or (dtype is not None and not dtype.is_floating_point): 8538 if op in ['mean', 'softmax', 'log_softmax']: 8539 continue 8540 return_line = f"torch.tensor({tensor_data}, dtype={tensor_type}).{op}({str_args}dtype={dtype})" 8541 # uncomment for debugging a failed test: 8542 # print("testing {}".format(return_line)) 8543 code = template.format(return_line=return_line) 8544 scope = {} 8545 exec(code, globals(), scope) 8546 cu = torch.jit.CompilationUnit(code) 8547 graph = cu.func.graph 8548 torch._C._jit_pass_complete_shape_analysis(graph, (), False) 8549 input_array = [1, 2, 3] 8550 for _ in range(1, input_dims): 8551 input_array = [input_array] 8552 t = torch.tensor(input_array, dtype=tensor_type) 8553 attr = getattr(t, op) 8554 kwargs = {'dtype': dtype} 8555 result = attr(*args, **kwargs) 8556 expect = self._dtype_to_expect(result.dtype, result.dim()) 8557 FileCheck().check("aten::tensor").check(expect).run(graph) 8558 8559 def test_dtype_op_shape(self): 8560 ops = ['prod'] 8561 self._test_dtype_op_shape(ops, args=[]) 8562 self._test_dtype_op_shape(ops, args=[0, False]) 8563 self._test_dtype_op_shape(ops, args=[0, False]) 8564 self._test_dtype_op_shape(ops, args=[0, True]) 8565 8566 def test_dtype_op_shape2(self): 8567 ops = ['cumprod', 'cumsum', 'softmax', 'log_softmax'] 8568 self._test_dtype_op_shape(ops, args=[0]) 8569 8570 self._test_dtype_op_shape(ops, args=[1], input_dims=4) 8571 8572 8573 def _test_binary_op_shape(self, ops, input_dims=1): 8574 8575 dtypes = [torch.float32, torch.float64, torch.int64, torch.int32, torch.bool] 8576 8577 if input_dims == 0: 8578 shape = '1' 8579 else: 8580 shape = '[' + ('1,' * 4) + ']' 8581 for _ in range(1, input_dims): 8582 shape = '[' + ",".join([shape] * 4) + ']' 8583 8584 template = dedent(''' 8585 def func(): 8586 arg1 = {} 8587 arg2 = {} 8588 return torch.{}(arg1, arg2) 8589 ''') 8590 8591 args = [] 8592 for dtype in dtypes: 8593 args = args + [f"torch.tensor({shape}, dtype={dtype})"] 8594 args = args + [1, 1.5] 8595 8596 def isBool(arg): 8597 return type(arg) == bool or (type(arg) == str and "torch.bool" in arg) 8598 8599 for op in ops: 8600 for first_arg in args: 8601 for second_arg in args: 8602 # subtract not supported for bool 8603 if (op == 'sub' or op == 'div') and (isBool(first_arg) or isBool(second_arg)): 8604 continue 8605 # div is not implemented correctly for mixed-type or int params 8606 if (op == 'div' and (type(first_arg) != type(second_arg) or 8607 isinstance(first_arg, int) or 8608 (isinstance(first_arg, str) and 'int' in first_arg))): 8609 continue 8610 return_line = f"torch.{op}({first_arg}, {second_arg})" 8611 # uncomment for debugging a failed test: 8612 # print("testing {}".format(return_line)) 8613 code = template.format(first_arg, second_arg, op) 8614 scope = {} 8615 exec(code, globals(), scope) 8616 non_jit_result = scope['func']() 8617 8618 cu = torch.jit.CompilationUnit(code) 8619 graph = cu.func.graph 8620 torch._C._jit_pass_complete_shape_analysis(graph, (), False) 8621 # use dim=-1 to represent a python/jit scalar. 8622 dim = -1 if type(first_arg) != str and type(second_arg) != str else non_jit_result.dim() 8623 dtype = non_jit_result.dtype 8624 # jit only supports int/float scalars. 8625 if dim < 0: 8626 if dtype == torch.int64: 8627 dtype = torch.int32 8628 if dtype == torch.float64: 8629 dtype = torch.float32 8630 expect = self._dtype_to_expect(dtype, dim) 8631 jit_output = next(graph.outputs()) 8632 8633 check = FileCheck() 8634 check.check(expect).run(str(jit_output)) 8635 8636 def test_binary_op_shape(self): 8637 self._test_binary_op_shape(['mul', 'div', 'add', 'sub'], 0) 8638 self._test_binary_op_shape(['mul', 'div', 'add', 'sub'], 3) 8639 8640 def test_no_dtype_shape(self): 8641 8642 @torch.jit.script 8643 def foo(x): 8644 scalar_number = x.item() 8645 return x.add(scalar_number) 8646 8647 @torch.jit.script 8648 def foo2(x): 8649 scalar_number = x.item() 8650 return torch.tensor(1).add(scalar_number) 8651 8652 t = torch.tensor(5) 8653 g = foo.graph_for(t) 8654 type = next(g.outputs()) 8655 self.assertTrue(type.type() == torch._C.TensorType.get()) 8656 g2 = foo2.graph_for(t) 8657 type = next(g.outputs()) 8658 self.assertTrue(type.type() == torch._C.TensorType.get()) 8659 8660 8661 def test_filecheck_parse(self): 8662 def test_check(): 8663 file = """ 8664 # CHECK: 2 8665 # CHECK: 3 8666 # CHECK: 2 8667 232 8668 """ 8669 FileCheck().run(checks_file=file, test_file=file) 8670 file = """ 8671 # CHECK: 232 8672 232 8673 """ 8674 FileCheck().run(file, "232") 8675 with self.assertRaisesRegex(RuntimeError, 'Expected to find "232"'): 8676 FileCheck().run(file, "22") 8677 with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'): 8678 FileCheck().run("# CHECK: 22", "23") 8679 test_check() 8680 8681 def test_check_count(): 8682 file = "22222" 8683 FileCheck().run("# CHECK-COUNT-5: 2", file) 8684 FileCheck().run("# CHECK-COUNT-EXACTLY-5: 2", file) 8685 FileCheck().run("# CHECK-COUNT-2: 22", file) 8686 FileCheck().run("# CHECK-COUNT-1: 222", file) 8687 8688 with self.assertRaisesRegex(RuntimeError, 'Expected to not find'): 8689 FileCheck().run("# CHECK-COUNT-EXACTLY-2: 2", file) 8690 test_check_count() 8691 8692 def test_check_same(): 8693 file = "22\n33" 8694 FileCheck().run("# CHECK-SAME: 22", file) 8695 8696 with self.assertRaisesRegex(RuntimeError, "Expected to not find"): 8697 FileCheck().run("# CHECK-SAME: 33", file) 8698 8699 file = "22 1 3" 8700 8701 FileCheck().run("# CHECK: 2\n # CHECK-SAME: 3", file) 8702 FileCheck().run("# CHECK-COUNT-2: 2\n # CHECK-SAME: 3", file) 8703 test_check_same() 8704 8705 def test_bad_input(): 8706 with self.assertRaisesRegex(RuntimeError, "Check for bad input"): 8707 FileCheck().run("", "1") 8708 8709 with self.assertRaisesRegex(RuntimeError, "Could not parse check"): 8710 FileCheck().run("# CHECK1", "") 8711 8712 test_bad_input() 8713 8714 def test_script_module_call_noscript(self): 8715 class M(torch.jit.ScriptModule): 8716 def __init__(self) -> None: 8717 super().__init__() 8718 self.value = 1 8719 8720 @torch.jit.ignore 8721 def foo(self): 8722 return torch.ones(2, 2) + self.value 8723 8724 @torch.jit.script_method 8725 def forward(self, input): 8726 return input + self.foo() 8727 8728 with torch.jit.optimized_execution(False): 8729 m = M() 8730 input = torch.randn(2, 2) 8731 o = m(input) 8732 self.assertEqual(o, input + torch.ones(2, 2) + 1) 8733 # check that we can change python attributes 8734 # and that those changes are picked up in script methods 8735 m.value = 2 8736 o = m(input) 8737 self.assertEqual(o, input + torch.ones(2, 2) + 2) 8738 8739 def test_script_module_nochange_submodule(self): 8740 class M(torch.jit.ScriptModule): 8741 def __init__(self) -> None: 8742 super().__init__() 8743 self.sub = nn.Linear(5, 5) 8744 8745 @torch.jit.script_method 8746 def forward(self, input): 8747 return self.sub(input) 8748 with torch.jit.optimized_execution(False): 8749 m = M() 8750 input = torch.randn(1, 5, 5) 8751 o = m(input) 8752 self.assertEqual(o, m.sub(input)) 8753 with self.assertRaisesRegex(RuntimeError, "Cannot re-assign"): 8754 m.sub = nn.Linear(5, 5) 8755 8756 def test_module_apis(self): 8757 class Sub(torch.nn.Module): 8758 def forward(self, thing): 8759 return thing - 2 8760 8761 class Double(torch.nn.Module): 8762 def forward(self, thing): 8763 return thing * 2 8764 8765 class MyMod(torch.nn.Module): 8766 def __init__(self) -> None: 8767 super().__init__() 8768 self.mod = (Sub()) 8769 self.mod2 = (Sub()) 8770 self.mod3 = nn.Sequential(nn.Sequential(Sub())) 8771 self.mod4 = nn.Sequential(Sub(), Double()) 8772 8773 @torch.jit.export 8774 def method(self, x, x1, y, y1): 8775 mod_names = "" 8776 for name, mod in self.named_modules(): 8777 mod_names = mod_names + " " + name 8778 x = mod(x) 8779 8780 children_names = "" 8781 for name, mod in self.named_children(): 8782 children_names = children_names + " " + name 8783 x1 = mod(x1) 8784 8785 for mod in self.modules(): 8786 y = mod(y) 8787 8788 for mod in self.children(): 8789 y1 = mod(y1) 8790 8791 return mod_names, children_names, x, x1, y, y1 8792 8793 def forward(self, x): 8794 return x + 2 8795 8796 mod = torch.jit.script(MyMod()) 8797 inps = tuple([torch.tensor(i) for i in range(1, 5)]) 8798 self.assertEqual(mod.method(*inps), MyMod().method(*inps)) 8799 8800 def test_script_module_const(self): 8801 class M(torch.jit.ScriptModule): 8802 8803 __constants__ = ['b', 'i', 'c', 's'] 8804 8805 def __init__(self) -> None: 8806 super().__init__() 8807 self.b = False 8808 self.i = 1 8809 self.c = 3.5 8810 self.s = ["hello"] 8811 8812 @torch.jit.script_method 8813 def forward(self): 8814 return self.b, self.i, self.c 8815 8816 with torch.jit.optimized_execution(False): 8817 m = M() 8818 o0, o1, o2 = m() 8819 self.assertEqual(o0, 0) 8820 self.assertEqual(o1, 1) 8821 self.assertEqual(o2, 3.5) 8822 8823 def test_script_module_fail_exist(self): 8824 class M(torch.jit.ScriptModule): 8825 @torch.jit.script_method 8826 def forward(self, x): 8827 return x + self.whatisgoingon 8828 with self.assertRaisesRegex(RuntimeError, "Module 'M' has no attribute"): 8829 M() 8830 8831 @unittest.skip("[module dedupe] currently NoneType refinement on optional attributes doesn't work.") 8832 def test_script_module_none_exist_fail(self): 8833 class M(torch.jit.ScriptModule): 8834 def __init__(self, my_optional): 8835 super().__init__() 8836 self.my_optional = my_optional 8837 8838 @torch.jit.script_method 8839 def forward(self, x): 8840 if self.my_optional is not None: 8841 return torch.neg(x) + self.my_optional 8842 return torch.neg(x) 8843 with self.assertRaisesRegex(RuntimeError, "has no attribute 'my_optional'"): 8844 x = torch.rand(3, 4) 8845 fb = M(None) 8846 fb(x) 8847 8848 def test_script_module_invalid_consts(self): 8849 class Foo(torch.jit.ScriptModule): 8850 __constants__ = ['invalid'] 8851 8852 def __init__(self) -> None: 8853 super().__init__() 8854 self.invalid = [nn.Linear(3, 4)] 8855 8856 with self.assertRaisesRegex( 8857 TypeError, 8858 "Linear' object in attribute 'Foo.invalid' is not a valid constant"): 8859 Foo() 8860 8861 class Foo2(torch.jit.ScriptModule): 8862 __constants__ = ['invalid'] 8863 8864 def __init__(self) -> None: 8865 super().__init__() 8866 self.invalid = int 8867 8868 with self.assertRaisesRegex(TypeError, "not a valid constant"): 8869 Foo2() 8870 8871 class Foo3(torch.jit.ScriptModule): 8872 __constants__ = ['invalid'] 8873 8874 def __init__(self) -> None: 8875 super().__init__() 8876 self.invalid = (3, 4, {}) 8877 8878 with self.assertRaisesRegex(TypeError, "not a valid constant"): 8879 Foo3() 8880 8881 class Foo4(torch.jit.ScriptModule): 8882 __constants__ = ['invalid'] 8883 8884 def __init__(self) -> None: 8885 super().__init__() 8886 self.invalid = np.int64(5) 8887 8888 # verify that we capture human understandable class name 8889 with self.assertRaisesRegex(TypeError, "numpy.int64"): 8890 Foo4() 8891 8892 def test_script_module_param_buffer_mutation(self): 8893 # TODO: add param mutation test case after JIT support it 8894 class ModuleBufferMutate(torch.jit.ScriptModule): 8895 def __init__(self) -> None: 8896 super().__init__() 8897 self.running_var = nn.Buffer(torch.tensor(0, dtype=torch.long)) 8898 8899 @torch.jit.script_method 8900 def forward(self): 8901 if self.training: 8902 self.running_var += 1 8903 return self.running_var 8904 8905 with torch.jit.optimized_execution(False): 8906 m = ModuleBufferMutate() 8907 self.assertEqual(m(), 1) 8908 m.eval() 8909 self.assertEqual(m(), 1) 8910 8911 def test_script_module_for(self): 8912 class M(torch.jit.ScriptModule): 8913 __constants__ = ['b'] 8914 8915 def __init__(self) -> None: 8916 super().__init__() 8917 self.b = [1, 2, 3, 4] 8918 8919 @torch.jit.script_method 8920 def forward(self): 8921 sum = 0 8922 for i in self.b: 8923 sum += i 8924 return sum 8925 8926 with torch.jit.optimized_execution(False): 8927 m = M() 8928 self.assertEqual(m(), 10) 8929 8930 def test_override_magic(self): 8931 class OverrideMagic(nn.Module): 8932 @torch.jit.export 8933 def __len__(self): 8934 return 10 8935 8936 mod = OverrideMagic() 8937 self.assertEqual(len(mod), len(torch.jit.script(mod))) 8938 8939 class OverrideMagicSeq(nn.Sequential): 8940 @torch.jit.export 8941 def __len__(self): 8942 return 10 8943 8944 mod = OverrideMagicSeq() 8945 self.assertEqual(len(mod), len(torch.jit.script(mod))) 8946 self.assertTrue(torch.jit.script(mod)) 8947 8948 def test_script_module_for2(self): 8949 class Sub(torch.jit.ScriptModule): 8950 def __init__(self) -> None: 8951 super().__init__() 8952 self.weight = nn.Parameter(torch.randn(2)) 8953 8954 @torch.jit.script_method 8955 def forward(self, thing): 8956 return self.weight + thing 8957 8958 class M(torch.jit.ScriptModule): 8959 def __init__(self) -> None: 8960 super().__init__() 8961 self.mods = nn.ModuleList([Sub() for i in range(10)]) 8962 8963 @torch.jit.script_method 8964 def forward(self, v): 8965 for m in self.mods: 8966 v = m(v) 8967 return v 8968 8969 with torch.jit.optimized_execution(False): 8970 i = torch.empty(2) 8971 m = M() 8972 o = m(i) 8973 v = i 8974 for sub in m.mods: 8975 v = sub(v) 8976 self.assertEqual(o, v) 8977 with self.assertRaisesRegex(Exception, "object is not iterable"): 8978 print(list(m)) 8979 8980 def test_attr_qscheme_script(self): 8981 class Foo(torch.nn.Module): 8982 def __init__(self) -> None: 8983 super().__init__() 8984 self.qscheme = torch.per_tensor_affine 8985 8986 def forward(self): 8987 if self.qscheme == torch.per_tensor_symmetric: 8988 return 3 8989 else: 8990 return 4 8991 8992 f = Foo() 8993 scripted = torch.jit.script(f) 8994 self.assertEqual(f(), scripted()) 8995 8996 def test_script_module_const_submodule_fail(self): 8997 class Sub(torch.jit.ScriptModule): 8998 def __init__(self) -> None: 8999 super().__init__() 9000 self.weight = nn.Parameter(torch.randn(2)) 9001 9002 @torch.jit.script_method 9003 def forward(self, thing): 9004 return self.weight + thing 9005 9006 class M(torch.jit.ScriptModule): 9007 def __init__(self) -> None: 9008 super().__init__() 9009 self.mods = [Sub() for _ in range(10)] 9010 9011 @torch.jit.script_method 9012 def forward(self): 9013 for _ in self.mods: 9014 print(1) 9015 return 4 9016 9017 with self.assertRaisesRegex(RuntimeError, "has no attribute 'mods'"): 9018 M() 9019 9020 class DerivedStateModule(torch.jit.ScriptModule): 9021 def __init__(self) -> None: 9022 super(TestScript.DerivedStateModule, self).__init__() 9023 self.param = torch.nn.Parameter(torch.ones(3, 4, dtype=torch.float)) 9024 self.derived = nn.Buffer(torch.neg(self.param).detach().clone()) 9025 9026 # This is a flag so we can test that the pack method was called 9027 self.pack_called = nn.Buffer(torch.zeros(1, dtype=torch.long)) 9028 # This is a flag so we can test that the unpack method was called 9029 self.unpack_called = nn.Buffer(torch.zeros(1, dtype=torch.long)) 9030 9031 @torch.jit.script_method 9032 def _pack(self): 9033 self.pack_called.set_(torch.ones(1, dtype=torch.long)) 9034 self.derived.set_(torch.rand(1).detach()) 9035 9036 @torch.jit.script_method 9037 def _unpack(self): 9038 self.unpack_called.set_(torch.ones(1, dtype=torch.long)) 9039 self.derived.set_(torch.neg(self.param).detach()) 9040 9041 @torch.jit.script_method 9042 def forward(self, x): 9043 return x + self.derived 9044 9045 def test_pack_unpack_state(self): 9046 sm = TestScript.DerivedStateModule() 9047 x = torch.rand(3, 4) 9048 torch.testing.assert_close(sm(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float))) 9049 9050 # Test save path 9051 self.assertFalse(sm.pack_called.item()) 9052 self.assertFalse(sm.unpack_called.item()) 9053 imported = self.getExportImportCopyWithPacking(sm) 9054 # ensure pack was called before serialization 9055 self.assertTrue(sm.pack_called.item()) 9056 # ensure unpack was called after serialization so as to leave the module in an initialized state 9057 self.assertTrue(sm.unpack_called.item()) 9058 9059 torch.testing.assert_close(sm.derived, torch.neg(sm.param)) 9060 9061 # Test load paths 9062 self.assertTrue(imported.unpack_called.item()) 9063 torch.testing.assert_close(imported(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float))) 9064 9065 @unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support") 9066 @unittest.skipIf(True, "Skipping while landing PR stack") 9067 def test_torch_functional(self): 9068 def stft(input, n_fft): 9069 # type: (Tensor, int) -> Tensor 9070 return torch.stft(input, n_fft, return_complex=True) 9071 9072 inps = (torch.randn(10), 7) 9073 self.assertEqual(stft(*inps), torch.jit.script(stft)(*inps)) 9074 9075 def istft(input, n_fft): 9076 # type: (Tensor, int) -> Tensor 9077 return torch.istft(input, n_fft) 9078 9079 inps2 = (stft(*inps), inps[1]) 9080 self.assertEqual(istft(*inps2), torch.jit.script(istft)(*inps2)) 9081 9082 def lu_unpack(x): 9083 A_LU, pivots = torch.linalg.lu_factor(x) 9084 return torch.lu_unpack(A_LU, pivots) 9085 9086 for shape in ((3, 3), (5, 3, 3), (7, 3, 5, 5), (7, 5, 3, 3, 3)): 9087 a = torch.randn(*shape) 9088 self.checkScript(lu_unpack, (a,)) 9089 9090 def cdist_fn(): 9091 a = torch.tensor([[0.9041, 0.0196], [-0.3108, -2.4423], [-0.4821, 1.059]]) 9092 b = torch.tensor([[-2.1763, -0.4713], [-0.6986, 1.3702]]) 9093 return torch.cdist(a, b, compute_mode="use_mm_for_euclid_dist") 9094 9095 self.checkScript(cdist_fn, ()) 9096 9097 def norm(): 9098 c = torch.tensor([[1, 2, 3], [-1, 1, 4]], dtype=torch.float) 9099 return torch.norm(c, p="fro"), torch.norm(c, p="nuc"), torch.norm(c), torch.norm(c, p=.5) 9100 9101 self.checkScript(norm, ()) 9102 9103 def torch_unique(dim: Optional[int]): 9104 ten = torch.unique(torch.tensor([[1, 3], [2, 3]], dtype=torch.long)) 9105 a = torch.unique(ten, dim=dim) 9106 b = torch.unique(ten, return_counts=True, dim=dim) 9107 c = torch.unique(ten, return_inverse=True, dim=dim) 9108 d = torch.unique(ten, return_counts=True, return_inverse=True, dim=dim) 9109 return a, b, c, d 9110 9111 self.checkScript(torch_unique, (None,)) 9112 self.checkScript(torch_unique, (0,)) 9113 9114 def torch_unique_consecutive(dim: Optional[int]): 9115 ten = torch.unique(torch.tensor([[1, 3], [3, 2], [3, 2], [2, 3]], dtype=torch.long)) 9116 a = torch.unique_consecutive(ten, dim=dim) 9117 b = torch.unique_consecutive(ten, return_counts=True, dim=dim) 9118 c = torch.unique_consecutive(ten, return_inverse=True, dim=dim) 9119 d = torch.unique_consecutive(ten, return_counts=True, return_inverse=True, dim=dim) 9120 return a, b, c, d 9121 9122 self.checkScript(torch_unique_consecutive, (None,)) 9123 self.checkScript(torch_unique_consecutive, (0,)) 9124 9125 def test_torch_functional_tensordot_int(self): 9126 def tensordot_dims_int(a: torch.Tensor, b: torch.Tensor, dims: int): 9127 return torch.tensordot(a, b, dims=dims) 9128 9129 a = torch.arange(120.).reshape(2, 3, 4, 5) 9130 b = torch.arange(840.).reshape(4, 5, 6, 7) 9131 dims = 2 9132 self.checkScript(tensordot_dims_int, (a, b, dims)) 9133 9134 for dims in [-1, 5]: 9135 try: 9136 tensordot_dims_int(a, b, dims) 9137 except RuntimeError as error: 9138 if dims < 0: 9139 self.assertEqual(str(error), "tensordot expects dims >= 0, but got dims=" + str(dims)) 9140 if dims > min(a.dim(), b.dim()): 9141 self.assertEqual(str(error), "tensordot expects dims < ndim_a or ndim_b, but got dims=" + str(dims)) 9142 9143 def test_torch_functional_tensordot_tensor(self): 9144 def tensordot_dims_tensor(a: torch.Tensor, b: torch.Tensor, dims: torch.Tensor): 9145 return torch.tensordot(a, b, dims=dims) 9146 9147 a = torch.arange(120.).reshape(2, 3, 4, 5) 9148 b = torch.arange(840.).reshape(4, 5, 6, 7) 9149 dims = torch.tensor([2]) 9150 self.checkScript(tensordot_dims_tensor, (a, b, dims)) 9151 9152 a = torch.arange(60.).reshape(3, 4, 5) 9153 b = torch.arange(24.).reshape(4, 3, 2) 9154 dims = torch.tensor([[1, 0], [0, 1]], dtype=torch.long) 9155 self.checkScript(tensordot_dims_tensor, (a, b, dims)) 9156 9157 def test_torch_functional_tensordot_list(self): 9158 def tensordot_dims_list(a: torch.Tensor, b: torch.Tensor, dims: List[List[int]]): 9159 return torch.tensordot(a, b, dims=dims) 9160 9161 a = torch.arange(60.).reshape(3, 4, 5) 9162 b = torch.arange(24.).reshape(4, 3, 2) 9163 dims = [[1, 0], [0, 1]] 9164 self.checkScript(tensordot_dims_list, (a, b, dims)) 9165 9166 def test_torch_functional_tensordot_tuple(self): 9167 def tensordot_dims_tuple(a: torch.Tensor, b: torch.Tensor, dims: Tuple[List[int], List[int]]): 9168 return torch.tensordot(a, b, dims=dims) 9169 9170 a = torch.arange(60.).reshape(3, 4, 5) 9171 b = torch.arange(24.).reshape(4, 3, 2) 9172 dims = ([1, 0], [0, 1]) 9173 self.checkScript(tensordot_dims_tuple, (a, b, dims)) 9174 9175 def test_missing_getstate(self): 9176 class Foo(torch.nn.Module): 9177 def __init__(self) -> None: 9178 super().__init__() 9179 self.x = 1 9180 9181 def forward(self, x): 9182 return x * self.x 9183 9184 @torch.jit.export 9185 def __setstate__(self, state): 9186 self.x = state[0] 9187 self.training = state[1] 9188 9189 with self.assertRaisesRegex(RuntimeError, "getstate"): 9190 scripted = torch.jit.script(Foo()) 9191 9192 def test_inlining_cleanup(self): 9193 def foo(x): 9194 return F.linear(x, x) 9195 9196 @torch.jit.script 9197 def fee(x): 9198 return foo(x) 9199 9200 # inlining optimizations should have cleaned up linear if statement 9201 self.run_pass("inline", fee.graph) 9202 FileCheck().check_not("prim::If").run(fee.graph) 9203 9204 @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 9205 def test_pack_unpack_nested(self): 9206 class SubSubMod(torch.jit.ScriptModule): 9207 def __init__(self) -> None: 9208 super().__init__() 9209 self.buf = nn.Buffer(torch.ones(3, 4) * 3) 9210 9211 @torch.jit.script_method 9212 def _pack(self): 9213 self.buf.set_(torch.zeros(1)) 9214 9215 @torch.jit.script_method 9216 def _unpack(self): 9217 self.buf.set_(torch.ones(3, 4) * 3) 9218 9219 @torch.jit.script_method 9220 def forward(self, x): 9221 return x + self.buf 9222 9223 class SubMod(torch.jit.ScriptModule): 9224 def __init__(self) -> None: 9225 super().__init__() 9226 self.buf = nn.Buffer(torch.ones(3, 4) * 2) 9227 self.ssm = SubSubMod() 9228 9229 @torch.jit.script_method 9230 def _pack(self): 9231 self.buf.set_(torch.zeros(1)) 9232 9233 @torch.jit.script_method 9234 def _unpack(self): 9235 self.buf.set_(torch.ones(3, 4) * 2) 9236 9237 @torch.jit.script_method 9238 def forward(self, x): 9239 return self.ssm(x + self.buf) 9240 9241 class Mod(torch.jit.ScriptModule): 9242 def __init__(self) -> None: 9243 super().__init__() 9244 self.submod = SubMod() 9245 self.buf = nn.Buffer(torch.ones(3, 4) * 1) 9246 9247 @torch.jit.script_method 9248 def _pack(self): 9249 self.buf.set_(torch.zeros(1)) 9250 9251 @torch.jit.script_method 9252 def _unpack(self): 9253 self.buf.set_(torch.ones(3, 4)) 9254 9255 @torch.jit.script_method 9256 def forward(self, x): 9257 return self.submod(x + self.buf) 9258 9259 m = Mod() 9260 torch.testing.assert_close(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6) 9261 m.apply(lambda s: s._pack()) 9262 torch.testing.assert_close(m(torch.zeros(3, 4)), torch.zeros(3, 4)) 9263 m.apply(lambda s: s._unpack()) 9264 torch.testing.assert_close(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6) 9265 9266 def test_torch_any(self): 9267 def fn(x): 9268 return torch.any(x) 9269 9270 def fn1(x, dim: int): 9271 return torch.any(x, dim) 9272 9273 self.checkScript(fn, (torch.randn(3, 4), )) 9274 self.checkScript(fn, (torch.empty(3), )) 9275 self.checkScript(fn, (torch.empty(1), )) 9276 self.checkScript(fn, (torch.ones(3, 4),)) 9277 self.checkScript(fn, (torch.zeros(5, 7, 1),)) 9278 self.checkScript(fn1, (torch.empty(3, 4), -2)) 9279 self.checkScript(fn1, (torch.randn(3, 8), 1)) 9280 self.checkScript(fn1, (torch.zeros(3, 6, 9), -3)) 9281 self.checkScript(fn1, (torch.empty(5), 0)) 9282 9283 def test_any(self): 9284 def fn(x: List[int]): 9285 return any(x) 9286 9287 def fn1(x: List[float]): 9288 return any(x) 9289 9290 def fn2(x: List[bool]): 9291 return any(x) 9292 9293 def fn3(x: List[str]): 9294 return any(x) 9295 9296 self.checkScript(fn, ([0, 0, 0, 0], )) 9297 self.checkScript(fn, ([0, 3, 0], )) 9298 self.checkScript(fn, ([], )) 9299 self.checkScript(fn1, ([1.0, 2.0, 3.0], )) 9300 self.checkScript(fn1, ([0.0, 0.0, 0.0], )) 9301 self.checkScript(fn1, ([0, 0, 0], )) 9302 self.checkScript(fn1, ([], )) 9303 self.checkScript(fn2, ([True, False, False], )) 9304 self.checkScript(fn2, ([False, False, False], )) 9305 self.checkScript(fn2, ([True, True, True, True], )) 9306 self.checkScript(fn2, ([], )) 9307 self.checkScript(fn3, (["", "", ""], )) 9308 self.checkScript(fn3, (["", "", "", "-1"], )) 9309 self.checkScript(fn3, ([], )) 9310 9311 def test_script_module_not_tuple(self): 9312 class M(torch.jit.ScriptModule): 9313 __constants__ = ['mods'] 9314 9315 def __init__(self) -> None: 9316 super().__init__() 9317 self.mods = 1 9318 9319 @torch.jit.script_method 9320 def forward(self, v): 9321 for m in self.mods: 9322 print(m) 9323 return v 9324 with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"): 9325 M() 9326 9327 def test_attr_module_constants(self): 9328 class M2(torch.jit.ScriptModule): 9329 def __init__(self, mod_list): 9330 super().__init__() 9331 self.mods = mod_list 9332 9333 @torch.jit.script_method 9334 def forward(self, x): 9335 return self.mods.forward(x) 9336 9337 with torch.jit.optimized_execution(False): 9338 m = M2(nn.Sequential(nn.ReLU())) 9339 self.assertExportImportModule(m, (torch.randn(2, 2),)) 9340 9341 def test_script_sequential_for(self): 9342 class Sub(torch.jit.ScriptModule): 9343 def __init__(self) -> None: 9344 super().__init__() 9345 self.weight = nn.Parameter(torch.randn(2)) 9346 9347 @torch.jit.script_method 9348 def forward(self, thing): 9349 return self.weight + thing 9350 9351 class M(torch.jit.ScriptModule): 9352 def __init__(self) -> None: 9353 super().__init__() 9354 self.mods = nn.Sequential(Sub(), Sub(), Sub()) 9355 9356 @torch.jit.script_method 9357 def forward(self, v): 9358 for m in self.mods: 9359 v = m(v) 9360 return v 9361 9362 @torch.jit.script_method 9363 def forward2(self, v): 9364 return self.mods(v) 9365 9366 with torch.jit.optimized_execution(False): 9367 i = torch.empty(2) 9368 m = M() 9369 o = m(i) 9370 v = i 9371 for sub in m.mods._modules.values(): 9372 v = sub(v) 9373 self.assertEqual(o, v) 9374 9375 o2 = m.forward2(i) 9376 self.assertEqual(o2, v) 9377 9378 def test_script_sequential_sliced_iteration(self): 9379 class seq_mod(nn.Module): 9380 def __init__(self) -> None: 9381 super().__init__() 9382 self.layers = [nn.ReLU(), nn.ReLU(), nn.ReLU()] 9383 self.layers = nn.Sequential(*self.layers) 9384 9385 def forward(self, input): 9386 x = self.layers[0].forward(input) 9387 for layer in self.layers[1:3]: 9388 x = layer.forward(x) 9389 for layer in self.layers[2:]: 9390 x = layer.forward(x) 9391 return x 9392 9393 seq = seq_mod() 9394 self.checkModule(seq, [torch.tensor([-2, 1, -1, 2])]) 9395 9396 def test_script_sequential_orderdict(self): 9397 class M(torch.jit.ScriptModule): 9398 def __init__(self) -> None: 9399 super().__init__() 9400 self.mods = nn.Sequential(OrderedDict([ 9401 ("conv", nn.Conv2d(1, 20, 5)), 9402 ("relu", nn.ReLU()) 9403 ])) 9404 9405 @torch.jit.script_method 9406 def forward(self, input): 9407 return self.mods(input) 9408 9409 m = M() 9410 self.assertTrue('mods.conv.weight' in m.state_dict().keys()) 9411 9412 def test_script_sequential_multi_output_fail(self): 9413 class Sub(torch.jit.ScriptModule): 9414 def __init__(self) -> None: 9415 super().__init__() 9416 self.weight = nn.Parameter(torch.randn(2)) 9417 9418 @torch.jit.script_method 9419 def forward(self, thing): 9420 return self.weight + thing 9421 9422 class ReturnMulti(torch.jit.ScriptModule): 9423 @torch.jit.script_method 9424 def forward(self, x): 9425 return x, x, x 9426 9427 class HaveSequential(torch.jit.ScriptModule): 9428 def __init__(self) -> None: 9429 super().__init__() 9430 self.someseq = nn.Sequential( 9431 Sub(), 9432 ReturnMulti(), 9433 Sub() 9434 ) 9435 9436 @torch.jit.script_method 9437 def forward(self, x): 9438 return self.someseq(x) 9439 9440 with self.assertRaisesRegex(RuntimeError, "(Tensor, Tensor, Tensor)"): 9441 with torch.jit.optimized_execution(False): 9442 hs = HaveSequential() 9443 i = torch.empty(2) 9444 hs(i) 9445 9446 @_tmp_donotuse_dont_inline_everything 9447 def test_script_sequential_in_mod_list(self): 9448 class Sub(torch.jit.ScriptModule): 9449 def __init__(self) -> None: 9450 super().__init__() 9451 self.weight = nn.Parameter(torch.randn(2)) 9452 9453 @torch.jit.script_method 9454 def forward(self, thing): 9455 return self.weight + thing 9456 9457 class M(torch.jit.ScriptModule): 9458 def __init__(self) -> None: 9459 super().__init__() 9460 self.mods = nn.ModuleList([Sub(), nn.Sequential(Sub(), nn.Sequential(Sub(), Sub()), Sub())]) 9461 9462 @torch.jit.script_method 9463 def forward(self, v): 9464 for mod in self.mods: 9465 v = mod(v) 9466 return v 9467 9468 m = M() 9469 graph = str(m.graph) 9470 self.assertTrue(graph.count("prim::CallMethod") == 2) 9471 self.assertTrue("python" not in graph) 9472 9473 @_tmp_donotuse_dont_inline_everything 9474 def test_script_nested_mod_list(self): 9475 class Sub(torch.jit.ScriptModule): 9476 def __init__(self) -> None: 9477 super().__init__() 9478 self.weight = nn.Parameter(torch.randn(2)) 9479 9480 @torch.jit.script_method 9481 def forward(self, thing): 9482 return self.weight + thing 9483 9484 class M(torch.jit.ScriptModule): 9485 def __init__(self) -> None: 9486 super().__init__() 9487 self.mods = nn.ModuleList([nn.ModuleList([Sub()]), nn.Sequential(Sub()), nn.ModuleList([Sub(), Sub()])]) 9488 9489 @torch.jit.script_method 9490 def forward(self, v): 9491 for mod in self.mods: 9492 for m in mod: 9493 v = m(v) 9494 return v 9495 9496 m = M() 9497 graph = str(m.graph) 9498 self.assertTrue(graph.count("prim::CallMethod") == 4) 9499 self.assertTrue("python" not in graph) 9500 9501 def test_constant_as_attr(self): 9502 class M(torch.jit.ScriptModule): 9503 __constants__ = ['dim'] 9504 9505 def __init__(self) -> None: 9506 super().__init__() 9507 self.dim = 1 9508 9509 @torch.jit.script_method 9510 def forward(self, v): 9511 return torch.cat([v, v, v], dim=self.dim) 9512 v = torch.zeros(1, 1) 9513 with torch.jit.optimized_execution(False): 9514 self.assertEqual(torch.cat([v, v, v], dim=1), M()(v)) 9515 9516 class StarTestSumStarred(torch.nn.Module): 9517 def __init__(self) -> None: 9518 super(TestScript.StarTestSumStarred, self).__init__() 9519 9520 def forward(self, *inputs): 9521 output = inputs[0] 9522 for i in range(1, len(inputs)): 9523 output += inputs[i] 9524 return output 9525 9526 class StarTestReturnThree(torch.nn.Module): 9527 def __init__(self) -> None: 9528 super(TestScript.StarTestReturnThree, self).__init__() 9529 9530 def forward(self, rep): 9531 return rep, rep, rep 9532 9533 def test_script_star_expr(self): 9534 9535 class M2(torch.jit.ScriptModule): 9536 def __init__(self) -> None: 9537 super().__init__() 9538 self.m = torch.jit.trace(TestScript.StarTestSumStarred(), 9539 (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3))) 9540 self.g = torch.jit.trace(TestScript.StarTestReturnThree(), torch.ones(4, 3)) 9541 9542 @torch.jit.script_method 9543 def forward(self, rep): 9544 tup = self.g(rep) 9545 return self.m(*tup) 9546 9547 m = M2() 9548 self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3)) 9549 9550 def test_script_star_expr_string(self): 9551 class M2(torch.jit.ScriptModule): 9552 def __init__(self) -> None: 9553 super().__init__() 9554 self.m = torch.jit.trace(TestScript.StarTestSumStarred(), 9555 (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3))) 9556 self.g = torch.jit.trace(TestScript.StarTestReturnThree(), torch.ones(4, 3)) 9557 9558 self.define(''' 9559 def forward(self, rep): 9560 tup = self.g(rep) 9561 return self.m(*tup) 9562 ''') 9563 9564 m = M2() 9565 self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3)) 9566 9567 class StarTestSumAndReturnThree(torch.nn.Module): 9568 def __init__(self) -> None: 9569 super(TestScript.StarTestSumAndReturnThree, self).__init__() 9570 9571 def forward(self, *inputs): 9572 output = inputs[0] 9573 for i in range(1, len(inputs)): 9574 output += inputs[i] 9575 return output, output, output 9576 9577 def test_script_star_assign(self): 9578 class M2(torch.jit.ScriptModule): 9579 def __init__(self) -> None: 9580 super().__init__() 9581 self.g = torch.jit.trace(TestScript.StarTestSumAndReturnThree(), torch.ones(4, 3)) 9582 self.define(''' 9583 def forward(self, rep): 9584 head, *tail = self.g(rep) 9585 return head 9586 ''') 9587 9588 m = M2() 9589 self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3)) 9590 9591 def test_script_module_star_assign2(self): 9592 class M2(torch.jit.ScriptModule): 9593 def __init__(self) -> None: 9594 super().__init__() 9595 self.g = torch.jit.trace( 9596 TestScript.StarTestSumAndReturnThree(), 9597 (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)), 9598 _force_outplace=True) 9599 self.define(''' 9600 def forward(self, rep): 9601 *head, tail = self.g(rep, rep, rep) 9602 return tail 9603 ''') 9604 9605 m = M2() 9606 self.assertEqual(m(torch.ones(4, 3)), 3 * torch.ones(4, 3)) 9607 9608 def test_script_module_star_assign2_inplace(self): 9609 class M2(torch.jit.ScriptModule): 9610 def __init__(self) -> None: 9611 super().__init__() 9612 self.g = torch.jit.trace( 9613 TestScript.StarTestSumAndReturnThree(), 9614 (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)), 9615 _force_outplace=False) 9616 self.define(''' 9617 def forward(self, rep): 9618 *head, tail = self.g(rep, rep, rep) 9619 return tail 9620 ''') 9621 9622 m = M2() 9623 # since forward() makes three aliases to the input `rep` before passing 9624 # it to StarTestSumAndReturnThree(), in-place behavior will be different 9625 # than the above out of place. 9626 self.assertEqual(m(torch.ones(4, 3)), 4 * torch.ones(4, 3)) 9627 9628 def test_script_module_star_assign_fail_pythonop(self): 9629 9630 with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"): 9631 class M2(torch.jit.ScriptModule): 9632 def __init__(self) -> None: 9633 super().__init__() 9634 9635 @torch.jit.ignore 9636 def myfunc(): 9637 return torch.zeros(1, 2, 3), torch.zeros(1, 2, 3) 9638 9639 self.define(''' 9640 def forward(self, rep): 9641 a, *b = myfunc() 9642 return a 9643 ''') 9644 9645 m = M2() 9646 m(torch.zeros(4, 3)) 9647 9648 def test_script_module_star_assign_fail_builtin(self): 9649 with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"): 9650 class M2(torch.jit.ScriptModule): 9651 def __init__(self) -> None: 9652 super().__init__() 9653 9654 self.define(''' 9655 def forward(self, rep): 9656 a, *b = torch.neg(rep) 9657 return a 9658 ''') 9659 9660 m = M2() 9661 m(torch.zeros(4, 3)) 9662 9663 def test_script_pack_padded_sequence(self): 9664 from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 9665 9666 def pack_padded_pad_packed_script(x, seq_lens): 9667 x = pack_padded_sequence(x, seq_lens) 9668 x, lengths = pad_packed_sequence(x) 9669 return x, lengths 9670 9671 T, B, C = 3, 5, 7 9672 x = torch.ones((T, B, C)) 9673 seq_lens = torch.tensor([3, 3, 2, 2, 1]) 9674 # set padding value so we can test equivalence 9675 for b in range(B): 9676 if seq_lens[b] < T: 9677 x[seq_lens[b]:, b, :] = 0 9678 9679 eager_seq, eager_lengths = pack_padded_pad_packed_script(x, seq_lens) 9680 with torch._jit_internal._disable_emit_hooks(): 9681 scripted_pack_padded_seq = torch.jit.script(pack_padded_pad_packed_script) 9682 script_seq, script_lengths = scripted_pack_padded_seq(x, seq_lens) 9683 self.assertEqual(eager_seq, script_seq) 9684 self.assertEqual(eager_lengths, script_lengths) 9685 9686 class ExperimentalLSTM(torch.nn.Module): 9687 def __init__(self, input_dim, hidden_dim): 9688 super().__init__() 9689 9690 def forward(self, input): 9691 # type: (Tensor) 9692 packed = pack_padded_sequence( 9693 input=input, lengths=torch.tensor([1, 2]), enforce_sorted=False 9694 ) 9695 output, lengths = pad_packed_sequence( 9696 sequence=packed, total_length=2 9697 ) 9698 # lengths is flipped, so is output 9699 return output[0] 9700 9701 lstm = ExperimentalLSTM(input_dim=2, hidden_dim=2) 9702 9703 with torch._jit_internal._disable_emit_hooks(): 9704 self.checkModule(lstm, [torch.ones(2, 2)]) 9705 9706 def test_script_pad_sequence_pack_sequence(self): 9707 from torch.nn.utils.rnn import pad_sequence, pack_sequence, pad_packed_sequence 9708 9709 def pad_sequence_func(tensor_list, batch_first=False, padding_value=0.0, padding_side="right"): 9710 # type: (List[Tensor], bool, float, str) -> Tensor 9711 return pad_sequence(tensor_list, batch_first, padding_value, padding_side) 9712 9713 def pack_sequence_func(tensor_list, enforce_sorted=True): 9714 # type: (List[Tensor], bool) -> Tensor 9715 return pad_packed_sequence(pack_sequence(tensor_list, enforce_sorted))[0] 9716 9717 ones3 = torch.ones(3, 5) 9718 ones4 = torch.ones(4, 5) 9719 ones5 = torch.ones(5, 5) 9720 tensor1 = torch.tensor([1, 2, 3]) 9721 tensor2 = torch.tensor([4, 5]) 9722 tensor3 = torch.tensor([6]) 9723 with torch._jit_internal._disable_emit_hooks(): 9724 self.checkScript(pad_sequence_func, 9725 ([ones3, ones4, ones5],)) 9726 self.checkScript(pad_sequence_func, 9727 ([ones3, ones4, ones5], True)) 9728 self.checkScript(pad_sequence_func, 9729 ([ones3, ones4, ones5], True, 2.5)) 9730 self.checkScript(pad_sequence_func, 9731 ([ones3, ones4, ones5], True, 2.5, "left")) 9732 self.checkScript(pad_sequence_func, 9733 ([ones3, ones4, ones5], False, 2.5, "left")) 9734 self.checkScript(pack_sequence_func, 9735 ([tensor1, tensor2, tensor3],)) 9736 self.checkScript(pack_sequence_func, 9737 ([tensor1, tensor2, tensor3], False)) 9738 9739 def test_script_get_tracing_state(self): 9740 def test_if_tracing(x): 9741 if torch._C._get_tracing_state(): 9742 return x + 1 9743 else: 9744 return x - 1 9745 9746 inp = torch.randn(3, 3) 9747 self.checkScript(test_if_tracing, (inp,)) 9748 9749 def test_script_is_tracing(self): 9750 def test_is_tracing(x): 9751 if torch.jit.is_tracing(): 9752 return x + 1 9753 else: 9754 return x - 1 9755 9756 inp = torch.randn(3, 3) 9757 self.checkScript(test_is_tracing, (inp,)) 9758 9759 def test_is_scripting(self): 9760 def foo(): 9761 return torch.jit.is_scripting() 9762 9763 self.assertFalse(foo()) 9764 scripted = torch.jit.script(foo) 9765 self.assertTrue(scripted()) 9766 9767 def test_comment_ignore_indent(self): 9768 class Model(torch.nn.Module): 9769 def __init__(self) -> None: 9770 # useless comment that is not indented correctly # noqa: E115 9771 super().__init__() 9772 9773 def forward(self): 9774 return 5 9775 9776 # should compile without an error 9777 self.checkModule(Model(), ()) 9778 9779 def test_script_outputs(self): 9780 with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"): 9781 @torch.jit.script 9782 def foo(a): 9783 c, d = a + a 9784 return c + d 9785 9786 @torch.jit.script 9787 def return3(): 9788 return 1, 2, 3 9789 9790 with self.assertRaisesRegex(RuntimeError, "too many values to unpack"): 9791 @torch.jit.script 9792 def bind2(): 9793 a, b = return3() 9794 print(a) 9795 print(b) 9796 9797 @unittest.skipIf(not RUN_CUDA, "requires CUDA") 9798 def test_script_get_device_cuda(self): 9799 @torch.jit.script 9800 def foo(a): 9801 return a.get_device() 9802 9803 v = torch.randn(1, device='cuda') 9804 self.assertEqual(foo(v), 0) 9805 9806 def test_script_chunk(self): 9807 @torch.jit.script 9808 def foo(a): 9809 b, c = torch.chunk(a, dim=0, chunks=2) 9810 return b 9811 v = torch.rand(10, 3) 9812 self.assertEqual(torch.chunk(v, dim=0, chunks=2)[0], foo(v)) 9813 9814 def test_script_copy(self): 9815 class M(torch.nn.Module): 9816 __annotations__ = { 9817 "val": Optional[torch.Tensor] 9818 } 9819 9820 def __init__(self) -> None: 9821 super().__init__() 9822 self.val = None 9823 9824 def some_method(self): 9825 return 3 9826 9827 def forward(self, x): 9828 # type: (Tensor) -> Tensor 9829 self.val = x + self.some_method() 9830 return x 9831 9832 m = torch.jit.script(M()) 9833 # test copy 9834 copy.copy(m) 9835 copy.deepcopy(m) 9836 9837 def test_script_forward_method_replacement(self): 9838 # We want to support the use case of attaching a different `forward` method 9839 class LowLevelModule(torch.nn.Module): 9840 def forward(self, input: torch.Tensor): 9841 # Generic forward dispatch 9842 return self.forward_pytorch(input) * 2 9843 9844 class TestModule(LowLevelModule): 9845 def __init__(self) -> None: 9846 super().__init__() 9847 # Replace the forward method 9848 self.forward = types.MethodType(LowLevelModule.forward, self) 9849 9850 def forward_pytorch(self, input: torch.Tensor): 9851 return torch.tensor(123) 9852 9853 def forward(self, input: torch.Tensor): 9854 # Should not use this forward method 9855 raise AssertionError("This method should not be used") 9856 return self.forward_pytorch(input) 9857 9858 m = TestModule() 9859 self.assertEqual(m(torch.tensor(1)), torch.tensor(246)) 9860 9861 m_scripted = torch.jit.script(m) 9862 self.assertEqual(m_scripted(torch.tensor(1)), torch.tensor(246)) 9863 9864 def test_python_call_non_tensor(self): 9865 def foo(a, b, c): 9866 # type: (Tensor, int, Tuple[Tensor, int]) -> Tuple[int, Tensor] 9867 d, e = c 9868 return b + e, a + d 9869 9870 @torch.jit.script 9871 def bar(): 9872 x = torch.ones(3, 4) 9873 a, b = foo(x, 3, (x, 3)) 9874 return a, b 9875 9876 self.assertEqual((6, torch.ones(3, 4) + 1), bar()) 9877 9878 def test_python_call_non_tensor_wrong(self): 9879 with self.assertRaisesRegex(RuntimeError, r"but instead got value of type tuple"): 9880 @torch.jit.ignore 9881 def foo(): 9882 # type: () -> Tensor 9883 return ((3, 4),) # noqa: T484 9884 9885 @torch.jit.script 9886 def bar(): 9887 return foo() 9888 9889 bar() 9890 9891 def test_if_different_type(self): 9892 with self.assertRaisesRegex(RuntimeError, "c0 is set to type " 9893 "int in the true branch and type " 9894 "float in the false branch"): 9895 @torch.jit.script 9896 def diff_type_used(): 9897 if 1 == 2: 9898 c0 = 1 9899 else: 9900 c0 = 1.0 9901 return c0 9902 9903 with self.assertRaisesRegex(RuntimeError, "Variable 'c0' previously had type float"): 9904 @torch.jit.script 9905 def diff_existing_type(x): 9906 c0 = 1.0 9907 if 1 == 2: 9908 c0 = 1 9909 print(x) 9910 return x 9911 9912 @torch.jit.script 9913 def diff_type_unused(): 9914 if 1 == 1: 9915 c0 = 1 9916 print(c0) 9917 else: 9918 c0 = 1.0 9919 print(c0) 9920 return 1 9921 9922 def test_if_not_defined_error(self): 9923 with self.assertRaisesRegex(RuntimeError, "c0 is not defined in the false branch"): 9924 @torch.jit.script 9925 def test(): 9926 if 1 == 1: 9927 c0 = 1 9928 return c0 9929 with self.assertRaisesRegex(RuntimeError, "c0 is not defined in the true branch"): 9930 @torch.jit.script 9931 def test2(): 9932 if 1 == 1: 9933 pass 9934 else: 9935 c0 = 1 9936 return c0 9937 9938 def test_if_list_cat(self): 9939 # testing that different length lists don't throw error on cat in shape prop 9940 @torch.jit.script 9941 def test_list(x): 9942 if bool(x.sum() < 1): 9943 c = [x, x] 9944 else: 9945 c = [x, x, x] 9946 return torch.cat(c) 9947 9948 b = torch.zeros(2, 4) 9949 _propagate_shapes(test_list.graph, (b,), False) 9950 9951 def test_if_supertype(self): 9952 @torch.jit.script 9953 def tensor_unifying(x, y, z): 9954 # testing dynamic is appropriately set for y and z 9955 if bool(x): 9956 x, y, z = x + 1, y, z 9957 else: 9958 x, y, z = x + 1, x, y 9959 9960 return x, y, z 9961 9962 a = torch.zeros(2, 2, dtype=torch.float) 9963 b = torch.zeros(2, 4, dtype=torch.long) 9964 c = torch.zeros(2, 4, dtype=torch.float) 9965 9966 graph = _propagate_shapes(tensor_unifying.graph, (a, b, c), False) 9967 if_outputs = list(graph.findNode("prim::If").outputs()) 9968 self.assertTrue(if_outputs[0].type().str() == "Float(*, *, requires_grad=0, device=cpu)") 9969 self.assertTrue(if_outputs[1].type().str() == "Tensor(*, *, requires_grad=0, device=cpu)") 9970 self.assertTrue(if_outputs[2].type().str() == "Tensor(*, *, requires_grad=0, device=cpu)") 9971 9972 def test_list_unify(self): 9973 # allowing a unififed int?[] would cause a runtime error b/c 9974 # the index operation expects int?[] to be a generic list, 9975 # but in the true branch the IValue will be a int list 9976 with self.assertRaisesRegex(RuntimeError, "int[] in the true branch and type None[]"): 9977 @torch.jit.script 9978 def list_optional_fails(x): 9979 # type: (bool) -> Optional[int] 9980 if x: 9981 y = [1] 9982 else: 9983 y = [None] # noqa: T484 9984 return y[0] 9985 9986 @torch.jit.script 9987 def list_tensors(x): 9988 # type: (bool) -> Tuple[Tensor, List[Tensor]] 9989 if x: 9990 a = torch.zeros([1, 1]) 9991 y = [a] 9992 else: 9993 a = torch.zeros([1, 2]) 9994 y = [a] 9995 return a, y 9996 9997 self.run_pass('constant_propagation', list_tensors.graph) 9998 m = self.createFunctionFromGraph(list_tensors.graph) 9999 # testing that tensor type of lists is unified 10000 self.getExportImportCopy(m) 10001 10002 @skipIfTorchDynamo("Not a TorchDynamo suitable test") 10003 @_inline_everything 10004 def test_import_constants_not_specialized(self): 10005 class Mod(torch.nn.Module): 10006 def forward(self, x): 10007 return torch.cat(2 * [x], dim=0) 10008 10009 class ScriptMod(torch.jit.ScriptModule): 10010 def __init__(self, mod): 10011 super().__init__() 10012 x = torch.zeros(1, 3) 10013 mod_fn = lambda : mod(x) # noqa: E731 10014 self.mod = torch.jit.trace(mod_fn, ()) 10015 10016 @torch.jit.script_method 10017 def forward(self): 10018 return self.mod() 10019 10020 cm = ScriptMod(Mod()) 10021 # specialized tensor in graph 10022 FileCheck().check("Float(1, 3, strides=[3, 1], requires_grad=0, device=cpu)").run(cm.forward.graph) 10023 buffer = io.BytesIO() 10024 torch.jit.save(cm, buffer) 10025 buffer.seek(0) 10026 # when tensor is loaded as constant it isnt specialized 10027 cm_load = torch.jit.load(buffer) 10028 FileCheck().check_not("Float(1, 3)").run(cm_load.forward.graph) 10029 10030 @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 10031 def test_type_annotations_repeated_list(self): 10032 @torch.jit.script 10033 def float_fn(x, y): 10034 # type: (float, BroadcastingList3[float]) -> List[float] 10035 return y 10036 self.assertEqual(float_fn(2.0, 1.0), float_fn(2.0, [1.0, 1.0, 1.0])) 10037 self.assertEqual(float_fn(2.0, 1.0), float_fn(2.0, (1.0, 1.0, 1.0))) 10038 10039 @torch.jit.script 10040 def float_fn_call(): 10041 print(float_fn(1.0, 1.0)) 10042 print(float_fn(1.0, (1.0, 1.0, 1.0))) 10043 10044 @torch.jit.script 10045 def int_fn(x): 10046 # type: (BroadcastingList3[int]) -> List[int] 10047 return x 10048 self.assertEqual(int_fn(1), int_fn([1, 1, 1])) 10049 self.assertEqual(int_fn(1), int_fn((1, 1, 1))) 10050 10051 @torch.jit.script 10052 def int_fn_call(): 10053 print(int_fn(1)) 10054 print(int_fn((1, 1, 1))) 10055 10056 with self.assertRaisesRegex(RuntimeError, "must be a positive integer:"): 10057 @torch.jit.script # noqa: T484 10058 def fn(x): 10059 # type: (BroadcastingListx[int]) -> List[int] # noqa: T484 10060 return x 10061 10062 # using CU so that flake8 error on int[2] is not raised (noqa not working) 10063 with self.assertRaisesRegex(RuntimeError, "Unknown type constructor"): 10064 cu = torch.jit.CompilationUnit(''' 10065 def nested(x, y): 10066 # type: (int, Tuple[int, int[2]]) -> List[int] 10067 return x # noqa: T484 10068 ''') 10069 10070 @torch.jit.script 10071 def f(x: BroadcastingList2[int]): 10072 return x 10073 10074 out = f(1) 10075 self.assertTrue(isinstance(out[0], int)) 10076 self.assertEqual(out, [1, 1]) 10077 10078 def test_ntuple_builtins(self): 10079 from torch.nn.modules.utils import _single, _pair, _triple, _quadruple 10080 10081 def test_ints(): 10082 return _single(1), _pair(2), _triple(3), _quadruple(4) 10083 10084 def test_floats(): 10085 return _single(1), _pair(2.1), _triple(3.1), _quadruple(4.1) 10086 10087 self.checkScript(test_ints, ()) 10088 self.checkScript(test_floats, ()) 10089 10090 def test_embedding_renorm_grad_error(self): 10091 # Testing that the builtin call to embedding_renorm_ correctly throws 10092 # Error when .backward() is called on its input 10093 10094 def embedding_norm(input, embedding_matrix, max_norm): 10095 F.embedding(input, embedding_matrix, max_norm=0.01) 10096 10097 @torch.jit.script 10098 def embedding_norm_script(input, embedding_matrix, max_norm): 10099 # type: (Tensor, Tensor, float) -> None 10100 F.embedding(input, embedding_matrix, max_norm=0.01) 10101 10102 for _ in [embedding_norm, embedding_norm_script]: 10103 input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]]) 10104 embedding_matrix = torch.randn(10, 3) 10105 10106 var1 = torch.randn(10, 3, requires_grad=True) 10107 var2 = var1.detach().requires_grad_() 10108 output1 = var1 * embedding_matrix 10109 output2 = var2 * embedding_matrix 10110 10111 output1.sum().backward() 10112 10113 ignore = F.embedding(input, embedding_matrix, max_norm=0.01) 10114 with self.assertRaisesRegex(RuntimeError, "modified"): 10115 output2.sum().backward() 10116 10117 def test_type_annotations(self): 10118 def fn(x, y): 10119 # type: (Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor] 10120 return x, x * 2, x * 3 10121 10122 with self.assertRaisesRegex(RuntimeError, r"need 4 values .* found only 3"): 10123 @torch.jit.script 10124 def script_fn(x): 10125 x, y, z, w = fn(x, x) 10126 10127 with self.assertRaisesRegex(RuntimeError, r"too many values .* need 2 but found 3"): 10128 @torch.jit.script 10129 def script_fn2(x): 10130 x, y = fn(x, x) 10131 10132 def fn_unpack(x): 10133 y, z, w = fn(x, x) 10134 return y 10135 10136 def fn_index(x): 10137 q = fn(x, x) 10138 return x 10139 10140 def fn_string(str, strpair): 10141 # type: (str, Tuple[str, str]) -> Tuple[str, int, str, str] 10142 str1, str2 = strpair 10143 return str, 2, str1, str2 10144 10145 x = torch.ones(2, 2) 10146 self.checkScript(fn_unpack, (x,), optimize=True) 10147 self.checkScript(fn_index, (x,), optimize=True) 10148 self.checkScript(fn_string, ("1", ("3", "4")), optimize=True) 10149 10150 def test_type_annotations_varargs(self): 10151 @torch.jit.ignore 10152 def fn_varargs(x, *args): 10153 return args[0] if args else x 10154 10155 def fn1(x, y, z): 10156 return fn_varargs(x) 10157 10158 def fn2(x, y, z): 10159 return fn_varargs(x, y) 10160 10161 def fn3(x, y, z): 10162 return fn_varargs(x, y, z) 10163 10164 x, y, z = (torch.randn(2, 2) for _ in range(3)) 10165 self.checkScript(fn1, (x, y, z), optimize=True) 10166 self.checkScript(fn2, (x, y, z), optimize=True) 10167 self.checkScript(fn3, (x, y, z), optimize=True) 10168 10169 def test_type_annotation_py3(self): 10170 code = dedent(""" 10171 import torch 10172 from torch import Tensor 10173 from typing import Tuple 10174 10175 def fn(x : torch.Tensor, y : Tensor, z) -> Tuple[Tensor, Tensor, Tensor]: 10176 return (x, y + z, z) 10177 """) 10178 10179 with tempfile.TemporaryDirectory() as tmp_dir: 10180 script_path = os.path.join(tmp_dir, 'script.py') 10181 with open(script_path, 'w') as f: 10182 f.write(code) 10183 fn = get_fn('test_type_annotation_py3', script_path) 10184 fn = torch.jit.ignore(fn) 10185 10186 with self.assertRaisesRegex(RuntimeError, r"Expected a value of type 'Tensor' for argument" 10187 r" 'x' but instead found type 'Tuple\[Tensor,"): 10188 @torch.jit.script 10189 def bad_fn(x): 10190 x, y = fn((x, x), x, x) 10191 return y 10192 10193 with self.assertRaisesRegex(RuntimeError, r"too many values .* need 2 but found 3"): 10194 @torch.jit.script 10195 def bad_fn2(x): 10196 x, y = fn(x, x, x) 10197 return y 10198 10199 with self.assertRaisesRegex(RuntimeError, r"need 4 values .* found only 3"): 10200 @torch.jit.script 10201 def bad_fn3(x): 10202 x, y, z, w = fn(x, x, x) 10203 return y 10204 10205 def good_fn(x): 10206 y, z, w = fn(x, x, x) 10207 return y, z, w 10208 10209 self.checkScript(good_fn, (torch.ones(2, 2),), optimize=True) 10210 10211 def test_type_annotation_module(self): 10212 class BaseModule(torch.jit.ScriptModule): 10213 @torch.jit.ignore 10214 def foo(self, x): 10215 # type: (Tensor) -> Tensor 10216 return x + 1 10217 10218 @torch.jit.ignore 10219 def bar(self, x, y): 10220 # type: (Tensor, Tensor) -> Tuple[Tensor, Tensor] 10221 return x + y, y 10222 10223 @torch.jit.ignore 10224 def baz(self, x, y): 10225 return x 10226 10227 class ModuleTooMany(BaseModule): 10228 @torch.jit.script_method 10229 def method(self, x): 10230 return self.foo(x, x) 10231 10232 class ModuleTooFew(BaseModule): 10233 @torch.jit.script_method 10234 def method(self, x): 10235 return self.bar(x) 10236 10237 class ModuleTooManyAssign(BaseModule): 10238 @torch.jit.script_method 10239 def method(self, x): 10240 y, z, w = self.bar(x, x) 10241 return x 10242 10243 class ModuleDefault(BaseModule): 10244 @torch.jit.script_method 10245 def method(self, x): 10246 y = self.baz(x) 10247 return x 10248 10249 with self.assertRaisesRegex(RuntimeError, "Expected at most 2 arguments but found 3"): 10250 ModuleTooMany() 10251 with self.assertRaisesRegex(RuntimeError, "Argument y not provided"): 10252 ModuleTooFew() 10253 with self.assertRaisesRegex(RuntimeError, "need 3 values .* found only 2"): 10254 ModuleTooManyAssign() 10255 with self.assertRaisesRegex(RuntimeError, "Argument y not provided."): 10256 ModuleDefault() 10257 10258 def test_type_inferred_from_empty_annotation(self): 10259 """ 10260 Test that the type inferred from an empty or missing annotation is Torch.Tensor wtih `inferred=true` 10261 """ 10262 @torch.jit.script 10263 def fn(x): 10264 return x 10265 10266 graph = fn.graph 10267 n = next(graph.inputs()) 10268 self.assertTrue(n.type() == torch._C.TensorType.getInferred()) 10269 10270 with self.assertRaisesRegex(RuntimeError, "Inferred 'x' to be of type 'Tensor"): 10271 fn("1") 10272 10273 def test_script_define_order(self): 10274 class M(torch.jit.ScriptModule): 10275 10276 @torch.jit.script_method 10277 def call_foo(self, input): 10278 return self.foo(input) 10279 10280 @torch.jit.script_method 10281 def foo(self, input): 10282 return input + 1 10283 m = M() 10284 self.assertEqual(2, m.call_foo(torch.ones((), dtype=torch.int64))) 10285 10286 def test_script_define_order_recursive_fail(self): 10287 class M(torch.jit.ScriptModule): 10288 10289 @torch.jit.script_method 10290 def call_foo(self, input): 10291 return self.foo(input) 10292 10293 @torch.jit.script_method 10294 def foo(self, input): 10295 self.call_foo(input) 10296 10297 with self.assertRaisesRegex(RuntimeError, 'called recursively'): 10298 M() 10299 10300 def test_script_kwargs_fn_call(self): 10301 class M(torch.jit.ScriptModule): 10302 10303 @torch.jit.script_method 10304 def call_foo(self, input): 10305 return self.foo(input=input, bar=1) 10306 10307 @torch.jit.script_method 10308 def foo(self, bar, input): 10309 # type: (int, Tensor) -> Tensor 10310 return input + bar 10311 m = M() 10312 self.assertEqual(2, m.call_foo(torch.ones((), dtype=torch.int64))) 10313 10314 def test_if_define(self): 10315 @torch.jit.script 10316 def foo(a): 10317 if bool(a == 0): 10318 b = 1 10319 else: 10320 b = 0 10321 return b + 1 10322 10323 @torch.jit.script 10324 def foo2(a): 10325 b = 0 10326 if bool(a == 0): 10327 b = 1 10328 return b + 1 10329 10330 @torch.jit.script 10331 def foo3(a): 10332 b = 1 10333 if bool(a == 0): 10334 c = 4 10335 else: 10336 b = 0 10337 return b + 1 10338 10339 a = torch.ones(1, dtype=torch.long) 10340 b = torch.zeros(1, dtype=torch.long) 10341 self.assertEqual(1, foo(a)) 10342 self.assertEqual(2, foo(b)) 10343 self.assertEqual(1, foo2(a)) 10344 self.assertEqual(2, foo2(b)) 10345 self.assertEqual(1, foo3(a)) 10346 self.assertEqual(2, foo3(b)) 10347 10348 def test_script_module_export_submodule(self): 10349 class M1(torch.jit.ScriptModule): 10350 def __init__(self) -> None: 10351 super().__init__() 10352 self.weight = nn.Parameter(torch.randn(2)) 10353 10354 @torch.jit.script_method 10355 def forward(self, thing): 10356 return self.weight + thing 10357 10358 class M2(torch.jit.ScriptModule): 10359 def __init__(self) -> None: 10360 super().__init__() 10361 # test submodule 10362 self.sub = M1() 10363 self.weight = nn.Parameter(torch.randn(2, 3)) 10364 self.bias = nn.Parameter(torch.randn(2)) 10365 self.define(""" 10366 def hi(self, a): 10367 return self.weight.mm(a) 10368 """) 10369 10370 @torch.jit.script_method 10371 def doit(self, input): 10372 return self.weight.mm(input) 10373 10374 @torch.jit.script_method 10375 def doit2(self, input): 10376 return self.weight.mm(input) 10377 10378 @torch.jit.script_method 10379 def doit3(self, input): 10380 return input + torch.ones([1], dtype=torch.double) 10381 10382 @torch.jit.script_method 10383 def forward(self, input): 10384 a = self.doit(input) 10385 b = self.doit2(input) 10386 c = self.hi(input) 10387 return a + b + self.bias + c 10388 10389 with torch.jit.optimized_execution(False): 10390 m_orig = M2() 10391 m_import = self.getExportImportCopy(m_orig) 10392 10393 input = torch.randn(3, 2) 10394 self.assertEqual(m_orig.doit(input), m_import.doit(input)) 10395 self.assertEqual(m_orig.hi(input), m_import.hi(input)) 10396 self.assertEqual(m_orig.doit3(input), m_import.doit3(input)) 10397 self.assertEqual(m_orig.forward(input), m_import.forward(input)) 10398 10399 @slowTest 10400 def test_compile_module_with_constant(self): 10401 class Double(nn.Module): 10402 def __init__(self, downsample=None): 10403 super().__init__() 10404 10405 def forward(self, input): 10406 return input * 2 10407 10408 class Mod(nn.Module): 10409 __constants__ = ['downsample'] 10410 10411 def __init__(self, downsample=None): 10412 super().__init__() 10413 self.downsample = downsample 10414 10415 def forward(self, input): 10416 if self.downsample is not None: 10417 return self.downsample(input) 10418 return input 10419 10420 none_mod = torch.jit.script(Mod(None)) 10421 double_mod = torch.jit.script(Mod(Double())) 10422 self.assertEqual(none_mod(torch.tensor(1)), torch.tensor(1)) 10423 self.assertEqual(double_mod(torch.tensor(1)), torch.tensor(1) * 2) 10424 10425 def test_device_kwarg(self): 10426 from torch import device 10427 10428 def f(): 10429 return device(type='cuda'), torch.device(type='cpu') 10430 self.checkScript(f, ()) 10431 10432 def test_script_module_export_tensor_type(self): 10433 class M(torch.jit.ScriptModule): 10434 def __init__(self, type): 10435 super().__init__() 10436 self.param = torch.nn.Parameter(torch.zeros((5, 5), dtype=type).random_()) 10437 10438 @torch.jit.script_method 10439 def foo(self): 10440 return self.param 10441 10442 with torch.jit.optimized_execution(False): 10443 for type in [torch.float, torch.double]: 10444 m_orig = M(type) 10445 m_import = self.getExportImportCopy(m_orig) 10446 # check to make sure the storage wasn't resized 10447 self.assertTrue(m_orig.param.storage().size() == 25) 10448 self.assertEqual(m_orig.foo(), m_import.foo()) 10449 self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype) 10450 10451 @unittest.skipIf(not RUN_CUDA, "testing cuda tensors require CUDA") 10452 def test_script_module_export_tensor_cuda(self): 10453 class M(torch.jit.ScriptModule): 10454 10455 def __init__(self) -> None: 10456 super().__init__() 10457 self.param = torch.nn.Parameter(torch.zeros((5, 5), device='cuda:0').random_()) 10458 10459 @torch.jit.script_method 10460 def foo(self): 10461 return self.param 10462 10463 m_orig = M() 10464 m_import = self.getExportImportCopy(m_orig) 10465 # check to make sure the storage wasn't resized 10466 self.assertTrue(m_orig.param.storage().size() == 25) 10467 self.assertTrue(m_import.foo().device == torch.device('cuda:0')) 10468 self.assertEqual(m_orig.foo(), m_import.foo()) 10469 self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype) 10470 10471 def test_script_module_export_blocks(self): 10472 class M(torch.jit.ScriptModule): 10473 def __init__(self, n, m): 10474 super().__init__() 10475 self.weight = torch.nn.Parameter(torch.rand(n, m)) 10476 10477 @torch.jit.script_method 10478 def forward(self, input): 10479 if bool(input.sum() > 0): 10480 output = self.weight.mv(input) 10481 else: 10482 output = self.weight + input 10483 return output 10484 10485 m_orig = M(200, 200) 10486 m_import = self.getExportImportCopy(m_orig) 10487 10488 t = torch.rand(200) 10489 self.assertEqual(m_orig(t), m_import(t)) 10490 10491 def test_script_module_export_shared_storage(self): 10492 class M(torch.jit.ScriptModule): 10493 10494 def __init__(self) -> None: 10495 super().__init__() 10496 self.param1 = torch.nn.Parameter(torch.rand(5, 5)) 10497 self.param2 = torch.nn.Parameter(self.param1[3]) 10498 self.param3 = torch.nn.Parameter(torch.rand(5, 5)) 10499 self.param4 = torch.nn.Parameter(torch.rand(11, 5)[1:6]) 10500 10501 @torch.jit.script_method 10502 def foo(self): 10503 return self.param1 + self.param2 + self.param3 + self.param4 10504 10505 with torch.jit.optimized_execution(False): 10506 m_orig = M() 10507 m_import = self.getExportImportCopy(m_orig) 10508 10509 self.assertEqual(m_orig.foo(), m_import.foo()) 10510 10511 self.assertTrue(m_import.param1.storage().data_ptr() == m_import.param2.storage().data_ptr()) 10512 self.assertTrue(m_import.param1.storage().data_ptr() != m_import.param3.storage().data_ptr()) 10513 10514 def test_sequential_intermediary_types(self): 10515 class A(torch.nn.Module): 10516 def forward(self, x): 10517 return x + 3 10518 10519 class B(torch.nn.Module): 10520 def forward(self, x): 10521 return {"1": x} 10522 10523 class C(torch.nn.Module): 10524 def __init__(self) -> None: 10525 super().__init__() 10526 self.foo = torch.nn.Sequential(A(), B()) 10527 10528 def forward(self, x): 10529 return self.foo(x) 10530 10531 self.checkModule(C(), (torch.tensor(1),)) 10532 10533 def test_ellipsis_const_mid(self): 10534 def ellipsize(x): 10535 # type: (Tensor) -> List[int] 10536 return x[2, Ellipsis, 0:4, 4:8].size() 10537 10538 dummy = torch.zeros(8, 8, 8, 8, 8) 10539 self.checkScript(ellipsize, (dummy,), optimize=True) 10540 10541 def test_ellipsis_const_mid_select(self): 10542 def ellipsize(x): 10543 # type: (Tensor) -> List[int] 10544 return x[2, Ellipsis, 4, 4, 4:8, 2].size() 10545 10546 dummy = torch.zeros(8, 8, 8, 8, 8, 8, 8) 10547 self.checkScript(ellipsize, (dummy,), optimize=True) 10548 10549 def test_ellipsis_const_start(self): 10550 def ellipsize(x): 10551 # type: (Tensor) -> List[int] 10552 return x[Ellipsis, 0:4, 4:8].size() 10553 dummy = torch.zeros(8, 8, 8, 8, 8) 10554 self.checkScript(ellipsize, (dummy,), optimize=True) 10555 10556 def test_ellipsis_const_end(self): 10557 def ellipsize(x): 10558 # type: (Tensor) -> List[int] 10559 return x[0:4, 2, Ellipsis].size() 10560 dummy = torch.zeros(8, 8, 8, 8, 8) 10561 self.checkScript(ellipsize, (dummy,), optimize=True) 10562 10563 def test_ellipsis_mid(self): 10564 def ellipsize(x): 10565 # type: (Tensor) -> List[int] 10566 return x[2, ..., 0:4, 4:8].size() 10567 10568 dummy = torch.zeros(8, 8, 8, 8, 8) 10569 self.checkScript(ellipsize, (dummy,), optimize=True) 10570 10571 def test_ellipsis_mid_select(self): 10572 def ellipsize(x): 10573 # type: (Tensor) -> List[int] 10574 return x[2, ..., 4, 4, 4:8, 2].size() 10575 10576 dummy = torch.zeros(8, 8, 8, 8, 8, 8, 8) 10577 self.checkScript(ellipsize, (dummy,), optimize=True) 10578 10579 def test_ellipsis_start(self): 10580 def ellipsize(x): 10581 # type: (Tensor) -> List[int] 10582 return x[..., 0:4, 4:8].size() 10583 dummy = torch.zeros(8, 8, 8, 8, 8) 10584 self.checkScript(ellipsize, (dummy,), optimize=True) 10585 10586 def test_ellipsis_end(self): 10587 def ellipsize(x): 10588 # type: (Tensor) -> List[int] 10589 return x[0:4, 2, ...].size() 10590 dummy = torch.zeros(8, 8, 8, 8, 8) 10591 self.checkScript(ellipsize, (dummy,), optimize=True) 10592 10593 def test_torch_manual_seed(self): 10594 with freeze_rng_state(): 10595 def test(): 10596 torch.manual_seed(2) 10597 return torch.rand(1) 10598 10599 script = torch.jit.script(test) 10600 self.assertEqual(test(), script()) 10601 graph = script.graph_for() 10602 FileCheck().check("aten::manual_seed").run(graph) 10603 10604 @skipIfTorchDynamo("Not a TorchDynamo suitable test") 10605 def test_index_select_shape_prop(self): 10606 10607 @torch.jit.script 10608 def foo(x, y): 10609 return torch.index_select(x, index=y, dim=1) 10610 10611 a = torch.zeros(2, 2) 10612 b = torch.zeros(4, dtype=torch.long) 10613 torch._C._jit_pass_complete_shape_analysis(foo.graph, (a, b), False) 10614 FileCheck().check("Float(2, 4, strides=[4, 1], requires_grad=0, device=cpu)").run(str(foo.graph)) 10615 10616 def test_shape_analysis_loop(self): 10617 def foo(a, b, x): 10618 c = a 10619 # on the first iteration of the loop it appears that 10620 # c should have a expand to the size of b 10621 # but on the second+ iterations, there is no broadcast and the 10622 # sizes are different. 10623 # previously this would cause the compiler to (1) enter an infinite 10624 # loop trying to compute the shape, and (2) insert invalid 10625 # broadcasts. 10626 # this test ensure we don't regress on these issues 10627 for _ in range(2): 10628 a = c + b 10629 c = x 10630 b = x 10631 return a 10632 10633 self.checkScript(foo, (torch.zeros(1), torch.zeros(4), torch.zeros(5)), optimize=False) 10634 10635 def test_intlist_args(self): 10636 def func_1(x): 10637 return torch.nn.functional.adaptive_avg_pool1d(x, 1) 10638 10639 def func_2(x): 10640 return torch.nn.functional.adaptive_avg_pool1d(x, output_size=1) 10641 10642 def func_3(x): 10643 return torch.nn.functional.adaptive_avg_pool1d(x, output_size=[1]) 10644 10645 x = torch.randn(8, 8, 8) 10646 self.checkScript(func_1, [x], optimize=True) 10647 self.checkScript(func_2, [x], optimize=True) 10648 self.checkScript(func_3, [x], optimize=True) 10649 10650 def test_wrong_implicit_expand(self): 10651 10652 @_trace(torch.zeros(3), torch.zeros(1)) 10653 def foo(a, b): 10654 return a + b 10655 10656 a = torch.rand(4) 10657 b = torch.rand(4) 10658 self.assertEqual(a + b, foo(a, b)) 10659 10660 def test_builtin_args_fails(self): 10661 10662 with self.assertRaisesRegex(RuntimeError, 'Argument self not provided'): 10663 @torch.jit.script 10664 def f1(a): 10665 torch.sum(foo=4) 10666 10667 with self.assertRaisesRegex(RuntimeError, 'specified twice'): 10668 @torch.jit.script 10669 def f2(a): 10670 torch.sum(a, self=a) 10671 10672 with self.assertRaisesRegex(RuntimeError, 'not provided'): 10673 @torch.jit.script 10674 def f3(a): 10675 torch.sum(dim=4) 10676 10677 with self.assertRaisesRegex(RuntimeError, 'for argument \'tensors\' but instead found type \'Tensor'): 10678 @torch.jit.script 10679 def f4(a): 10680 torch.cat(a) 10681 10682 with self.assertRaisesRegex(RuntimeError, r'argument \'tensors\' but instead found type \'List\[int\]'): 10683 @torch.jit.script 10684 def f5(a): 10685 torch.cat([3]) 10686 10687 with self.assertRaisesRegex(RuntimeError, r'Expected a value of' 10688 r' type \'List\[int\]\' for argument' 10689 r' \'size\' but instead found type ' 10690 r'\'List\[Union\[List\[int\], int\]\]'): 10691 @torch.jit.script 10692 def f6(a): 10693 a.expand(size=[3, [4]]) 10694 10695 def test_builtin_args(self): 10696 10697 def t0(a): 10698 # default arg dim 10699 return torch.cat([a, a]) 10700 10701 self.checkScript(t0, (torch.zeros(1, 1),)) 10702 10703 def t1(a): 10704 # keywords out of order 10705 return torch.cat(dim=1, tensors=[a, a]) 10706 10707 self.checkScript(t1, (torch.zeros(1, 1, 2),)) 10708 10709 def t2(a): 10710 # mix const/non-const attributes 10711 if 1 == 1: 10712 b = 1 10713 else: 10714 b = 0 10715 return torch.sum(a, dim=b, keepdim=False) 10716 10717 self.checkScript(t2, (torch.zeros(1, 1, 2),)) 10718 10719 def test_parser_type_annotations(self): 10720 cu = torch.jit.CompilationUnit(''' 10721 def foo(x : Tensor, y : Tuple[Tuple[Tensor, Tensor], Tensor]) -> Tuple[Tensor, Tensor]: 10722 return x, x 10723 ''') 10724 10725 self.assertExpected(str(cu.foo.schema)) 10726 10727 def test_parser_type_annotations_comment(self): 10728 cu = torch.jit.CompilationUnit(''' 10729 def foo(x, y): 10730 # type: (Tensor, Tuple[Tuple[Tensor, Tensor], Tensor]) -> Tuple[Tensor, Tensor] 10731 return x, x 10732 ''') 10733 10734 self.assertExpected(str(cu.foo.schema)) 10735 10736 def test_parser_type_annotations_unknown_type(self): 10737 with self.assertRaisesRegex(RuntimeError, "Unknown type name 'Foo'"): 10738 cu = torch.jit.CompilationUnit(''' 10739 def foo(x : Tensor, y : Tuple[Tuple[Foo, Tensor], Tensor]) -> Tuple[Tensor, Tensor]: 10740 return x, x 10741 ''') 10742 10743 def test_parser_type_annotations_subscript_non_ident(self): 10744 with self.assertRaisesRegex(RuntimeError, r'Subscripted type must be a type identifier'): 10745 cu = torch.jit.CompilationUnit(''' 10746 def foo(x : Tensor, y : Tuple[Tensor, Tensor][Tensor]) -> Tuple[Tensor, Tensor]: 10747 return x, x 10748 ''') 10749 10750 def test_parser_type_annotations_subscript_tensor(self): 10751 with self.assertRaisesRegex(RuntimeError, r'Unknown type constructor Tensor'): 10752 cu = torch.jit.CompilationUnit(''' 10753 def foo(x : Tensor, y : Tensor[Tensor, Tensor]) -> Tuple[Tensor, Tensor]: 10754 return x, x 10755 ''') 10756 10757 def test_parser_type_annotations_incompatible_expression(self): 10758 with self.assertRaisesRegex(RuntimeError, r'Expression of type \+ cannot be used in a type expression'): 10759 cu = torch.jit.CompilationUnit(''' 10760 def foo(x : Tensor, y : Tuple[3 + 4, Tensor]) -> Tuple[Tensor, Tensor]: 10761 return x, x 10762 ''') 10763 10764 def test_gather_dynamic_index(self): 10765 def t(x): 10766 gather1 = x[0] 10767 idx = 0 + 1 10768 gather2 = x[idx] 10769 return gather1 + gather2 10770 10771 self.checkScript(t, (torch.zeros(3, 2, 3),)) 10772 10773 def test_torch_ignore_conversion_to_none(self): 10774 class A(torch.nn.Module): 10775 @torch.jit.ignore 10776 def ignored(self, a: int) -> None: 10777 l: int = len([2 for i in range(a) if i > 2]) 10778 return 10779 10780 def forward(self) -> int: 10781 a: int = 4 10782 b: int = 5 10783 self.ignored(a) 10784 return a + b 10785 10786 class B(torch.nn.Module): 10787 @torch.jit.ignore 10788 def ignored(self, a: int): 10789 l: int = len([2 for i in range(a) if i > 2]) 10790 return 10791 10792 def forward(self) -> int: 10793 a: int = 4 10794 b: int = 5 10795 self.ignored(a) 10796 return a + b 10797 10798 modelA = torch.jit.script(A()) 10799 self.assertEqual(modelA(), 9) 10800 10801 modelB = torch.jit.script(B()) 10802 self.assertEqual(modelB(), 9) 10803 10804 def test_addmm_grad(self): 10805 """ This test checks several things: 10806 1. An expand node was inserted before the addmm operating on the 10807 bias term. 10808 2. The fused form of addmm appears in the ultimate graph that's 10809 executed. 10810 3. A sum op was emitted for accumulating gradients along the 0th 10811 (expanded) dimension of the bias term. 10812 4. The correct symbolic representation for the backward pass of the 10813 mm operator was emitted (x.t() -> mm) 10814 10815 TODO: we should actually check these conditions once we have a way 10816 to dump the GraphExecutor state. Namely the processed forward graph 10817 and the backward graph. 10818 """ 10819 @torch.jit.script 10820 def addmm_grad_test(b, x, w): 10821 return torch.addmm(b, x, w) 10822 10823 # Initialize param and input values 10824 w_init = torch.rand(2, 5) 10825 b_init = torch.rand(5) 10826 x = torch.rand(3, 2) 10827 10828 # Clone trainable params 10829 b = b_init.clone() 10830 b.requires_grad_() 10831 w = w_init.clone() 10832 w.requires_grad_() 10833 10834 # Test symbolic differentiation 10835 y = addmm_grad_test(b, x, w) 10836 y.sum().backward() 10837 10838 # clone params for autograd reference 10839 b_ref = b_init.clone() 10840 b_ref.requires_grad_() 10841 w_ref = w_init.clone() 10842 w_ref.requires_grad_() 10843 y_ref = torch.addmm(b_ref, x, w_ref) 10844 y_ref.sum().backward() 10845 10846 self.assertEqual(w.grad, w_ref.grad) 10847 self.assertEqual(b.grad, b_ref.grad) 10848 10849 @unittest.skipIf(not RUN_CUDA, "running tests on cuda to verify cudnn fix") 10850 def test_batch_norm_inference_backward_cuda(self): 10851 with enable_profiling_mode_for_profiling_tests(): 10852 class MyBatchNorm(torch.nn.Module): 10853 def __init__(self, num_features, affine, track_running_stats): 10854 super().__init__() 10855 self.bn = torch.nn.BatchNorm2d( 10856 num_features, 1e-5, affine=affine, track_running_stats=track_running_stats).float() 10857 10858 def forward(self, x: torch.Tensor): 10859 o = self.bn(x) 10860 o = torch.nn.functional.relu(o) 10861 return o 10862 10863 batch = 4 10864 c = 2 10865 hw = 3 10866 # Initialize param and input values 10867 x_init = torch.randn(batch, c, hw, hw, dtype=torch.float).cuda() 10868 grad = torch.randn(batch, c, hw, hw, dtype=torch.float).cuda() 10869 10870 training = False 10871 affine = True 10872 track_running_stats = True 10873 10874 module = torch.jit.script(MyBatchNorm(c, affine, track_running_stats)).cuda() 10875 ref_module = MyBatchNorm(c, affine, track_running_stats).cuda() 10876 module.eval() 10877 ref_module.eval() 10878 10879 jit_module = torch.jit.script(module) 10880 ref_module.load_state_dict(module.state_dict()) 10881 10882 x = x_init.detach().clone() 10883 x.requires_grad_() 10884 x_ref = x_init.detach().clone() 10885 x_ref.requires_grad_() 10886 10887 # Test symbolic differentiation 10888 # Run Forward and Backward thrice to trigger autodiff graph 10889 for i in range(0, 3): 10890 y = jit_module(x) 10891 y.backward(grad) 10892 x.grad.zero_() 10893 10894 module.bn.running_mean.zero_() 10895 module.bn.running_var.fill_(1.0) 10896 ref_module.bn.running_mean.zero_() 10897 ref_module.bn.running_var.fill_(1.0) 10898 10899 # run jitted module 10900 y = jit_module(x) 10901 y.backward(grad) 10902 # reference computation 10903 y_ref = ref_module(x_ref) 10904 y_ref.backward(grad) 10905 10906 self.assertEqual(y_ref, y) 10907 self.assertEqual(x.grad, x_ref.grad) 10908 self.assertEqual(module.bn.running_mean, ref_module.bn.running_mean) 10909 self.assertEqual(module.bn.running_var, ref_module.bn.running_var) 10910 10911 def test_zeros(self): 10912 class M(torch.jit.ScriptModule): 10913 __constants__ = ['d'] 10914 10915 def __init__(self) -> None: 10916 super().__init__() 10917 self.d = torch.device('cpu') 10918 10919 @torch.jit.script_method 10920 def create(self): 10921 return torch.zeros([1, 1, 2], dtype=torch.float, device=self.d, layout=torch.strided) 10922 10923 r = M().create() 10924 self.assertEqual(r.dtype, torch.float) 10925 self.assertEqual(torch.zeros([1, 1, 2], dtype=torch.float), r) 10926 10927 def fn(): 10928 return torch.zeros((1, 2, 3)) 10929 10930 self.checkScript(fn, ()) 10931 10932 def test_vararg_zeros(self): 10933 def foo(): 10934 return torch.zeros(3, 4, 5, dtype=torch.int) 10935 10936 self.checkScript(foo, ()) 10937 10938 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "the original version of test_rand") 10939 def test_rand(self): 10940 def test_rand(): 10941 a = torch.rand([3, 4]) 10942 return a + 1.0 - a 10943 10944 self.checkScript(test_rand, ()) 10945 fn = torch.jit.script(test_rand) 10946 out = fn() 10947 self.assertEqual(out.dtype, torch.get_default_dtype()) 10948 g = fn.graph_for() 10949 # Testing shape analysis correctly setting type 10950 if GRAPH_EXECUTOR != ProfilingMode.SIMPLE: 10951 FileCheck().check("Double(*, *, requires_grad=0, device=cpu)") \ 10952 .check_not("Float(*, *, requires_grad=0, device=cpu)").run(g) 10953 10954 @torch.jit.script 10955 def randint(): 10956 return torch.randint(0, 5, [1, 2]) 10957 out = randint() 10958 self.assertEqual(out.dtype, torch.int64) 10959 if GRAPH_EXECUTOR != ProfilingMode.SIMPLE: 10960 FileCheck().check("Long(*, *, requires_grad=0, device=cpu)") \ 10961 .check_not("Float(*, *, requires_grad=0, device=cpu)") \ 10962 .check_not("Double(*, *, requires_grad=0, device=cpu)") \ 10963 .run(randint.graph_for()) 10964 10965 @unittest.skipIf(not RUN_CUDA, "no CUDA") 10966 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "skip if profiling isn't enabled") 10967 def test_autodiff_complex(self): 10968 def foo(x: torch.Tensor, y: torch.Tensor, W: torch.Tensor): 10969 return torch.exp(torch.mm(torch.complex(x, y), W.cfloat())) 10970 10971 @torch.jit.script 10972 def jitted_foo(x: torch.Tensor, y: torch.Tensor, W: torch.Tensor): 10973 return torch.exp(torch.mm(torch.complex(x, y), W.cfloat())) 10974 10975 x = torch.randn(128, 16, dtype=torch.float32, device='cuda:0') 10976 y = torch.randn(128, 16, dtype=torch.float32, device='cuda:0') 10977 W = torch.randn(16, 1, dtype=torch.float32, device='cuda:0', requires_grad=True) 10978 W.data /= 4 10979 10980 with enable_profiling_mode_for_profiling_tests(): 10981 for i in range(4): 10982 self.assertTrue((foo(x, y, W).grad_fn is None) == (jitted_foo(x, y, W).grad_fn is None)) 10983 10984 10985 def test_linear_grad(self): 10986 with enable_profiling_mode_for_profiling_tests(): 10987 def t(x: torch.Tensor, w: torch.Tensor, b: Optional[torch.Tensor]): 10988 return torch.nn.functional.linear(x, w, b) 10989 10990 x_init = torch.randn(4, 2) 10991 w_init = torch.randn(3, 2) 10992 b_init = torch.randn(3) 10993 grad = torch.randn(4, 3) 10994 10995 with disable_autodiff_subgraph_inlining(): 10996 # script module 10997 jit_t = torch.jit.script(t) 10998 10999 x = x_init.detach().requires_grad_() 11000 w = w_init.detach().requires_grad_() 11001 b = b_init.detach().requires_grad_() 11002 x_ref = x_init.detach().requires_grad_() 11003 w_ref = w_init.detach().requires_grad_() 11004 b_ref = b_init.detach().requires_grad_() 11005 11006 # profiling/optimization runs 11007 jit_o = jit_t(x, w, b) 11008 jit_o.backward(grad) 11009 jit_o = jit_t(x, w, b) 11010 jit_o.backward(grad) 11011 11012 x.grad.zero_() 11013 w.grad.zero_() 11014 b.grad.zero_() 11015 jit_o = jit_t(x, w, b) 11016 jit_o.backward(grad) 11017 o = t(x_ref, w_ref, b_ref) 11018 o.backward(grad) 11019 11020 self.assertEqual(jit_o, o) 11021 self.assertEqual(x.grad, x_ref.grad) 11022 self.assertEqual(w.grad, w_ref.grad) 11023 self.assertEqual(b.grad, b_ref.grad) 11024 11025 x.grad.zero_() 11026 w.grad.zero_() 11027 x_ref.grad.zero_() 11028 w_ref.grad.zero_() 11029 jit_o = jit_t(x, w, None) 11030 jit_o.backward(grad) 11031 o = t(x_ref, w_ref, None) 11032 o.backward(grad) 11033 11034 self.assertEqual(jit_o, o) 11035 self.assertEqual(x.grad, x_ref.grad) 11036 self.assertEqual(w.grad, w_ref.grad) 11037 11038 @skipIfTorchDynamo("TorchDynamo doesn't support profile") 11039 @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "the profiling version of test_rand") 11040 def test_rand_profiling(self): 11041 def test_rand(): 11042 a = torch.rand([3, 4]) 11043 return a + 1.0 - a 11044 11045 # Testing shape analysis correctly setting type 11046 with enable_profiling_mode_for_profiling_tests(): 11047 with num_profiled_runs(1): 11048 fn = torch.jit.script(test_rand) 11049 out = fn() 11050 graph_str = torch.jit.last_executed_optimized_graph() 11051 self.assertEqual(out.dtype, torch.float) 11052 FileCheck().check("Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)") \ 11053 .check_not("Double(3, 4, strides=[4, 1], requires_grad=0, device=cpu)").run(graph_str) 11054 11055 # fn = self.checkScript(test_rand, ()) 11056 # out = fn() 11057 # self.assertEqual(out.dtype, torch.float) 11058 11059 @torch.jit.script 11060 def randint(): 11061 return torch.randint(0, 5, [1, 2]) 11062 11063 with enable_profiling_mode_for_profiling_tests(): 11064 with num_profiled_runs(1): 11065 out = randint() 11066 graph_str = torch.jit.last_executed_optimized_graph() 11067 self.assertEqual(out.dtype, torch.int64) 11068 FileCheck().check("profiled_type=Long(1, 2, strides=[2, 1], requires_grad=0, device=cpu)").run(graph_str) 11069 11070 11071 def test_erase_number_types(self): 11072 def func(a): 11073 b = 7 + 1 + 3 11074 c = a + b 11075 c += b 11076 return c 11077 11078 graph = torch.jit.script(func).graph 11079 FileCheck().check("int = prim::Constant").check("aten::add_").run(str(graph)) 11080 self.run_pass("erase_number_types", graph) 11081 FileCheck().check_not("int = prim::Constant").run(str(graph)) 11082 11083 def test_refine_tuple_types(self): 11084 # TupleConstruct output type is not correct here. 11085 graph_str = """ 11086 graph(%a : Float(123), %b : Float(4, 5, 6)): 11087 %c : (Tensor, Tensor) = prim::TupleConstruct(%a, %b) 11088 return (%c) 11089 """ 11090 graph = parse_ir(graph_str) 11091 torch._C._jit_pass_refine_tuple_types(graph) 11092 11093 # After the pass, the output type should've been updated. 11094 self.assertTrue('(Float(123), Float(4, 5, 6))' in str(graph.findNode('prim::TupleConstruct').output())) 11095 11096 # TODO(henrytu): Add test for RefineTypes for NamedTuple when it's supported by IR parser. 11097 11098 def test_remove_dropout(self): 11099 weight_0_shape = (20, 5) 11100 weight_1_shape = (20, 20) 11101 input_shape = (10, 5) 11102 11103 class M(torch.nn.Module): 11104 def __init__(self) -> None: 11105 super().__init__() 11106 self.weight_0 = torch.nn.Parameter(torch.rand(weight_0_shape)) 11107 self.weight_1 = torch.nn.Parameter(torch.rand(weight_1_shape)) 11108 11109 def forward(self, x): 11110 o = F.linear(x, self.weight_0) 11111 o = F.dropout(o, training=self.training) 11112 o = F.linear(o, self.weight_1) 11113 return o 11114 11115 data = torch.rand(input_shape) 11116 m = M() 11117 m = torch.jit.script(m) 11118 with self.assertRaisesRegex(RuntimeError, r'Dropout removal module in training mode is not yet supported'): 11119 torch._C._jit_pass_remove_dropout(m._c) 11120 m.eval() 11121 ref_res = m(data) 11122 # Need to inline otherwise we see instances of Function. 11123 # We would have to use torch.linear/dropout to get around it otherwise. 11124 from torch.jit._recursive import wrap_cpp_module 11125 m = wrap_cpp_module(torch._C._freeze_module(m._c)) 11126 torch._C._jit_pass_remove_dropout(m._c) 11127 res = m(data) 11128 FileCheck().check_not("aten::dropout").run(str(m.graph)) 11129 torch.testing.assert_close(ref_res, res, rtol=1e-2, atol=1e-3) 11130 11131 def test_unfold_zero_dim(self): 11132 def fn(x): 11133 return x.unfold(0, 1, 1) 11134 11135 graph = torch.jit.script(fn).graph 11136 torch._C._jit_pass_complete_shape_analysis(graph, (torch.tensor(0.39),), False) 11137 out_dims = fn(torch.tensor(0.3923)).ndim 11138 self.assertEqual(graph.findNode("aten::unfold").output().type().dim(), out_dims) 11139 11140 def test_mm_batching(self): 11141 11142 with enable_profiling_mode_for_profiling_tests(): 11143 lstm_cell = torch.jit.script(LSTMCellS) 11144 11145 def lstm(x, hx, cx, w_ih, w_hh, b_ih, b_hh): 11146 for i in range(x.size(0)): 11147 hx, cx = lstm_cell(x[i], hx, cx, w_ih, w_hh, b_ih, b_hh) 11148 return hx 11149 11150 slstm = torch.jit.script(lstm) 11151 11152 inputs = get_lstm_inputs('cpu', training=True, seq_length=10) 11153 slstm(*inputs, profile_and_replay=True).sum().backward(retain_graph=True) 11154 if GRAPH_EXECUTOR == ProfilingMode.PROFILING: 11155 slstm(*inputs, profile_and_replay=True).sum().backward() 11156 11157 fw_graph = slstm.graph_for(*inputs) 11158 if GRAPH_EXECUTOR == ProfilingMode.LEGACY: 11159 bw_graph = backward_graph(slstm, diff_graph_idx=0) 11160 self.assertTrue('prim::MMBatchSide' in str(fw_graph)) 11161 self.assertTrue('prim::MMTreeReduce' in str(bw_graph)) 11162 11163 sout = slstm(*inputs) 11164 out = lstm(*inputs) 11165 self.assertEqual(sout, out) 11166 self.assertEqual(torch.autograd.grad(sout.sum(), inputs), 11167 torch.autograd.grad(out.sum(), inputs)) 11168 11169 def test_loop_unrolling(self): 11170 def fn(x): 11171 y = 0 11172 for i in range(int(x)): 11173 y -= i 11174 return y 11175 11176 graph = torch.jit.script(fn).graph 11177 self.run_pass('loop_unrolling', graph) 11178 unroll_factor = 8 11179 FileCheck().check("prim::Loop").check_count("aten::sub", unroll_factor) \ 11180 .check("prim::Loop").check("aten::sub").run(str(graph)) 11181 self.checkScript(fn, (torch.tensor(10),)) 11182 11183 def test_loop_unrolling_const(self): 11184 def fn(): 11185 y = 0 11186 for _ in range(10): 11187 y -= 1 11188 return y 11189 11190 def fn2(): 11191 y = 0 11192 for i in range(10): 11193 y -= i 11194 return y 11195 11196 def check(fn, name): 11197 graph = torch.jit.script(fn).graph 11198 self.run_pass('loop_unrolling', graph) 11199 # entirely unrolled 11200 FileCheck().check_not("prim::Loop'").run(str(graph)) 11201 self.checkScript(fn, ()) 11202 11203 check(fn, 'add_const') 11204 check(fn2, 'add_iter') 11205 11206 def test_loop_unrolling_nested(self): 11207 def fn(x): 11208 y = 0 11209 for _ in range(10): 11210 for j in range(int(x)): 11211 y -= j 11212 return y 11213 11214 graph = torch.jit.script(fn).graph 11215 self.run_pass('loop_unrolling', graph) 11216 # inner loop with 8 subs followed by loop epilogue 11217 unroll_factor = 8 11218 FileCheck().check("prim::Loop").check("prim::Loop").check_count('aten::sub', unroll_factor) \ 11219 .check("prim::Loop").check("aten::sub").run(str(graph)) 11220 self.checkScript(fn, (torch.tensor(10),)) 11221 11222 def test_loop_unroll_unused_counter(self): 11223 def fn(x): 11224 y = 0 11225 for _ in range(int(x)): 11226 y -= 1 11227 return y 11228 11229 graph = torch.jit.script(fn).graph 11230 self.run_pass('loop_unrolling', graph) 11231 FileCheck().check("prim::Loop").check_not("aten::add").check("return") \ 11232 .run(str(graph)) 11233 11234 def test_loop_unroll_negative(self): 11235 def fn(x): 11236 y = 0 11237 for _ in range(int(x)): 11238 y += 1 11239 return y 11240 11241 self.checkScript(fn, (torch.tensor(-20),)) 11242 self.checkScript(fn, (torch.tensor(-2),)) 11243 self.checkScript(fn, (torch.tensor(-1),)) 11244 self.checkScript(fn, (torch.tensor(0),)) 11245 self.checkScript(fn, (torch.tensor(1),)) 11246 self.checkScript(fn, (torch.tensor(2),)) 11247 11248 def test_where(self): 11249 def fn(x, y): 11250 return torch.where(x > 0.0, x, y) 11251 11252 self.checkScript(fn, (torch.randn(3, 2, dtype=torch.float), torch.ones(3, 2, dtype=torch.float))) 11253 11254 def test_where_method(self): 11255 def fn(x, y): 11256 return x.where(x > 0.0, y) 11257 11258 self.checkScript(fn, (torch.randn(3, 2, dtype=torch.float), torch.ones(3, 2, dtype=torch.float))) 11259 11260 def test_union_to_number(self): 11261 @torch.jit.script 11262 def fn(x: Union[int, complex, float], y: Union[int, complex, float]): 11263 return x + y 11264 FileCheck().check(": Scalar):").run(fn.graph) 11265 11266 def test_reassign_module_lhs(self): 11267 with self.assertRaisesRegex(RuntimeError, 'Cannot re-assign \'self\''): 11268 class ReassignSelfLHS(torch.jit.ScriptModule): 11269 @torch.jit.script_method 11270 def forward(self, x): 11271 for _ in range(20): 11272 self = x 11273 return self 11274 11275 ReassignSelfLHS() 11276 11277 def test_reassign_module_rhs(self): 11278 with self.assertRaisesRegex(RuntimeError, 'Cannot re-assign \'x\' to a value of type module'): 11279 class ReassignSelfRHS(torch.jit.ScriptModule): 11280 @torch.jit.script_method 11281 def forward(self, x): 11282 for _ in range(20): 11283 x = self 11284 return self 11285 11286 ReassignSelfRHS() 11287 11288 def test_unknown_builtin(self): 11289 with self.assertRaisesRegex(RuntimeError, 'object has no attribute or method'): 11290 @torch.jit.script 11291 def unknown_builtin(x): 11292 return x.splork(3) 11293 11294 def test_return_tuple(self): 11295 def return_tuple(x): 11296 a = (x, x) 11297 return a, x 11298 self.checkScript(return_tuple, (torch.rand(4),)) 11299 11300 def test_add_tuple_optional(self): 11301 def foo(input: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]) -> Optional[torch.Tensor]: 11302 changed_input = input[0] + 1 11303 value: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]] = (changed_input,) + input[1:] 11304 return value[2] 11305 inp: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]] = (torch.rand(4), None, None) 11306 self.checkScript(foo, (inp,)) 11307 11308 def test_add_tuple_non_optional(self): 11309 def foo(input: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> torch.Tensor: 11310 changed_input = input[0] + 1 11311 value: Tuple[torch.Tensor, torch.Tensor, torch.Tensor] = (changed_input,) + input[1:] 11312 return torch.sum(value[2]) + 4 11313 inp: Tuple[torch.Tensor, torch.Tensor, torch.Tensor] = (torch.rand(4), torch.rand(4), torch.rand(4)) 11314 self.checkScript(foo, (inp,)) 11315 11316 def test_add_tuple_different_types(self): 11317 def foo(a: Tuple[int, float], b: Tuple[int]) -> int: 11318 c: Tuple[int, float, int] = a + b 11319 d: Tuple[int, float, int, int] = c + b 11320 return d[3] + 1 11321 a = (1, 2.0) 11322 b = (3,) 11323 self.checkScript(foo, (a, b)) 11324 11325 def test_add_tuple_same_types(self): 11326 def foo(a: Tuple[int, int], b: Tuple[int, int, int]) -> int: 11327 c: Tuple[int, int, int, int, int] = a + b 11328 d: Tuple[int, int, int, int, int, int, int, int] = c + b 11329 return d[6] - 2 11330 a = (1, 2) 11331 b = (3, 4, 5) 11332 self.checkScript(foo, (a, b)) 11333 11334 def test_method_no_self(self): 11335 with self.assertRaisesRegex(RuntimeError, 'methods must have a self argument'): 11336 class MethodNoSelf(torch.jit.ScriptModule): 11337 @torch.jit.script_method # noqa: B902 11338 def forward(): # noqa: B902 11339 return torch.zeros(3, 4) 11340 11341 MethodNoSelf() 11342 11343 def test_return_stmt_not_at_end(self): 11344 def return_stmt(x): 11345 if bool(x > 3): 11346 return x + 3 11347 else: 11348 return x 11349 self.checkScript(return_stmt, (torch.rand(1),)) 11350 11351 def test_for_in_range(self): 11352 def fn(): 11353 c = 0 11354 for i in range(100): 11355 c += i 11356 return c 11357 self.checkScript(fn, ()) 11358 11359 def test_for_in_range_dynamic(self): 11360 def fn(): 11361 c = 0 11362 for i in range(100): 11363 acc = 0 11364 for j in range(i): 11365 acc += j 11366 c += acc 11367 return c 11368 self.checkScript(fn, (), optimize=False) 11369 11370 def test_for_in_range_ast(self): 11371 def test_script_for_in_range_ast(): 11372 c = 0 11373 for i in range(100): 11374 acc = 0 11375 for j in range(i): 11376 acc += j 11377 c += acc 11378 return c 11379 11380 self.checkScript(test_script_for_in_range_ast, ()) 11381 11382 def test_for_in_range_if_ast(self): 11383 @torch.jit.script 11384 def test_script_for_in_range_if_ast(x): 11385 output = x 11386 for i in range(20): 11387 if i == 0: 11388 output = x.unsqueeze(0) 11389 else: 11390 output = torch.cat((output, x.unsqueeze(0)), dim=0) 11391 return output 11392 inputs = self._make_scalar_vars([0], torch.int64) 11393 11394 self.assertEqual(test_script_for_in_range_if_ast(*inputs).shape[0], 20) 11395 11396 def test_for_in_range_start_end(self): 11397 def fn(): 11398 x = 0 11399 for i in range(7, 100): 11400 x += i 11401 return x 11402 self.checkScript(fn, ()) 11403 11404 def test_for_in_range_start_end_step(self): 11405 def fn(start, end, step): 11406 # type: (int, int, int) -> int 11407 x = 0 11408 for i in range(start, end, step): 11409 x += i 11410 return x 11411 11412 self.checkScript(fn, (7, 100, 7)) 11413 self.checkScript(fn, (7, 100, -7)) 11414 self.checkScript(fn, (2, -11, -3)) 11415 self.checkScript(fn, (2, -11, 3)) 11416 self.checkScript(fn, (2, 10, 3)) 11417 self.checkScript(fn, (-2, -10, -10)) 11418 11419 def test_for_in_range_zero_step(self): 11420 @torch.jit.script 11421 def fn(): 11422 x = 0 11423 for i in range(2, -11, 0): 11424 x += i 11425 return x 11426 11427 with self.assertRaisesRegex(RuntimeError, "must not be zero"): 11428 fn() 11429 11430 def test_range_args(self): 11431 with self.assertRaisesRegex(RuntimeError, r'range expected at least 1 arguments, got 0'): 11432 @torch.jit.script 11433 def range_no_arg(x): 11434 for _ in range(): 11435 x += 1 11436 return x 11437 with self.assertRaisesRegex(RuntimeError, r'found float'): 11438 @torch.jit.script 11439 def range_non_float(): 11440 for i in range(.5): 11441 print(i) 11442 11443 def test_parse_empty_tuple_annotation(self): 11444 cu = torch.jit.CompilationUnit(''' 11445 def foo(x : Tuple[()]) -> Tuple[()]: 11446 return x 11447 ''') 11448 11449 foo_code = cu.find_function('foo').code 11450 FileCheck().check("Tuple[()]").check("Tuple[()]").run(foo_code) 11451 11452 def test_parse_empty_tuple_annotation_element_error(self): 11453 with self.assertRaisesRegex( 11454 RuntimeError, 'Tuple literal in Tuple type annotation must not have any elements'): 11455 cu = torch.jit.CompilationUnit(''' 11456 def foo(x : Tuple[(int,)]) -> Tuple[(int,)]: 11457 return x 11458 ''') 11459 11460 def test_parse_none_type_annotation(self): 11461 cu = torch.jit.CompilationUnit(''' 11462 def foo(x : NoneType) -> NoneType: 11463 return x 11464 ''') 11465 11466 foo_code = cu.find_function('foo').code 11467 FileCheck().check(": NoneType").check("-> NoneType").run(foo_code) 11468 11469 def test_empty_tuple_str(self): 11470 empty_tuple_type = torch._C.TupleType([]) 11471 g = {'Tuple' : typing.Tuple} 11472 python_type = eval(empty_tuple_type.annotation_str, g) 11473 assert python_type is typing.Tuple[()] 11474 11475 def test_tuple_str(self): 11476 tuple1_type = torch._C.TupleType([torch._C.StringType.get()]) 11477 self.assertEqual(tuple1_type.annotation_str, "Tuple[str]") 11478 tuple2_type = torch._C.TupleType([torch._C.StringType.get(), torch._C.StringType.get()]) 11479 self.assertEqual(tuple2_type.annotation_str, "Tuple[str, str]") 11480 11481 def test_dict_str(self): 11482 dict_type = torch._C.DictType(torch._C.StringType.get(), torch._C.StringType.get()) 11483 self.assertEqual(dict_type.annotation_str, "Dict[str, str]") 11484 11485 def test_none_type_str(self): 11486 none_type = torch._C.NoneType.get() 11487 g = {'NoneType' : type(None)} 11488 python_type = eval(none_type.annotation_str, g) 11489 assert python_type is type(None) 11490 11491 @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 11492 def test_zip_enumerate_modulelist(self): 11493 class Sub(torch.nn.Module): 11494 def forward(self, thing): 11495 return thing - 2 11496 11497 class Double(torch.nn.Module): 11498 def forward(self, thing): 11499 return thing * 2 11500 11501 # zipping over two 11502 class ZipModLists(torch.nn.Module): 11503 def __init__(self, mods, mods2): 11504 super().__init__() 11505 self.mods = mods 11506 self.mods2 = mods2 11507 11508 def forward(self, x): 11509 iter = 0 11510 for mod1, mod2 in zip(self.mods, self.mods2): 11511 x = mod2(mod1(x)) 11512 iter += 1 11513 return x, iter 11514 11515 class ZipWithValues(torch.nn.Module): 11516 __constants__ = ['tup_larger', 'tup_smaller'] 11517 11518 def __init__(self, mods, mods2): 11519 super().__init__() 11520 self.mods = mods 11521 self.mods2 = mods2 11522 self.tup_larger = list(range(len(mods2) + 1)) 11523 self.tup_smaller = list(range(max(len(mods2) + 1, 1))) 11524 11525 def forward(self, x): 11526 iter = 0 11527 x2 = x 11528 for val, mod1, mod2 in zip(self.tup_larger, self.mods, self.mods2): 11529 x = mod2(mod1(x)) + val 11530 iter += 1 11531 for val, mod1, mod2 in zip(self.tup_smaller, self.mods, self.mods2): 11532 x2 = mod2(mod1(x2)) + val 11533 iter += 1 11534 return x, iter 11535 11536 mods = nn.ModuleList([Double()]), nn.ModuleList([Double(), Sub(), Sub()]), nn.ModuleList([Sub(), Double()]) 11537 for i in range(len(mods)): 11538 for j in range(len(mods)): 11539 mod = ZipModLists(mods[i], mods[j]) 11540 self.checkModule(mod, (torch.tensor(.5),)) 11541 mod2 = ZipWithValues(mods[i], mods[j]) 11542 self.checkModule(mod2, (torch.tensor(.5),)) 11543 11544 11545 def test_enumerate_modlist_range(self): 11546 class Double(torch.nn.Module): 11547 def forward(self, thing): 11548 return thing * 2 11549 11550 class Mod(torch.nn.Module): 11551 def __init__(self) -> None: 11552 super().__init__() 11553 self.mods = nn.ModuleList([Double(), Double()]) 11554 11555 def forward(self, x): 11556 x2 = x 11557 iter = 0 11558 for val, mod in enumerate(self.mods): 11559 x2 = mod(x2) * val 11560 iter += 1 11561 return iter, x, x2 11562 11563 self.checkModule(Mod(), (torch.tensor(.5),)) 11564 11565 # variable length, modulelist 11566 class Mod2(Mod): 11567 def forward(self, x): 11568 for val, mod in zip(range(int(x)), self.mods): 11569 x = mod(x) * val 11570 return x 11571 11572 with self.assertRaisesRegex(Exception, "that does not have a statically determinable length"): 11573 torch.jit.script(Mod2()) 11574 11575 # modulelist, variable length 11576 class Mod3(Mod): 11577 def forward(self, x): 11578 for val, mod in zip(self.mods, range(int(x))): 11579 x = mod(x) * val 11580 return x 11581 11582 with self.assertRaisesRegex(Exception, "that does not have a statically determinable length"): 11583 torch.jit.script(Mod3()) 11584 11585 def test_for_in_enumerate(self): 11586 def fn(x): 11587 # type: (List[int]) -> int 11588 sum = 0 11589 for (i, v) in enumerate(x): 11590 sum += i * v 11591 11592 return sum 11593 11594 self.checkScript(fn, ([1, 2, 3, 4, 5],)) 11595 11596 def fn_enumerate_start_arg(x): 11597 # type: (List[int]) -> int 11598 sum = 0 11599 for (i, v) in enumerate(x, 1): 11600 sum += i * v 11601 11602 return sum 11603 11604 self.checkScript(fn_enumerate_start_arg, ([1, 2, 3, 4, 5],)) 11605 11606 def fn_enumerate_start_kwarg(x): 11607 # type: (List[int]) -> int 11608 sum = 0 11609 for (i, v) in enumerate(x, start=1): 11610 sum += i * v 11611 11612 return sum 11613 11614 self.checkScript(fn_enumerate_start_kwarg, ([1, 2, 3, 4, 5],)) 11615 11616 def fn_nested_enumerate(x): 11617 # type: (List[int]) -> int 11618 sum = 0 11619 for (i, (j, v)) in enumerate(enumerate(x)): 11620 sum += i * j * v 11621 11622 return sum 11623 11624 self.checkScript(fn_nested_enumerate, ([1, 2, 3, 4, 5],)) 11625 11626 with self.assertRaisesRegex(RuntimeError, r'enumerate expected at least 1 arguments, got 0'): 11627 @torch.jit.script 11628 def enumerate_no_arg(x): 11629 # type: (List[int]) -> int 11630 sum = 0 11631 for _ in enumerate(): 11632 sum += 1 11633 11634 return sum 11635 11636 with self.assertRaisesRegex(RuntimeError, r'enumerate expected at most 2 arguments, got 3'): 11637 @torch.jit.script 11638 def enumerate_too_many_args(x): 11639 # type: (List[int]) -> int 11640 sum = 0 11641 for _ in enumerate(x, x, x): 11642 sum += 1 11643 11644 return sum 11645 11646 def test_list_comprehension_modulelist(self): 11647 class Inner(torch.nn.Module): 11648 def forward(self, x): 11649 return x + 10 11650 11651 class M(torch.nn.Module): 11652 def __init__(self, mod_list): 11653 super().__init__() 11654 self.module_list = mod_list 11655 11656 def forward(self, x): 11657 out = torch.jit.annotate(List[Tensor], [mod(x) for mod in self.module_list]) 11658 return out 11659 11660 mod = M(nn.ModuleList([Inner(), Inner()])) 11661 self.checkModule(mod, (torch.tensor(3),)) 11662 11663 mod = M(nn.ModuleList([])) 11664 torch.jit.script(mod) 11665 11666 class M2(M): 11667 def __init__(self, mod_list): 11668 super().__init__(mod_list) 11669 11670 def forward(self, x): 11671 out = [mod(x) for mod in self.module_list] 11672 return out 11673 11674 mod = M2(nn.ModuleList([Inner(), Inner()])) 11675 self.checkModule(mod, (torch.tensor(3),)) 11676 11677 mod = M2(nn.ModuleList([])) 11678 # defaults to List of Tensor for empty modulelist 11679 self.assertEqual(torch.jit.script(mod)(torch.tensor(.5)), []) 11680 11681 def bad_type_annotation(): 11682 out = torch.jit.annotate(int, [x for x in [1, 2, 3]]) # noqa: C416 11683 return out 11684 11685 with self.assertRaisesRegex(Exception, "Expected an annotation" 11686 " of type List"): 11687 torch.jit.script(bad_type_annotation) 11688 11689 def test_list_comprehension_variable_write(self): 11690 # i in comprehension doesn't write to function scope 11691 def foo(): 11692 i = 1 11693 x = [i if i != 5 else 3 for i in range(7)] # noqa: C416 11694 return i, x 11695 11696 self.assertEqual(foo(), torch.jit.script(foo)()) 11697 11698 def test_for_in_zip(self): 11699 def fn(x, y): 11700 # type: (List[int], List[int]) -> int 11701 sum = 0 11702 for (i, j) in zip(x, y): 11703 sum += i * j 11704 11705 return sum 11706 11707 self.checkScript(fn, ([1, 2, 3, 4, 5], [2, 3, 4, 5, 6])) 11708 11709 def fn_multi_inputs(x, y, z): 11710 # type: (List[int], List[int], List[int]) -> int 11711 sum = 0 11712 for (i, j, k) in zip(x, y, z): 11713 sum += i * j * k 11714 11715 return sum 11716 11717 self.checkScript(fn_multi_inputs, ([1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6])) 11718 11719 def fn_nested_zip(x, y, z): 11720 # type: (List[int], List[int], List[int]) -> int 11721 sum = 0 11722 for (i, (j, k)) in zip(x, zip(y, z)): 11723 sum += i * j * k 11724 11725 return sum 11726 11727 self.checkScript(fn_multi_inputs, ([1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6])) 11728 11729 with self.assertRaisesRegex(RuntimeError, r'zip expected at least 1 arguments, got 0'): 11730 @torch.jit.script 11731 def zip_no_arg(x): 11732 # type: (List[int]) -> int 11733 sum = 0 11734 for _ in zip(): 11735 sum += 1 11736 11737 return sum 11738 11739 with self.assertRaisesRegex(RuntimeError, r'too many values to unpack: need 2 but found 3'): 11740 @torch.jit.script 11741 def fn_nested_zip_wrong_target_assign(x, y, z): 11742 # type: (List[int], List[int], List[int]) -> int 11743 sum = 0 11744 for (i, (j, k)) in zip(x, y, z): 11745 sum += i * j * k 11746 11747 return sum 11748 11749 def test_for_in_zip_enumerate(self): 11750 def fn_zip_enumerate(x, y): 11751 # type: (List[int], List[int]) -> int 11752 sum = 0 11753 for (i, (j, v), k) in zip(x, enumerate(y), range(0, 100)): 11754 sum += i * j * v * k 11755 11756 return sum 11757 11758 self.checkScript(fn_zip_enumerate, ([1, 2, 3, 4], [2, 3, 4, 5])) 11759 11760 def fn_enumerate_zip(x, y): 11761 # type: (List[int], List[int]) -> int 11762 sum = 0 11763 for (i, (j, v)) in enumerate(zip(x, y)): 11764 sum += i * j * v 11765 11766 return sum 11767 11768 self.checkScript(fn_enumerate_zip, ([1, 2, 3, 4], [2, 3, 4, 5])) 11769 11770 def test_for_in_tensors(self): 11771 def test_sizes(x): 11772 sumz = 0 11773 for s in x: 11774 sumz += 1 11775 return sumz 11776 self.checkScript(test_sizes, (torch.rand(5, 4, 3, 2, 1),)) 11777 self.checkScript(test_sizes, (torch.rand(777),)) 11778 self.checkScript(test_sizes, (torch.rand(0),)) 11779 11780 def test_for_in_tensors_rank0(self): 11781 with self.assertRaisesRegex(RuntimeError, "of a 0-d tensor"): 11782 @torch.jit.script 11783 def test_sizes(x): 11784 sumz = 0 11785 for s in x: 11786 sumz += 1 11787 return sumz 11788 11789 test_sizes(torch.tensor(1)) 11790 11791 def test_for_in_tensors_fail_scalar(self): 11792 with self.assertRaisesRegex(RuntimeError, "'float' object is not iterable"): 11793 @torch.jit.script 11794 def test_sizes(x): 11795 # type: (float) -> int 11796 sumz = 0 11797 for s in x: 11798 sumz += 1 11799 return sumz 11800 11801 test_sizes(0.0) 11802 11803 def test_for_in_tensors_nested(self): 11804 def test_sizes(x): 11805 sumz = 0 11806 for n in x: 11807 for t in n: 11808 sumz += 1 11809 return sumz 11810 11811 self.checkScript(test_sizes, (torch.rand(5, 4, 3, 2, 1),)) 11812 11813 # to avoid defining sum_list in multiple tests 11814 def get_sum_list_fn(self): 11815 def sum_list(a): 11816 # type: (List[int]) -> int 11817 sum = 0 11818 for i in a: 11819 sum += i 11820 11821 return sum 11822 11823 return sum_list 11824 11825 def test_sum_list_diff_elms(self): 11826 self.checkScript(self.get_sum_list_fn(), ([1, 2, 3, 4, 5],)) 11827 11828 def test_sum_list_empty(self): 11829 self.checkScript(self.get_sum_list_fn(), ([],)) 11830 11831 def test_sum_list_one(self): 11832 self.checkScript(self.get_sum_list_fn(), ([1],)) 11833 11834 def test_sum_list_literal(self): 11835 11836 def sum_list(): 11837 # type: () -> int 11838 sum = 0 11839 for i in [1, 2, 3, 4, 5]: 11840 sum += i 11841 11842 return sum 11843 11844 self.checkScript(sum_list, ()) 11845 11846 def test_sum_list_wrong_type(self): 11847 11848 with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"): 11849 @torch.jit.script 11850 def sum_list(a): 11851 # type: (int) -> int 11852 sum = 0 11853 for i in a: # noqa: T484 11854 sum += i 11855 11856 return sum 11857 11858 sum_list(1) 11859 11860 def test_list_iterables(self): 11861 with self.assertRaisesRegex(RuntimeError, 'List of iterables is not supported currently'): 11862 cu = torch.jit.CompilationUnit(''' 11863 def list_iterables(x): 11864 for i, j in [2, 3, 4], [5, 6, 7]: 11865 x += i 11866 x += j 11867 return x 11868 ''') 11869 11870 def test_for_in_string(self): 11871 def test_strings(x): 11872 # type: (str) -> str 11873 reverse = "" 11874 for c in x: 11875 reverse = c + reverse 11876 return reverse 11877 11878 self.checkScript(test_strings, ("hello",)) 11879 self.checkScript(test_strings, ("",)) 11880 11881 def test_list_strings(x): 11882 # type: (List[str]) -> str 11883 result = "" 11884 for sub_str in x: 11885 result += sub_str 11886 return result 11887 11888 self.checkScript(test_list_strings, (["hello", "world"],)) 11889 self.checkScript(test_list_strings, (["hello", " ", "world", ""],)) 11890 11891 def test_for_in_dict(self): 11892 def test_dicts(x): 11893 # type: (Dict[str, int]) -> int 11894 sum = 0 11895 for key in x: 11896 sum += x[key] 11897 return sum 11898 11899 self.checkScript(test_dicts, ({"a": 1, "b": 2, "c": 3},)) 11900 11901 def test_dict_keys_values(x): 11902 # type: (Dict[str, int]) -> Tuple[str, int] 11903 key_str = "" 11904 sum = 0 11905 for key in x.keys(): 11906 key_str += key 11907 for val in x.values(): 11908 sum += val 11909 return key_str, sum 11910 11911 self.checkScript(test_dicts, ({"a": 1, "b": 2, "c": 3},)) 11912 11913 def test_for_tuple_unpack(self): 11914 def for_tuple_unpack(x, y): 11915 for i, j in [[3, 4], [5, 6], [7, 8]]: 11916 x += i 11917 y += j 11918 return x, y 11919 11920 self.checkScript(for_tuple_unpack, (torch.tensor(3), torch.tensor(5))) 11921 11922 def nested_tuple_unpack(x, y): 11923 # type: (List[int], List[int]) -> int 11924 sum = 0 11925 for i, (j, k), v in zip(x, enumerate(x), y): 11926 sum += i + j + k + v 11927 return sum 11928 11929 self.checkScript(nested_tuple_unpack, ([1, 3, 5], [2, 4, 6])) 11930 11931 def test_for_tuple_assign(self): 11932 def test_simple_assign(x): 11933 # type: (Tuple[int, float]) -> float 11934 sum = 0.0 11935 for a in x: 11936 sum += float(a) 11937 return sum 11938 11939 self.checkScript(test_simple_assign, ((1, 2.5),)) 11940 11941 def test_tuple_assign(x): 11942 # type: (Tuple[Tuple[int, int], Tuple[int, int]]) -> int 11943 sum = 0 11944 for a in x: 11945 sum += a[0] 11946 sum += a[1] 11947 return sum 11948 11949 self.checkScript(test_tuple_assign, (((1, 2), (4, 7)), )) 11950 11951 def test_single_starred_lhs(self): 11952 with self.assertRaisesRegex(RuntimeError, 'A Starred expression may only appear on the lhs within the presence' 11953 ' of another non-starred expression'): 11954 cu = torch.jit.CompilationUnit(''' 11955 def single_starred_lhs(x): 11956 a = (x, x, x) 11957 *b, = a 11958 return b 11959 ''') 11960 11961 def test_singleton_tuple_unpack(self): 11962 def foo(a): 11963 b, = (a,) 11964 return b + 1 11965 self.checkScript(foo, (torch.rand(3),)) 11966 11967 def test_tuple_assignments(self): 11968 def var_tuple_assign(x, y): 11969 # type: (Tuple[Tensor, Tensor], Tensor) -> Tensor 11970 (a, b), c = x, y 11971 return a + b + c 11972 11973 tuple_inputs = (torch.randn(1, 4), torch.randn(3, 4)) 11974 self.checkScript(var_tuple_assign, (tuple_inputs, torch.randn(3, 4))) 11975 11976 def nested_tuple_assign(x, y, z): 11977 # type: (int, Tuple[int, Tuple[int, int]], Tuple[int, int]) -> int 11978 a, (b, (c, d)), (e, f) = x, y, z 11979 return a + b + c + d + e + f 11980 11981 self.checkScript(nested_tuple_assign, ((1, (2, (3, 4)), (5, 6)))) 11982 11983 def subscript_tuple_assign(a, x, i): 11984 # type: (List[int], Tensor, int) -> Tuple[int, Tensor, int] 11985 a[i], (x[i], b) = 1, (2, 3) 11986 return a[i] + 1, x + 5, b 11987 11988 self.checkScript(subscript_tuple_assign, ([12, 7, 9, 11], torch.tensor((3, 13, 17)), 0)) 11989 11990 def star_tuple_assign(): 11991 # type: () -> Tuple[int, int, Tuple[int, int], Tuple[int, int]] 11992 a, (b, *c), *d = 1, (2, 3, 4), 5, 6 11993 return a, b, c, d 11994 11995 self.checkScript(star_tuple_assign, ()) 11996 11997 def subscript_tuple_augmented_assign(a): 11998 # type: (Tuple[int, int]) -> Tuple[int, int] 11999 a[0] += 1 12000 return a 12001 12002 with self.assertRaisesRegex(RuntimeError, 'does not support augmented assign'): 12003 scripted_aug_assign = torch.jit.script(subscript_tuple_augmented_assign) 12004 12005 class AttrTupleAssignmentTestClass: 12006 def __init__(self, a: int, b: int): 12007 self.a = a 12008 self.b = b 12009 12010 def set_ab(self, a: int, b: int): 12011 self.a, self.b = (a, b) 12012 12013 def get(self) -> Tuple[int, int]: 12014 return (self.a, self.b) 12015 12016 make_global(AttrTupleAssignmentTestClass) 12017 12018 @torch.jit.script 12019 def attr_tuple_assignment(o: AttrTupleAssignmentTestClass, a: int, b: int): 12020 o.set_ab(a, b) 12021 return o 12022 12023 o = AttrTupleAssignmentTestClass(1, 2) 12024 self.assertEqual(attr_tuple_assignment(o, 3, 4).get(), (3, 4)) 12025 12026 def test_multiple_assign(self): 12027 def test(): 12028 a = b, c = d, f = (1, 1) 12029 12030 # side effect 12031 ten = torch.tensor(1) 12032 ten1 = ten2 = ten.add_(1) 12033 12034 # ordering 12035 x = 1 12036 y = 3 12037 x, y = y, x + y 12038 12039 return a, b, c, d, f, ten, ten1, ten2, x, y 12040 12041 self.checkScript(test, ()) 12042 12043 def test_multi_reduction(self): 12044 with self.assertRaisesRegex( 12045 RuntimeError, 12046 'augmented assignment can only have one LHS expression'): 12047 cu = torch.jit.CompilationUnit(''' 12048 def multi_reduction(x): 12049 a, b += x 12050 return a, b 12051 ''') 12052 12053 def test_invalid_call_arguments(self): 12054 with self.assertRaisesRegex(RuntimeError, 'but instead found type '): 12055 @torch.jit.script 12056 def invalid_call_arguments(x): 12057 return torch.unsqueeze(3, 4, 5, 6, 7, 8) 12058 12059 def test_invalid_lhs_assignment(self): 12060 with self.assertRaisesRegex(RuntimeError, 'unexpected expression'): 12061 cu = torch.jit.CompilationUnit(''' 12062 def invalid_lhs_assignment(x): 12063 x + 1 = x 12064 return x 12065 ''') 12066 12067 def test_multi_starred_expr_lhs(self): 12068 with self.assertRaisesRegex(RuntimeError, 'Only one starred expression is allowed on the lhs'): 12069 cu = torch.jit.CompilationUnit(''' 12070 def multi_starred_expr_lhs(): 12071 a, *b, *c = [1, 2, 3, 4, 5, 6] 12072 return a 12073 ''') 12074 12075 def test_pack_tuple_into_non_var(self): 12076 with self.assertRaisesRegex(RuntimeError, 'Cannot pack a tuple into a non-variable'): 12077 cu = torch.jit.CompilationUnit(''' 12078 def pack_tuple_into_non_var(x): 12079 a, *1 = (3, 4, 5) 12080 return x 12081 ''') 12082 12083 def test_print_kwargs(self): 12084 with self.assertRaisesRegex(RuntimeError, 'print doesn\'t accept any keyword arguments'): 12085 cu = torch.jit.CompilationUnit(''' 12086 def print_kwargs(x): 12087 print(x, flush=True) 12088 return x 12089 ''') 12090 12091 def test_builtin_use_as_value(self): 12092 with self.assertRaisesRegex(RuntimeError, 'builtin cannot be used as a value'): 12093 @torch.jit.script 12094 def builtin_use_as_value(x): 12095 return x.unsqueeze 12096 12097 def test_wrong_use_as_tuple(self): 12098 with self.assertRaisesRegex(RuntimeError, 'cannot be used as a tuple'): 12099 def test_fn(): 12100 return 3 12101 12102 @torch.jit.script 12103 def wrong_use_as_tuple(self): 12104 a, b = test_fn 12105 return a 12106 12107 def test_wrong_attr_lookup(self): 12108 with self.assertRaisesRegex(RuntimeError, 'attribute lookup is not defined on builtin'): 12109 @torch.jit.script 12110 def wrong_attr_lookup(self, x): 12111 a = x.unsqueeze.myattr 12112 return a 12113 12114 def test_wrong_use_as_callable(self): 12115 with self.assertRaisesRegex(RuntimeError, 'cannot call a value'): 12116 @torch.jit.script 12117 def wrong_use_as_callable(x): 12118 return x(3, 4, 5) 12119 12120 def test_python_val_doesnt_have_attr(self): 12121 with self.assertRaisesRegex(RuntimeError, 'object has no attribute abcd'): 12122 12123 @torch.jit.script 12124 def python_val_doesnt_have_attr(): 12125 # this has to be a module otherwise attr lookup would not be 12126 # allowed in the first place 12127 return shutil.abcd 12128 12129 def test_wrong_module_attr_lookup(self): 12130 with self.assertRaisesRegex(RuntimeError, 'python value of type \'type\' cannot be used as a value'): 12131 import io 12132 12133 @torch.jit.script 12134 def wrong_module_attr_lookup(): 12135 return io.BytesIO 12136 12137 def test_wrong_method_call_inputs(self): 12138 with self.assertRaisesRegex(RuntimeError, 'Argument y not provided'): 12139 class SomeModule(torch.jit.ScriptModule): 12140 12141 @torch.jit.script_method 12142 def foo(self, x, y): 12143 return x 12144 12145 @torch.jit.script_method 12146 def forward(self, x, y): 12147 return self.foo(x) 12148 SomeModule() 12149 12150 def test_single_starred_expr_for_loop(self): 12151 with self.assertRaisesRegex(RuntimeError, 'A Starred expression may only appear'): 12152 cu = torch.jit.CompilationUnit(''' 12153 def test(): 12154 x = 0 12155 for *a in [1, 2, 3]: 12156 x = x + 1 12157 return x 12158 ''') 12159 12160 def test_call_ge(self): 12161 with self.assertRaisesRegex(RuntimeError, 'Expected at most 1 arguments but found 3'): 12162 @_trace(torch.zeros(1, 2, 3)) 12163 def foo(x): 12164 return x 12165 12166 @torch.jit.script 12167 def test_fn(): 12168 return foo(torch.full([1], 1), torch.full([1], 2), torch.full([1], 3)) 12169 12170 def test_wrong_return_type(self): 12171 with self.assertRaisesRegex(RuntimeError, 'but instead got value of type tuple'): 12172 @torch.jit.ignore 12173 def somefunc(): 12174 # type: () -> Tuple[Tuple[Tensor, Tensor]] 12175 return torch.zeros(3, 4), torch.zeros(4, 5) # noqa: T484 12176 12177 @torch.jit.script 12178 def wrong_return_type(): 12179 return somefunc() 12180 wrong_return_type() 12181 12182 # Tests for calling between different front-end modes 12183 def test_call_python_fn_from_tracing_fn(self): 12184 def python_fn(x): 12185 return torch.neg(x) 12186 12187 @_trace(torch.rand(3, 4)) 12188 def traced_fn(x): 12189 return python_fn(x) + 1 12190 12191 # The neg op in the python function should be properly inlined to the 12192 # graph 12193 FileCheck().check("aten::neg").run(str(traced_fn.graph)) 12194 12195 def test_call_python_mod_from_tracing_fn(self): 12196 class PythonMod(torch.nn.Module): 12197 def __init__(self) -> None: 12198 super().__init__() 12199 self.param = torch.nn.Parameter(torch.rand(4, 3), requires_grad=False) 12200 12201 def forward(self, x): 12202 return torch.mm(x, self.param) 12203 12204 pm = PythonMod() 12205 12206 @_trace(torch.rand(3, 4)) 12207 def traced_fn(x): 12208 return pm(x) + 1.0 12209 12210 # Note: the parameter self.param from the Python module is inlined 12211 # into the graph 12212 self.assertTrue(len(list(traced_fn.graph.inputs())) == 1) 12213 FileCheck().check("aten::mm").check("aten::add").run(str(traced_fn.graph)) 12214 12215 @_tmp_donotuse_dont_inline_everything 12216 def test_call_traced_fn_from_tracing_fn(self): 12217 @_trace(torch.rand(3, 4)) 12218 def traced_fn1(x): 12219 return torch.neg(x) 12220 12221 @_trace(torch.rand(3, 4)) 12222 def traced_fn(x): 12223 return traced_fn1(x) + 1 12224 12225 FileCheck().check("traced_fn").check("prim::CallFunction").check("aten::add") \ 12226 .run(str(traced_fn.graph)) 12227 12228 @unittest.skip("error in first class mode") 12229 def test_call_traced_mod_from_tracing_fn(self): 12230 class TracedModule(torch.nn.Module): 12231 def __init__(self) -> None: 12232 super().__init__() 12233 self.param = torch.nn.Parameter(torch.rand(4, 3), requires_grad=False) 12234 12235 def forward(self, x): 12236 return torch.mm(x, self.param) 12237 12238 tm = torch.jit.trace(TracedModule(), torch.rand(3, 4)) 12239 12240 with self.assertRaisesRegex(RuntimeError, "must be registered as submodules"): 12241 @_trace(torch.rand(3, 4)) 12242 def traced_fn(x): 12243 return tm(x) + 1.0 12244 12245 @_tmp_donotuse_dont_inline_everything 12246 def test_call_script_fn_from_tracing_fn(self): 12247 @torch.jit.script 12248 def script_fn(x): 12249 return torch.neg(x) 12250 12251 @_trace(torch.rand(3, 4)) 12252 def traced_fn(x): 12253 return script_fn(x) + 1 12254 12255 FileCheck().check("prim::CallFunction").check("aten::add").run(str(traced_fn.graph)) 12256 12257 @unittest.skip("error in first class mode") 12258 def test_call_script_mod_from_tracing_fn(self): 12259 with self.assertRaisesRegex(RuntimeError, "must be registered as submodules"): 12260 class ScriptMod(torch.jit.ScriptModule): 12261 def __init__(self) -> None: 12262 super().__init__() 12263 self.param = torch.nn.Parameter(torch.rand(3, 4), requires_grad=False) 12264 12265 @torch.jit.script_method 12266 def forward(self, x): 12267 for _i in range(4): 12268 x += self.param 12269 return x 12270 12271 sm = ScriptMod() 12272 12273 @_trace(torch.rand(3, 4)) 12274 def traced_fn(x): 12275 return sm(x) + 1.0 12276 12277 12278 def test_call_python_fn_from_traced_module(self): 12279 def python_fn(x): 12280 return torch.neg(x) 12281 12282 class TracedModule(torch.nn.Module): 12283 def __init__(self) -> None: 12284 super().__init__() 12285 self.param = torch.nn.Parameter(torch.rand(4, 3)) 12286 12287 def forward(self, x): 12288 return torch.mm(python_fn(x), self.param) 12289 12290 tm = torch.jit.trace(TracedModule(), torch.rand(3, 4)) 12291 12292 # Note: parameter self.param from the traced module should appear as 12293 # an input to the graph and the neg op from the Python function should 12294 # be properly inlined 12295 self.assertTrue(len(list(tm.graph.inputs())) == 2) 12296 FileCheck().check("aten::neg").check("aten::mm").run(str(tm.graph)) 12297 12298 def test_call_python_mod_from_traced_module(self): 12299 class PythonModule(torch.nn.Module): 12300 def __init__(self) -> None: 12301 super().__init__() 12302 self.param = torch.nn.Parameter(torch.rand(5, 7)) 12303 12304 def forward(self, x): 12305 return torch.mm(x, self.param) 12306 12307 class TracedModule(torch.nn.Module): 12308 def __init__(self) -> None: 12309 super().__init__() 12310 self.param = torch.nn.Parameter(torch.rand(4, 5)) 12311 self.mod = PythonModule() 12312 12313 def forward(self, x): 12314 return self.mod(torch.mm(x, self.param)) + 1.0 12315 12316 tm = torch.jit.trace(TracedModule(), torch.rand(3, 4)) 12317 12318 FileCheck().check_not("value=<Tensor>").check("aten::mm")\ 12319 .check('prim::CallMethod[name="forward"]').check("aten::add") \ 12320 .run(str(tm.graph)) 12321 FileCheck().check("aten::mm").run(str(tm.mod.graph)) 12322 12323 def test_op_dtype(self): 12324 12325 def check_equal_and_dtype(a, b): 12326 self.assertEqual(a, b) 12327 self.assertEqual(a.dtype, b.dtype) 12328 12329 def fn(): 12330 a = torch.arange(10) 12331 b = torch.arange(10, dtype=torch.float) 12332 c = torch.arange(1, 10, 2) 12333 d = torch.arange(1, 10, 2, dtype=torch.float) 12334 e = torch.arange(1, 10., 2) 12335 f = torch.arange(1, 10., 2, dtype=torch.float) 12336 return a, b, c, d, e, f 12337 12338 scripted_fn = torch.jit.script(fn) 12339 eager_out = fn() 12340 script_out = scripted_fn() 12341 for a, b in zip(eager_out, script_out): 12342 check_equal_and_dtype(a, b) 12343 12344 def test_floor_div(self): 12345 @torch.jit.script 12346 def foo(a, b): 12347 # type: (int, int) -> int 12348 return a // b 12349 for i in range(-8, 8): 12350 for j in range(-8, 8): 12351 if j != 0: 12352 self.assertEqual(foo(i, j), i // j) 12353 12354 def test_floordiv(self): 12355 funcs_template = dedent(''' 12356 def fn(): 12357 ten = {a_construct} 12358 ten_or_scalar = {b_construct} 12359 return ten // ten_or_scalar, torch.floor_divide(ten, ten_or_scalar) 12360 ''') 12361 12362 lhs = ["torch.tensor([5.5, 3.2])", "torch.tensor([2, 2])", "torch.tensor([3, 2])"] 12363 rhs = ["1.5", "2", "4", "1.1"] + lhs 12364 for tensor in lhs: 12365 for tensor_or_scalar in rhs: 12366 funcs_str = funcs_template.format(a_construct=tensor, b_construct=tensor_or_scalar) 12367 scope = {} 12368 execWrapper(funcs_str, globals(), scope) 12369 cu = torch.jit.CompilationUnit(funcs_str) 12370 f_script = cu.fn 12371 f = scope['fn'] 12372 self.assertEqual(f_script(), f()) 12373 12374 def test_call_python_fn_from_script_fn(self): 12375 @torch.jit.ignore 12376 def python_fn(x): 12377 return torch.neg(x) 12378 12379 @torch.jit.script 12380 def script_fn(x): 12381 return python_fn(x) + 1 12382 12383 # Note: the call to python_fn appears as `^python_fn()` and is called 12384 # as a PythonOp in the interpreter 12385 a = torch.tensor(1) 12386 self.assertEqual(script_fn(a), torch.tensor(0)) 12387 FileCheck().check("python_fn").run(str(script_fn.graph)) 12388 12389 def test_call_python_mod_from_script_fn(self): 12390 class PythonModule(torch.nn.Module): 12391 def __init__(self) -> None: 12392 super().__init__() 12393 self.param = torch.nn.Parameter(torch.rand(5, 7)) 12394 12395 def forward(self, x): 12396 return torch.mm(x, self.param) 12397 12398 pm = PythonModule() 12399 12400 @torch.jit.script 12401 def script_fn(x): 12402 return pm(x) + 1 12403 12404 # Note: call to pm(x) appears as ^<python_value>() in the trace. 12405 # Parameters are NOT inlined. 12406 FileCheck().check("python_value").check("aten::add").run(str(script_fn.graph)) 12407 12408 @_tmp_donotuse_dont_inline_everything 12409 def test_call_script_fn_from_script_fn(self): 12410 @torch.jit.script 12411 def script_fn1(x): 12412 return torch.neg(x) 12413 12414 @torch.jit.script 12415 def script_fn(x): 12416 return script_fn1(x) + 1 12417 12418 FileCheck().check("prim::CallFunction").run(str(script_fn.graph)) 12419 12420 def test_call_script_mod_from_script_fn(self): 12421 with self.assertRaisesRegex(RuntimeError, "Cannot call a ScriptModule that is not a submodule of the caller"): 12422 class ScriptMod(torch.jit.ScriptModule): 12423 @torch.jit.script_method 12424 def forward(self, x): 12425 return torch.mm(x, torch.zeros([4, 3])) 12426 12427 sm = ScriptMod() 12428 12429 @torch.jit.script 12430 def script_fn(x): 12431 return sm(x) + 1 12432 12433 def test_call_python_fn_from_script_module(self): 12434 @torch.jit.ignore 12435 def python_fn(x): 12436 return torch.neg(x) 12437 12438 class ScriptMod(torch.jit.ScriptModule): 12439 def __init__(self) -> None: 12440 super().__init__() 12441 self.param = torch.nn.Parameter(torch.rand(4, 3)) 12442 12443 @torch.jit.script_method 12444 def forward(self, x): 12445 return python_fn(torch.mm(x, self.param)) 12446 12447 sm = ScriptMod() 12448 FileCheck().check("aten::mm").check("python_fn") \ 12449 .run(str(sm.forward.graph)) 12450 12451 def test_call_python_mod_from_script_module(self): 12452 class PythonMod(torch.nn.Module): 12453 def __init__(self) -> None: 12454 super().__init__() 12455 self.param = torch.nn.Parameter(torch.rand(3, 5)) 12456 12457 @torch.jit.ignore 12458 def forward(self, x): 12459 return torch.mm(x, self.param) 12460 12461 class ScriptMod(torch.jit.ScriptModule): 12462 def __init__(self) -> None: 12463 super().__init__() 12464 self.param = torch.nn.Parameter(torch.rand(4, 3)) 12465 self.pm = PythonMod() 12466 12467 @torch.jit.script_method 12468 def forward(self, x): 12469 return self.pm(torch.mm(x, self.param)) 12470 12471 sm = ScriptMod() 12472 # Note: the call into PythonMod appears as ^forward(). Parameters 12473 # are NOT inlined 12474 FileCheck().check("aten::mm").check("forward").run(str(sm.graph)) 12475 12476 @_tmp_donotuse_dont_inline_everything 12477 def test_call_script_fn_from_script_module(self): 12478 @torch.jit.script 12479 def script_fn(x): 12480 return torch.neg(x) 12481 12482 class ScriptMod(torch.jit.ScriptModule): 12483 def __init__(self) -> None: 12484 super().__init__() 12485 self.param = torch.nn.Parameter(torch.rand(4, 3)) 12486 12487 @torch.jit.script_method 12488 def forward(self, x): 12489 return script_fn(torch.mm(x, self.param)) 12490 12491 sm = ScriptMod() 12492 graph = (sm.forward.graph) 12493 FileCheck().check("aten::mm").check("prim::CallFunction").run(str(graph)) 12494 12495 @_tmp_donotuse_dont_inline_everything 12496 def test_call_script_mod_from_script_module(self): 12497 class ScriptMod1(torch.jit.ScriptModule): 12498 def __init__(self) -> None: 12499 super().__init__() 12500 self.param = torch.nn.Parameter(torch.rand(3, 5)) 12501 12502 @torch.jit.script_method 12503 def forward(self, x): 12504 return torch.mm(x, self.param) 12505 12506 class ScriptMod(torch.jit.ScriptModule): 12507 def __init__(self) -> None: 12508 super().__init__() 12509 self.param = torch.nn.Parameter(torch.rand(4, 3)) 12510 self.tm = ScriptMod1() 12511 12512 @torch.jit.script_method 12513 def forward(self, x): 12514 return self.tm(torch.mm(x, self.param)) 12515 12516 sm = ScriptMod() 12517 # Note: the parameters from both modules should appear in the flattened 12518 # input list to the graph. The mm op from ScriptMod1 should be properly 12519 # inlined 12520 # 3 % values in graph input lists, two mms in body 12521 FileCheck().check_count('%', 3).check(":").check_count("mm", 1).check("prim::CallMethod").run(str(sm.graph)) 12522 12523 def test_module_with_params_called_fails(self): 12524 with self.assertRaisesRegex(RuntimeError, "Cannot call a ScriptModule that is not a submodule of the caller"): 12525 class ScriptMod(torch.jit.ScriptModule): 12526 def __init__(self) -> None: 12527 super().__init__() 12528 self.param = torch.nn.Parameter(torch.rand(3, 3)) 12529 12530 @torch.jit.script_method 12531 def forward(self, x): 12532 return torch.mm(x, self.param) 12533 12534 sm = ScriptMod() 12535 12536 @torch.jit.script 12537 def some_func(x): 12538 return sm(x) 12539 12540 def test_tuple_index_to_list(self): 12541 def test_non_constant_input(a): 12542 # type: (bool) -> int 12543 if a: 12544 b = 1 12545 else: 12546 b = 0 12547 c = (0, 1) 12548 return c[b] 12549 12550 self.checkScript(test_non_constant_input, (True,)) 12551 self.checkScript(test_non_constant_input, (False,)) 12552 12553 with self.assertRaisesRegex(RuntimeError, "because we cannot resolve the output type"): 12554 @torch.jit.script 12555 def test_non_constant_input(a): 12556 # type: (bool) -> None 12557 if a: 12558 b = 1 12559 else: 12560 b = 0 12561 c = (0, 1.1) 12562 print(c[b]) 12563 12564 def test_tuple_indexing(self): 12565 def tuple_index(a): 12566 if bool(a): 12567 b = (1, 2) 12568 else: 12569 b = (0, 2) 12570 return b[-2], b[1] 12571 12572 self.checkScript(tuple_index, (torch.tensor([0]),)) 12573 self.checkScript(tuple_index, (torch.tensor([1]),)) 12574 self.checkScript(tuple_index, (torch.tensor([1]),), optimize=True) 12575 tuple_comp = torch.jit.script(tuple_index) 12576 FileCheck().check_count("TupleIndex", 2, exactly=True).run(str(tuple_comp.graph)) 12577 12578 with self.assertRaisesRegex(RuntimeError, "index must be an integer"): 12579 @torch.jit.script 12580 def test_indexing_float(): 12581 c = (1, 2) 12582 return c[0.1] 12583 12584 def test_indexing_out_of_bounds_pos(): 12585 c = (1, 2) 12586 return c[2] 12587 12588 self.checkScriptRaisesRegex(test_indexing_out_of_bounds_pos, (), Exception, 12589 "out of range") 12590 12591 def test_indexing_out_of_bounds_neg(): 12592 c = (1, 2) 12593 return c[-3] 12594 12595 self.checkScriptRaisesRegex(test_indexing_out_of_bounds_pos, (), Exception, 12596 "out of range") 12597 12598 def negative_index(): 12599 tup = (1, 2, 3, 4) 12600 return tup[-1] 12601 12602 self.checkScript(negative_index, []) 12603 12604 def really_negative_index(): 12605 tup = (1, 2, 3, 4) 12606 return tup[-100] 12607 12608 self.checkScriptRaisesRegex(really_negative_index, [], Exception, "index out of range") 12609 12610 def negative_slice(): 12611 tup = (1, 2, 3, 4) 12612 return tup[-3:4] 12613 12614 self.checkScript(negative_slice, []) 12615 12616 def really_slice_out_of_bounds(): 12617 tup = (1, 2, 3, 4) 12618 return tup[-300:4000] 12619 12620 self.checkScript(really_slice_out_of_bounds, []) 12621 12622 def test_namedtuple_attr(self): 12623 def f(x): 12624 return x.max(dim=1).indices + torch.max(x, dim=1).indices 12625 12626 self.checkScript(f, (torch.rand(20, 20, 20),), optimize=True) 12627 12628 with self.assertRaisesRegex(RuntimeError, "object has no attribute or method"): 12629 @torch.jit.script 12630 def g1(x): 12631 return x.max(dim=1).unknown_symbol 12632 12633 with self.assertRaisesRegex(RuntimeError, "object has no attribute or method"): 12634 @torch.jit.script 12635 def g2(x): 12636 print((x, x, x).__doc__) 12637 return x 12638 12639 def test_tuple_len(self): 12640 @torch.jit.script 12641 def foo(): 12642 return len((1, "str", None)) 12643 12644 self.assertEqual(foo(), 3) 12645 12646 @torch.jit.script 12647 def test_indexing_end_out_of_bounds(): 12648 c = (1, 2) 12649 return c[2:10] 12650 12651 self.assertEqual(test_indexing_end_out_of_bounds(), ()) 12652 12653 def test_lower_nested_tuples(self): 12654 @torch.jit.script 12655 def test(): 12656 return ((1, 2), 3) 12657 12658 self.run_pass('constant_propagation', test.graph) 12659 FileCheck().check("prim::Constant").check_not("TupleConstruct").run(test.graph) 12660 # fails if a tuple can't be lowered 12661 self.run_pass('lower_all_tuples', test.graph) 12662 12663 def test_unwrap_optional_builtin(self): 12664 def test(x): 12665 # type: (Optional[int]) -> int 12666 x = torch.jit._unwrap_optional(x) 12667 x = x + x # noqa: T484 12668 return x 12669 12670 self.checkScript(test, (3,)) 12671 12672 with self.assertRaisesRegex(AssertionError, "Unwrapping null optional"): 12673 test(None) 12674 12675 test_script = torch.jit.script(test) 12676 with self.assertRaisesRegex(RuntimeError, "Unwrapping null optional"): 12677 test_script(None) 12678 12679 @torch.jit.script 12680 def test_test(): 12681 return torch.jit._unwrap_optional(1) 12682 12683 with self.assertRaisesRegex(RuntimeError, r"could not be inferred from actual type None"): 12684 @torch.jit.script 12685 def test_no_type(): 12686 # type: () -> int 12687 return torch.jit._unwrap_optional(None) 12688 12689 def test_indexing_error(self): 12690 with self.assertRaisesRegex(RuntimeError, "'int' object is not subscriptable"): 12691 @torch.jit.script 12692 def test_wrong_type(): 12693 a = 8 12694 return a[0] 12695 12696 def test_unsupported_builtin_error(self): 12697 with self.assertRaisesRegex(RuntimeError, 12698 "Python builtin <built-in function hypot> is currently"): 12699 @torch.jit.script 12700 def test_unsupported(a): 12701 return math.hypot(a, 2.0) 12702 12703 def test_annotated_script_fn(self): 12704 @torch.jit.script 12705 def foo(x, y, z): 12706 # type: (Tensor, Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tuple[Tensor, Tensor]]) -> Tensor 12707 return x 12708 12709 self.assertExpected(str(foo.schema)) 12710 12711 def test_annotated_script_method(self): 12712 class SM(torch.jit.ScriptModule): 12713 @torch.jit.script_method 12714 def forward(self, x, y): 12715 # type: (Tuple[Tensor, Tensor], Tensor) -> Tuple[Tensor, Tensor, Tensor] 12716 return y, y, y 12717 12718 sm = SM() 12719 12720 self.assertExpectedStripMangled(str(sm.forward.schema)) 12721 12722 def test_annotated_script_fn_return_mismatch(self): 12723 with self.assertRaisesRegex(RuntimeError, "but is actually of type"): 12724 @torch.jit.script 12725 def return_tup(x): 12726 # type: (Tensor) -> Tuple[Tuple[Tensor, Tensor], Tensor] 12727 return x, x # noqa: T484 12728 12729 def test_annotated_script_fn_arg_mismatch(self): 12730 with self.assertRaisesRegex(RuntimeError, r"Arguments for call are not valid"): 12731 @torch.jit.script 12732 def tuple_arg(x): 12733 # type: (Tuple[Tensor, Tensor]) -> Tensor 12734 return x + 1 # noqa: T484 12735 12736 def test_script_non_tensor_args_outputs(self): 12737 @torch.jit.script 12738 def fn(x, y): 12739 # type: (Tensor, float) -> float 12740 return float((x + y).sum()) 12741 12742 x = torch.ones(2, 2) 12743 z = fn(x, 1) 12744 self.assertIsInstance(z, float) 12745 self.assertEqual(z, 8.) 12746 12747 @unittest.skip('https://github.com/pytorch/pytorch/issues/9595') 12748 def test_inline_and_run_annotated_script_fn(self): 12749 @torch.jit.script 12750 def to_inline(x, y): 12751 # type: (Tuple[Tensor, Tensor], Tensor) -> Tensor 12752 return y 12753 12754 @torch.jit.script 12755 def some_func(x): 12756 return to_inline((x, x), x) 12757 12758 x = torch.rand(3, 4) 12759 self.assertEqual(some_func(x), x) 12760 12761 def _make_filereader_test_file(self): 12762 filename = tempfile.mktemp() 12763 writer = torch._C.PyTorchFileWriter(filename) 12764 buffers = [os.urandom(size) for size in [random.randint(1, 100) for i in range(20)]] 12765 offsets = [] 12766 for i, buf in enumerate(buffers): 12767 writer.write_record(str(i), buf, len(buf)) 12768 offsets.append(i) 12769 serialized_offsets = pickle.dumps(offsets) 12770 writer.write_record("meta", serialized_offsets, len(serialized_offsets)) 12771 writer.write_end_of_file() 12772 return filename, buffers, serialized_offsets 12773 12774 def test_file_format_serialization(self): 12775 filename, buffers, serialized_offsets = self._make_filereader_test_file() 12776 12777 reader = torch._C.PyTorchFileReader(filename) 12778 serialized_offsets_read = reader.get_record("meta") 12779 parsed_serialized_offsets = pickle.loads(serialized_offsets) 12780 12781 for i, offset in enumerate(parsed_serialized_offsets): 12782 data = reader.get_record(str(offset)) 12783 assert data == buffers[i] 12784 12785 def test_file_reader_no_memory_leak(self): 12786 num_iters = 10000 12787 filename, _, _ = self._make_filereader_test_file() 12788 12789 # Load from filename 12790 tracemalloc.start() 12791 for i in range(num_iters): 12792 torch._C.PyTorchFileReader(filename) 12793 _, peak_from_string = tracemalloc.get_traced_memory() 12794 tracemalloc.stop() 12795 12796 # Load from stream 12797 tracemalloc.start() 12798 with open(filename, 'rb') as f: 12799 for i in range(num_iters): 12800 f.seek(0) 12801 torch._C.PyTorchFileReader(f) 12802 _, peak_from_file = tracemalloc.get_traced_memory() 12803 tracemalloc.stop() 12804 12805 # Check if the peak sizes at most differ by an empirically obtained factor 12806 self.assertLess(peak_from_file, peak_from_string * 500) 12807 12808 # for each type, the input type annotation and corresponding return type annotation 12809 def type_input_return_pairs(self): 12810 return [ 12811 ('Tensor', 'Tensor'), 12812 ('torch.Tensor', 'Tensor'), 12813 ('str', 'str'), 12814 ('int', 'int'), 12815 ('bool', 'bool'), 12816 ('BroadcastingList3[float]', 'List[float]'), 12817 ('BroadcastingList2[int]', 'List[int]'), 12818 ('List[int]', 'List[int]'), 12819 ('Optional[int]', 'Optional[int]'), 12820 ] 12821 12822 # replacing code input & return type pair 12823 def format_code(self, code, pair): 12824 return code.format(input=pair[0], output=pair[1]) 12825 12826 # ***** Type annotation tests **** 12827 # Test combinations of: 12828 # {String frontend, Python AST Frontend} 12829 # {Python 3-style type annotations, MyPy-style type comments} 12830 # {Script method, Script function} 12831 12832 # String frontend , Python 3-style type annotations , Script function 12833 def test_annot_string_py3_fn(self): 12834 code = ''' 12835 def foo(x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]: 12836 return x, x 12837 ''' 12838 test_str = [] 12839 for pair in self.type_input_return_pairs(): 12840 cu = torch.jit.CompilationUnit(self.format_code(code, pair)) 12841 test_str.append(str(cu.foo.schema)) 12842 self.assertExpected("\n".join(test_str) + "\n") 12843 12844 # String frontend , Python 3-style type annotations , Script method 12845 def test_annot_string_py3_method(self): 12846 class TestModule(torch.jit.ScriptModule): 12847 def __init__(self) -> None: 12848 super().__init__() 12849 12850 code = ''' 12851 def foo(self, x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]: 12852 return x, x 12853 ''' 12854 test_str = [] 12855 for pair in self.type_input_return_pairs(): 12856 # clear the class registry as we will be defining foo multiple times 12857 jit_utils.clear_class_registry() 12858 tm = TestModule() 12859 tm.define(self.format_code(code, pair)) 12860 test_str.append(str(tm.foo.schema)) 12861 self.assertExpectedStripMangled("\n".join(test_str) + "\n") 12862 12863 # String frontend , MyPy-style type comments , Script function 12864 def test_annot_string_mypy_fn(self): 12865 code = ''' 12866 def foo(x, y): 12867 # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}] 12868 return x, x 12869 ''' 12870 test_str = [] 12871 for pair in self.type_input_return_pairs(): 12872 cu = torch.jit.CompilationUnit(self.format_code(code, pair)) 12873 test_str.append(str(cu.foo.schema)) 12874 self.assertExpectedStripMangled("\n".join(test_str) + "\n") 12875 12876 # String frontend , MyPy-style type comments , Script method 12877 def test_annot_string_mypy_method(self): 12878 class TestModule(torch.jit.ScriptModule): 12879 def __init__(self) -> None: 12880 super().__init__() 12881 12882 code = ''' 12883 def foo(self, x, y): 12884 # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}] 12885 return x, x 12886 ''' 12887 12888 test_str = [] 12889 for pair in self.type_input_return_pairs(): 12890 # clear the class registry as we will be defining foo multiple times 12891 jit_utils.clear_class_registry() 12892 tm = TestModule() 12893 tm.define(self.format_code(code, pair)) 12894 test_str.append(str(tm.foo.schema)) 12895 self.assertExpectedStripMangled("\n".join(test_str) + "\n") 12896 12897 # Python AST Frontend , Python 3-style type annotations , Script function 12898 def test_annot_ast_py3_fn(self): 12899 code = dedent(''' 12900 from typing import Tuple, List, Optional 12901 from torch import Tensor 12902 from torch.jit.annotations import BroadcastingList2, BroadcastingList3 12903 import torch 12904 @torch.jit.script 12905 def foo(x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]: 12906 return x, x 12907 ''') 12908 test_str = [] 12909 for pair in self.type_input_return_pairs(): 12910 fn = jit_utils._get_py3_code(self.format_code(code, pair), 'foo') 12911 test_str.append(str(fn.schema)) 12912 self.assertExpectedStripMangled("\n".join(test_str) + "\n") 12913 12914 def test_multiline_annot_ast_py3_fn(self): 12915 code = dedent(''' 12916 from typing import Tuple, List, Optional 12917 from torch import Tensor 12918 from torch.jit.annotations import BroadcastingList2, BroadcastingList3 12919 import torch 12920 @torch.jit.script 12921 def foo(x, # type: {input} 12922 y # type: Tuple[Tensor, Tensor] 12923 ): 12924 # type: (...) -> Tuple[{output}, {output}] 12925 return x, x 12926 ''') 12927 test_str = [] 12928 12929 for pair in self.type_input_return_pairs(): 12930 fn = jit_utils._get_py3_code(self.format_code(code, pair), 'foo') 12931 args = fn.schema.arguments 12932 returns = fn.schema.returns 12933 self.assertEqual(str(args[0].type), pair[1]) 12934 self.assertEqual(str(args[1].type), "Tuple[Tensor, Tensor]") 12935 self.assertEqual(str(returns[0].type), f"Tuple[{pair[1]}, {pair[1]}]") 12936 12937 def test_bad_multiline_annotations(self): 12938 with self.assertRaisesRegex(RuntimeError, "Return type line"): 12939 @torch.jit.script 12940 def bad_type_line(a, # type: Tensor 12941 b, # type: Tensor 12942 c # type: Tensor 12943 ): 12944 # type: (int, int, int) -> Tensor 12945 # type: bad type line # noqa: F723 12946 12947 return a + b + c 12948 12949 with self.assertRaisesRegex(RuntimeError, "Return type line"): 12950 @torch.jit.script 12951 def bad_return_line(a, # type: Tensor 12952 b, 12953 c # type: Tensor 12954 ): 12955 # type: (int, int, int) -> Tensor 12956 return a + b + c 12957 12958 # TODO: this should be supported but is difficult to parse 12959 with self.assertRaisesRegex(RuntimeError, "Number of type annotations"): 12960 @torch.jit.script 12961 def missing_type(a, # type: Tensor 12962 b, 12963 c # type: Tensor 12964 ): 12965 # type: (...) -> Tensor 12966 return a + b + c 12967 12968 # Python AST Frontend , Python 3-style type annotations , Script method 12969 def test_annot_ast_py3_method(self): 12970 code = dedent(''' 12971 from typing import Tuple, List, Optional 12972 from torch import Tensor 12973 from torch.jit.annotations import BroadcastingList2, \\ 12974 BroadcastingList3 12975 import torch 12976 class FooModule(torch.jit.ScriptModule): 12977 @torch.jit.script_method 12978 def foo(self, x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]: 12979 return x, x 12980 instance = FooModule() 12981 ''') 12982 12983 test_str = [] 12984 for pair in self.type_input_return_pairs(): 12985 fn = jit_utils._get_py3_code(self.format_code(code, pair), 'instance') 12986 test_str.append(str(fn.foo.schema)) 12987 self.assertExpectedStripMangled("\n".join(test_str) + "\n") 12988 12989 # Python AST Frontend , MyPy-style type comments , Script function 12990 def test_annot_ast_mypy_fn(self): 12991 code = dedent(''' 12992 import torch 12993 @torch.jit.script 12994 def foo(x, y): 12995 # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}] 12996 return x, x 12997 ''') 12998 12999 test_str = [] 13000 for pair in self.type_input_return_pairs(): 13001 fn = jit_utils._get_py3_code(self.format_code(code, pair), 'foo') 13002 test_str.append(str(fn.schema)) 13003 self.assertExpected("\n".join(test_str) + "\n") 13004 13005 # Python AST Frontend , MyPy-style type comments , Script method 13006 def test_annot_ast_mypy_method(self): 13007 code = dedent(''' 13008 import torch 13009 class FooModule(torch.jit.ScriptModule): 13010 @torch.jit.script_method 13011 def foo(self, x, y): 13012 # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}] 13013 return x, x 13014 instance = FooModule() 13015 ''') 13016 13017 test_str = [] 13018 for pair in self.type_input_return_pairs(): 13019 fn = jit_utils._get_py3_code(self.format_code(code, pair), 'instance') 13020 test_str.append(str(fn.foo.schema)) 13021 self.assertExpectedStripMangled("\n".join(test_str) + "\n") 13022 13023 # Tests that "# type: ignore[*]" is supported in type lines and is 13024 # properly ignored. 13025 def test_mypy_type_ignore(self): 13026 @torch.jit.script 13027 def foo(x): # type: ignore 13028 return x 13029 13030 @torch.jit.script 13031 def bar(x): # type: ignore[no-redef] 13032 return x 13033 13034 def test_method_casts_script(self): 13035 cast_types = [ 13036 'byte', 'char', 'double', 'float', 'int', 'long', 'short' 13037 ] 13038 13039 for cast_type in cast_types: 13040 cu = torch.jit.CompilationUnit(f''' 13041 def cast_to(x): 13042 return x.{cast_type}() 13043 ''') 13044 13045 x = torch.rand(3, 4, 5) * 128 13046 cu_result = cu.cast_to(x) 13047 reference = getattr(x, cast_type)() 13048 self.assertEqual(cu_result, reference) 13049 13050 def test_string_frontend_elif(self): 13051 code = ''' 13052 def func(niter): 13053 # type: (int) 13054 rv = 0 13055 for i in range(niter): 13056 if i % 3 == 0 and i % 5 == 0: 13057 rv += 35 13058 elif i % 3 == 0: 13059 rv += 3 13060 elif i % 5 == 0: 13061 rv += 5 13062 else: 13063 rv += i 13064 return rv 13065 ''' 13066 13067 self.checkScript(dedent(code), (101,)) 13068 13069 def test_module_parameters_and_buffers(self): 13070 weights = torch.randn(10, 10) 13071 bias = torch.randn(10) 13072 weights2 = torch.randn(10, 10) 13073 bias2 = torch.randn(10) 13074 13075 class TestLinear(torch.nn.Module): 13076 def __init__(self, in_features, out_features): 13077 super().__init__() 13078 self.in_features = in_features 13079 self.out_features = out_features 13080 self.weight = torch.nn.Parameter(torch.empty(out_features, in_features)) 13081 self.bias = torch.nn.Parameter(torch.empty(out_features)) 13082 self.counter = nn.Buffer(torch.ones(out_features)) 13083 self.reset_parameters() 13084 13085 def reset_parameters(self): 13086 torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 13087 if self.bias is not None: 13088 fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) 13089 bound = 1 / math.sqrt(fan_in) 13090 torch.nn.init.uniform_(self.bias, -bound, bound) 13091 13092 def forward(self, input): 13093 return F.linear(input, self.weight, self.bias) + self.counter 13094 13095 # Initialize a ScriptModule that uses the weak module above multiple times 13096 class Strong(torch.jit.ScriptModule): 13097 def __init__(self) -> None: 13098 super().__init__() 13099 self.fc1 = TestLinear(10, 10) 13100 self.fc1.weight = torch.nn.Parameter(weights) 13101 self.fc1.bias = torch.nn.Parameter(bias) 13102 self.fc2 = TestLinear(10, 10) 13103 self.fc2.weight = torch.nn.Parameter(weights2) 13104 self.fc2.bias = torch.nn.Parameter(bias2) 13105 13106 @torch.jit.script_method 13107 def forward(self, x): 13108 return x + self.fc1(x) + self.fc1(x) + self.fc2(x) 13109 13110 strong_mod = Strong() 13111 13112 # Run same calculation as module 13113 inp = torch.ones(10) 13114 lin = torch.nn.Linear(10, 10) 13115 lin.weight = torch.nn.Parameter(weights) 13116 lin.bias = torch.nn.Parameter(bias) 13117 lin2 = torch.nn.Linear(10, 10) 13118 lin2.weight = torch.nn.Parameter(weights2) 13119 lin2.bias = torch.nn.Parameter(bias2) 13120 expected_result = inp + (lin(inp) + torch.ones(10)) * 2 + lin2(inp) + torch.ones(10) 13121 13122 self.assertEqual(strong_mod(inp), expected_result) 13123 self.assertExportImportModule(strong_mod, (inp,)) 13124 13125 def test_module_copying(self): 13126 class Submodule(torch.nn.Module): 13127 def forward(self, x): 13128 return x + 100 13129 13130 class Weak(torch.nn.Module): 13131 def __init__(self, in_features, out_features): 13132 super().__init__() 13133 self.weight = torch.nn.Parameter(torch.ones(out_features, in_features)) 13134 self.bias = torch.nn.Parameter(torch.ones(out_features)) 13135 self.buffer = nn.Buffer(torch.ones(out_features)) 13136 self.submodule = Submodule() 13137 13138 def forward(self, x): 13139 return F.linear(x, self.weight, self.bias) \ 13140 + self.buffer + self.submodule(x) 13141 13142 class Strong(torch.jit.ScriptModule): 13143 def __init__(self, weak): 13144 super().__init__() 13145 self.weak = weak 13146 13147 @torch.jit.script_method 13148 def forward(self, x): 13149 return self.weak(x) 13150 13151 inp = torch.ones(5, 5) * 5 13152 weak_mod = Weak(5, 5) 13153 strong_mod = Strong(weak_mod) 13154 13155 self.assertTrue(isinstance(strong_mod.weak, torch.jit.ScriptModule)) 13156 self.assertFalse(isinstance(weak_mod, torch.jit.ScriptModule)) 13157 13158 self.assertIs(strong_mod.weak.weight, weak_mod.weight) 13159 self.assertIs(strong_mod.weak.buffer, weak_mod.buffer) 13160 # strong_mod.weak.submodule has been recursively scripted 13161 self.assertIsNot(strong_mod.weak.submodule, weak_mod.submodule) 13162 13163 weak_mod.weight.data += torch.ones(5, 5) * 100 13164 self.assertTrue(strong_mod(inp).allclose(weak_mod(inp))) 13165 13166 # Re-assignment is not tracked 13167 weak_mod.weight = torch.nn.Parameter(torch.ones(5, 5) * 100) 13168 self.assertFalse(strong_mod(inp).allclose(weak_mod(inp))) 13169 13170 def test_backend_cudnn_enabled(self): 13171 # Only test that this compiles 13172 @torch.jit.script 13173 def fn(x): 13174 if torch.backends.cudnn.enabled: 13175 x = x + 2 13176 else: 13177 x = x + 3 13178 return x 13179 13180 def test_inplace_add(self): 13181 13182 def foo(a, b): 13183 c = a + b 13184 c.add_(b) 13185 return c 13186 self.checkScript(foo, (torch.rand(3), torch.rand(3))) 13187 13188 def test_add_out(self): 13189 def foo(a, b): 13190 c = a + b 13191 e = 2 * a 13192 torch.add(c, b, out=e) 13193 return e 13194 self.checkScript(foo, (torch.rand(3), torch.rand(3))) 13195 13196 def test_tuple_error_msg(self): 13197 def fn(t: Any): 13198 if isinstance(t, tuple): 13199 a, b = t 13200 return a + b 13201 with self.assertRaisesRegexWithHighlight(RuntimeError, "Provided tuple is not fully defined/refined", "t"): 13202 s = torch.jit.script(fn) 13203 13204 def test_augmented_assign(self): 13205 def foo(a, b): 13206 a += b 13207 a -= b 13208 a /= b 13209 a *= b 13210 return a, b 13211 self.checkScript(foo, (torch.rand(3), torch.rand(3))) 13212 13213 def test_ignored_props(self): 13214 class A(nn.Module): 13215 __jit_ignored_attributes__ = ["ignored", "ignored_return_val"] 13216 13217 @property 13218 def ignored(self): 13219 raise ValueError("shouldn't be called") 13220 13221 @property 13222 def ignored_return_val(self): 13223 return 1 13224 13225 @torch.jit.ignore 13226 def call(self): 13227 return self.ignored_return_val 13228 13229 f = torch.jit.script(A()) 13230 # jank way to test if there is no error 13231 self.assertTrue(isinstance(f, torch.jit.ScriptModule)) 13232 self.assertTrue(isinstance(f.call(), property)) 13233 13234 13235 def test_pass(self): 13236 def foo(x): 13237 # type: (bool) -> int 13238 for _i in range(3): 13239 pass 13240 if x: 13241 pass 13242 else: 13243 pass 13244 return 3 13245 13246 self.checkScript(foo, (True,)) 13247 13248 def test_lhs_indexing(self): 13249 def foo(a, b): 13250 a = a.clone() 13251 a[0] = b 13252 return a 13253 self.checkScript(foo, (torch.rand(2, 3), torch.rand(3))) 13254 13255 def test_lhs_advanced_indexing_assignment(self): 13256 def foo(x, y): 13257 a = torch.exp(x) 13258 b = x == 1 13259 a[b] = y[b] 13260 return a 13261 self.checkScript(foo, (torch.ones(4, 3), torch.ones(4, 3))) 13262 13263 def test_lhs_advanced_indexing_augmented_assignment(self): 13264 def foo(x, y): 13265 a = torch.exp(x) 13266 b = x == 1 13267 a[b] += y[b] 13268 return a 13269 self.checkScript(foo, (torch.ones(4, 3), torch.ones(4, 3))) 13270 13271 def test_lhs_indexing_list(self): 13272 def foo(a, b): 13273 ls = [a] 13274 ls[0] = b 13275 return ls 13276 self.checkScript(foo, (torch.rand(2, 3), torch.rand(3))) 13277 13278 def test_inplace_copy_script(self): 13279 def foo(x): 13280 a = torch.rand(3, 4) 13281 a.copy_(x) 13282 return a 13283 self.checkScript(foo, (torch.rand(3, 4),)) 13284 13285 def test_lhs_indexing_increment(self): 13286 def foo(a, b): 13287 a[0] += b 13288 return a 13289 self.checkScript(foo, (torch.rand(2, 3), torch.rand(3))) 13290 13291 def test_lhs_indexing_increment_list(self): 13292 def foo(a, b): 13293 a = a.clone() 13294 ls = [a, b] 13295 ls[0] += b 13296 return ls 13297 self.checkScript(foo, (torch.rand(2, 3), torch.rand(3))) 13298 13299 def test_lhs_indexing_increment_list_prim(self): 13300 def foo(): 13301 ls = [1, 2, 3] 13302 ls[0] += 5 13303 return ls 13304 self.checkScript(foo, ()) 13305 13306 def test_lhs_indexing_multi(self): 13307 def foo(a, b): 13308 a = a.clone() 13309 foo, a[0], bar = (1, b, 3) 13310 return foo, a, bar 13311 self.checkScript(foo, (torch.rand(2, 3), torch.rand(3))) 13312 13313 def test_bool_dispatch(self): 13314 with torch._jit_internal._disable_emit_hooks(): # TODO: Python print broadcasting list 13315 def kwarg_false(x): 13316 # type: (Tensor) -> Tensor 13317 return F.max_pool1d(x, 1, 1, return_indices=False) 13318 self.checkScript(kwarg_false, (torch.randn(3, 3, 3),)) 13319 13320 def kwarg_true(x): 13321 # type: (Tensor) -> Tuple[Tensor, Tensor] 13322 return F.max_pool1d(x, 1, 1, return_indices=True) 13323 self.checkScript(kwarg_true, (torch.randn(3, 3, 3),)) 13324 13325 def full_kwarg_false(x): 13326 # type: (Tensor) -> Tensor 13327 return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=False) 13328 self.checkScript(full_kwarg_false, (torch.randn(3, 3, 3),)) 13329 13330 def full_kwarg_true(x): 13331 # type: (Tensor) -> Tuple[Tensor, Tensor] 13332 return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=True) 13333 self.checkScript(full_kwarg_true, (torch.randn(3, 3, 3),)) 13334 13335 def use_default(x): 13336 # type: (Tensor) -> Tensor 13337 return F.max_pool1d(x, 1, 1) 13338 self.checkScript(use_default, (torch.randn(3, 3, 3),)) 13339 13340 def arg_false(x): 13341 # type: (Tensor) -> Tensor 13342 return F.max_pool1d(x, 1, 1, 0, 1, False, False) 13343 self.checkScript(arg_false, (torch.randn(3, 3, 3),)) 13344 13345 def arg_true(x): 13346 # type: (Tensor) -> Tuple[Tensor, Tensor] 13347 return F.max_pool1d(x, 1, 1, 0, 1, False, True) 13348 self.checkScript(arg_true, (torch.randn(3, 3, 3),)) 13349 13350 def test_infer_size(self): 13351 from torch._C import _infer_size 13352 13353 def fn(x, y): 13354 # type: (Tensor, Tensor) -> List[int] 13355 return _infer_size(x.size(), y.size()) 13356 13357 self.checkScript(fn, (torch.ones(2, 4, 2), torch.ones(2, 4, 2))) 13358 13359 def test_hash(self): 13360 def tester(fn, inputs): 13361 for x in inputs: 13362 for y in inputs: 13363 if x == y: 13364 self.assertEqual(fn(x), fn(y)) 13365 else: 13366 self.assertNotEqual(fn(x), fn(y)) 13367 13368 @torch.jit.script 13369 def int_hash(x): 13370 # type: (int) -> int 13371 return hash(x) 13372 13373 @torch.jit.script 13374 def float_hash(x): 13375 # type: (float) -> int 13376 return hash(x) 13377 13378 @torch.jit.script 13379 def str_hash(x): 13380 # type: (str) -> int 13381 return hash(x) 13382 13383 tester(int_hash, (20, 21, 22)) 13384 tester(float_hash, (20.0, 21.00001, 22.443)) 13385 tester(str_hash, ("", "hello", "a")) 13386 13387 def test_id(self): 13388 with self.assertRaisesRegex(RuntimeError, "Expected a value"): 13389 @torch.jit.script 13390 def test_id_scalars(): 13391 return id(2) == id(None) 13392 13393 @torch.jit.script 13394 class FooTest: 13395 def __init__(self, x): 13396 self.foo = x 13397 13398 def getFooTest(self): 13399 return self.foo 13400 13401 @torch.jit.script 13402 def test_id_class_types(): 13403 obj1 = FooTest(torch.tensor(3)) 13404 obj2 = FooTest(torch.tensor(2)) 13405 assert obj1 is not obj2 13406 assert id(obj1) != id(obj2) 13407 assert id(obj1) != id(None) 13408 return True 13409 13410 self.assertTrue(test_id_class_types()) 13411 13412 def test_mutable_dce(self): 13413 @torch.jit.script 13414 def foo(): 13415 a = torch.rand(2, 3) 13416 a += torch.rand(2, 3) 13417 b = torch.rand(2, 3) 13418 b += torch.rand(2, 3) 13419 # b should be cleaned up but not a 13420 return a 13421 13422 FileCheck().check_count("aten::rand", 2, exactly=True) \ 13423 .check_count("aten::add", 1, exactly=True).run(str(foo.graph)) 13424 13425 def test_mutable_dce_block(self): 13426 @torch.jit.script 13427 def foo(): 13428 a = torch.rand(2, 3) 13429 a += torch.rand(2, 3) 13430 b = torch.rand(2, 3) 13431 if bool(a > torch.zeros(2, 3)): 13432 b += torch.rand(2, 3) 13433 a += torch.rand(2, 3) 13434 # a should be cleaned up but not b 13435 return b 13436 13437 FileCheck().check("prim::If").check_count("aten::rand", 1, exactly=True) \ 13438 .run(str(foo.graph)) 13439 13440 def test_mutable_dce_graph_input(self): 13441 @torch.jit.script 13442 def foo(a): 13443 a += torch.rand(2, 3) 13444 # shouldn't clean up `a` even though it's not used in the output 13445 13446 FileCheck().check("aten::rand").check("aten::add").run(str(foo.graph)) 13447 13448 def test_mutable_dce_list(self): 13449 @torch.jit.script 13450 def foo(a): 13451 l = [] 13452 l.append(a) 13453 c = l[0] 13454 b = torch.rand(2, 3) 13455 c += torch.rand(2, 3) 13456 return b 13457 13458 # c does not get cleaned up because there is a wildcard + mutation 13459 FileCheck().check_count("aten::rand", 2, exactly=True).run(str(foo.graph)) 13460 13461 def test_mutable_dce_loop(self): 13462 @torch.jit.script 13463 def foo(a): 13464 l = [] 13465 l.append(a) 13466 i = 0 13467 b = torch.rand(2, 3) 13468 while i < 1: 13469 dead = torch.rand(2, 3) 13470 c = l[0] 13471 c += torch.rand(2, 3) 13472 i += 1 13473 return b 13474 13475 FileCheck().check("prim::Loop").check_not("aten::rand").check("aten::__getitem__") \ 13476 .check_count("aten::rand", 1, exactly=True).run(str(foo.graph)) 13477 13478 def test_mutable_dce_indirect_wildcards(self): 13479 def fn(): 13480 x = torch.ones(2, 3) 13481 x_1 = x.view(-1) 13482 l = [] 13483 l.append(x_1) 13484 x_view = l[0] 13485 x.add_(torch.ones(2, 3)) 13486 return x_view 13487 self.checkScript(fn, ()) 13488 13489 def test_mutable_dce_indirect_wildcard_write(self): 13490 def fn(): 13491 indexes = torch.jit.annotate(List[Tensor], []) 13492 word_ids = torch.zeros(10, dtype=torch.int32) 13493 word_ids[1] = 1 13494 indexes.append(word_ids) 13495 13496 return word_ids 13497 self.checkScript(fn, ()) 13498 13499 def test_mutable_dce_wildcards(self): 13500 def fn(): 13501 x = torch.ones(2, 3) 13502 l = [] 13503 l.append(x) 13504 x_view = l[0] 13505 x.add_(torch.ones(2, 3)) 13506 return x_view 13507 13508 self.checkScript(fn, (), profiling=ProfilingMode.SIMPLE) 13509 13510 def test_cpp_function_tensor_str(self): 13511 x = torch.randn(2, 2) 13512 scale = torch.randn(2, 2, requires_grad=True) 13513 shift = torch.randn(2, 2, requires_grad=True) 13514 13515 @torch.jit.script 13516 def fn(x, scale, shift): 13517 return scale * x + shift 13518 13519 with self.capture_stdout() as captured: 13520 print(fn(x, scale, shift)) 13521 13522 def test_string_index(self): 13523 def fn(x): 13524 # type: (str) 13525 return x[2], x[-1] 13526 13527 self.checkScript(fn, ("abcde",)) 13528 13529 def test_ord(self): 13530 def fn(x): 13531 # type: (str) -> int 13532 return ord(x) 13533 13534 self.checkScript(fn, ("h")) 13535 self.checkScript(fn, ("y")) 13536 13537 def index_str_to_tensor(s): 13538 # type: (str) -> Tensor 13539 return torch.tensor(ord(s)) # noqa: T484 13540 13541 s = '\u00a3'.encode()[:1] 13542 self.checkScript(index_str_to_tensor, (s,)) 13543 13544 def test_chr(self): 13545 def fn(x): 13546 # type: (int) -> str 13547 return chr(x) 13548 13549 self.checkScript(fn, (1,)) 13550 self.checkScript(fn, (97,)) 13551 13552 def test_round(self): 13553 def round_float(x): 13554 # type: (float) -> float 13555 return round(x) 13556 13557 def round_int(x): 13558 # type: (int) -> float 13559 return round(x) 13560 13561 self.checkScript(round_float, (1.5,)) 13562 self.checkScript(round_int, (2,)) 13563 13564 def test_convert_base(self): 13565 def test_hex(x): 13566 # type: (int) -> str 13567 return hex(x) 13568 13569 def test_oct(x): 13570 # type: (int) -> str 13571 return oct(x) 13572 13573 def test_bin(x): 13574 # type: (int) -> str 13575 return bin(x) 13576 13577 numbers = [-1000, -10, 0, 1, 10, 2343] 13578 for n in numbers: 13579 self.checkScript(test_bin, (n,)) 13580 self.checkScript(test_oct, (n,)) 13581 self.checkScript(test_hex, (n,)) 13582 13583 @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle") 13584 def test_get_set_state(self): 13585 class Root(torch.jit.ScriptModule): 13586 __constants__ = ['number'] 13587 13588 def __init__(self, number): 13589 super().__init__() 13590 self.buffer1 = nn.Buffer(torch.ones(2, 2)) 13591 self.buffer2 = nn.Buffer(torch.ones(2, 2)) 13592 self.number = number 13593 13594 @torch.jit.script_method 13595 def __getstate__(self): 13596 return (self.buffer1, self.buffer2, 74, self.training) 13597 13598 @torch.jit.script_method 13599 def __setstate__(self, state): 13600 self.buffer1 = state[0] + 10 13601 self.buffer2 = state[1] + 10 13602 self.training = state[3] 13603 13604 class M(torch.jit.ScriptModule): 13605 __constants__ = ['number'] 13606 13607 def __init__(self, number, submodule): 13608 super().__init__() 13609 self.buffer1 = nn.Buffer(torch.ones(2, 2)) 13610 self.buffer2 = nn.Buffer(torch.ones(2, 2)) 13611 self.number = number 13612 self.submodule = submodule 13613 13614 @torch.jit.script_method 13615 def __getstate__(self): 13616 return (self.buffer1, self.buffer2, 74, self.submodule, self.training) 13617 13618 @torch.jit.script_method 13619 def __setstate__(self, state): 13620 self.buffer1 = state[0] + 10 13621 self.buffer2 = state[1] + 10 13622 self.submodule = state[3] 13623 self.training = state[4] 13624 13625 with TemporaryFileName() as fname: 13626 m = M(23, submodule=Root(99)) 13627 m.save(fname) 13628 loaded = torch.jit.load(fname) 13629 13630 # Check original module 13631 self.assertEqual(m.buffer1, torch.ones(2, 2)) 13632 self.assertEqual(m.buffer2, torch.ones(2, 2)) 13633 13634 # Check top level module 13635 self.assertEqual(loaded.buffer1, torch.ones(2, 2) + 10) 13636 self.assertEqual(loaded.buffer2, torch.ones(2, 2) + 10) 13637 13638 # Check submodule 13639 self.assertEqual(loaded.submodule.buffer1, torch.ones(2, 2) + 10) 13640 self.assertEqual(loaded.submodule.buffer2, torch.ones(2, 2) + 10) 13641 13642 # Check simpler module 13643 class NoArgState(torch.nn.Module): 13644 def __init__(self) -> None: 13645 super().__init__() 13646 self.buffer1 = nn.Buffer(torch.ones(2, 2)) 13647 self.buffer2 = nn.Buffer(torch.ones(2, 2)) 13648 13649 def forward(self): 13650 pass 13651 13652 @torch.jit.export 13653 def __getstate__(self): 13654 return 5, self.training 13655 13656 @torch.jit.export 13657 def __setstate__(self, state): 13658 self.buffer1 = torch.ones(2, 2) + state[0] 13659 self.buffer2 = torch.ones(2, 2) + 10 13660 self.training = state[1] 13661 13662 with TemporaryFileName() as fname: 13663 m = torch.jit.script(NoArgState()) 13664 m.save(fname) 13665 loaded = torch.jit.load(fname) 13666 self.assertEqual(loaded.buffer1, torch.ones(2, 2) + 5) 13667 self.assertEqual(loaded.buffer2, torch.ones(2, 2) + 10) 13668 13669 13670 13671 def test_string_slicing(self): 13672 def fn1(x): 13673 # type: (str) -> str 13674 return x[1:3] 13675 13676 def fn2(x): 13677 # type: (str) -> str 13678 return x[-1:3] 13679 13680 def fn3(x): 13681 # type: (str) -> str 13682 return x[3:1] 13683 13684 def fn4(x): 13685 # type: (str) -> str 13686 return x[3:100] 13687 13688 self.checkScript(fn1, ("abcdefghi",)) 13689 self.checkScript(fn2, ("abcdefghi",)) 13690 self.checkScript(fn3, ("abcdefghi",)) 13691 self.checkScript(fn4, ("abcdefghi",)) 13692 13693 def test_early_return_closure(self): 13694 code = dedent(''' 13695 def tanh(self): 13696 output = torch.tanh(self) 13697 def backward(grad_output): 13698 pass 13699 return output, backward 13700 ''') 13701 cu = torch.jit.CompilationUnit(code) 13702 g = cu.tanh.graph 13703 FileCheck().check_count("prim::Closure_0", 2).check("NoneType = prim::Constant") \ 13704 .check_next("return").run(g) 13705 13706 code = dedent(''' 13707 def tanh(self): 13708 output = torch.tanh(self) 13709 def backward(grad_output): 13710 a = 1 13711 if output: 13712 return 1 13713 else: 13714 a = 2 13715 return a 13716 return output, backward 13717 ''') 13718 cu = torch.jit.CompilationUnit(code) 13719 g = cu.tanh.graph 13720 FileCheck().check_count("prim::Closure_0", 2).check("int = prim::If") \ 13721 .run(g) 13722 13723 code = dedent(''' 13724 def loop_in_closure(self): 13725 output = torch.tanh(self) 13726 def backward(grad_output): 13727 for i in range(3): 13728 return 1 13729 return 4 13730 return output, backward 13731 ''') 13732 cu = torch.jit.CompilationUnit(code) 13733 fc = FileCheck() 13734 fc.check("prim::Closure").check("(Tensor, NoneType) = prim::TupleConstruct") 13735 # Loop then two if's added in exit transform 13736 fc.check("prim::Closure").check("prim::Loop").check_count("prim::If", 2) 13737 fc.run(cu.loop_in_closure.graph) 13738 13739 code = dedent(''' 13740 def tanh(self): 13741 output = torch.tanh(self) 13742 def backward(grad_output): 13743 if 1 == 1: 13744 return 1 13745 else: 13746 return 1. 13747 return output, backward 13748 ''') 13749 with self.assertRaisesRegex(RuntimeError, "returned a value of type int but"): 13750 cu = torch.jit.CompilationUnit(code) 13751 13752 @_inline_everything 13753 def test_early_return_fork_join(self): 13754 @torch.jit.script 13755 def foo(x): 13756 if x.dim() == 2: 13757 return torch.neg(x), x 13758 else: 13759 return torch.neg(x), x + 1 13760 13761 x = torch.rand(3, 4) 13762 13763 @torch.jit.script 13764 def wait_script(x): 13765 fut = torch.jit._fork(foo, x) 13766 y_hat = foo(x) 13767 y = torch.jit._wait(fut) 13768 return y, y_hat 13769 13770 FileCheck().check("with prim::fork").check("prim::If").check("return")\ 13771 .run(wait_script.graph) 13772 13773 def test_early_return_type_refinement(self): 13774 @torch.jit.script 13775 def test(x): 13776 # type: (Optional[int]) -> int 13777 if x is None: 13778 return 1 13779 else: 13780 return x 13781 self.assertEqual(test(None), 1) 13782 self.assertEqual(test(2), 2) 13783 13784 def test_exceptions_with_control_flow(self): 13785 def test_num_ifs(func, num_ifs): 13786 g = torch.jit.script(func).graph 13787 FileCheck().check_count("prim::If", num_ifs, exactly=True).run(g) 13788 13789 def no_guard_ifs_added(x): 13790 # type: (int) -> int 13791 if x == 1: 13792 return 1 13793 else: 13794 if x == 2: 13795 raise RuntimeError("hi") 13796 else: 13797 raise RuntimeError("hi") 13798 13799 self.checkScript(no_guard_ifs_added, (1,)) 13800 self.checkScriptRaisesRegex(no_guard_ifs_added, (2,), Exception, "") 13801 test_num_ifs(no_guard_ifs_added, 2) 13802 13803 # FUNCTION LOOKS LIKE: 13804 # graph(%x.1 : int): 13805 # %7 : str = prim::Constant[value="Exception"]() 13806 # %2 : int = prim::Constant[value=1]() 13807 # %5 : int = prim::Constant[value=2]() 13808 # %19 : int = prim::Uninitialized() 13809 # %3 : bool = aten::eq(%x.1, %2) 13810 # %20 : int = prim::If(%3) 13811 # block0(): 13812 # -> (%2) 13813 # block1(): 13814 # %6 : bool = aten::eq(%x.1, %5) 13815 # = prim::If(%6) 13816 # block0(): 13817 # = prim::RaiseException(%7) 13818 # -> () 13819 # block1(): 13820 # = prim::RaiseException(%7) 13821 # -> () 13822 # -> (%19) 13823 # return (%20) 13824 13825 def no_ifs_added(x): 13826 # type: (int) -> int 13827 if x < 0: 13828 raise RuntimeError("hi") 13829 return x 13830 13831 self.checkScript(no_ifs_added, (1,)) 13832 self.checkScriptRaisesRegex(no_ifs_added, (-2,), Exception, "") 13833 test_num_ifs(no_ifs_added, 1) 13834 13835 def test_if_might(x): 13836 # type: (int) 13837 if x > 0: 13838 if x == 1: 13839 return 1 13840 else: 13841 a = 2 13842 else: 13843 raise RuntimeError("hi") 13844 return a + 2 13845 13846 self.checkScript(test_if_might, (1,)) 13847 self.checkScript(test_if_might, (3,)) 13848 self.checkScriptRaisesRegex(no_ifs_added, (-2,), Exception, "") 13849 test_num_ifs(test_if_might, 3) # one if added to guard a + 2 13850 13851 def test_loop_no_escape(x): 13852 # type: (int) 13853 if x >= 0: 13854 for i in range(x): 13855 raise RuntimeError("hi") 13856 else: 13857 return 5 13858 return x + 3 13859 13860 self.checkScript(test_loop_no_escape, (0,)) 13861 self.checkScript(test_loop_no_escape, (-1,)) 13862 self.checkScriptRaisesRegex(test_loop_no_escape, (1,), Exception, "") 13863 13864 # if guard gets optimized away 13865 test_num_ifs(test_loop_no_escape, 1) 13866 13867 def test_loop_exception_with_continue(x): 13868 # type: (int) 13869 i = 0 13870 for i in range(5): 13871 if i == x: 13872 raise RuntimeError("hi") 13873 else: 13874 continue 13875 print(i) 13876 return i + 5 13877 13878 self.checkScript(test_loop_exception_with_continue, (-1,)) 13879 self.checkScriptRaisesRegex(test_loop_exception_with_continue, (1,), Exception, "") 13880 test_num_ifs(test_loop_exception_with_continue, 1) # no ifs added to guard print 13881 13882 13883 def test_exception_exits_closure(self): 13884 code = dedent(''' 13885 def no_return_func(self): 13886 # type: (Tensor) -> Tensor 13887 output = torch.tanh(self) 13888 def backward(grad_output): 13889 raise RuntimeError("Hi") 13890 ''') 13891 with self.assertRaisesRegex(RuntimeError, "does not return along all"): 13892 cu = torch.jit.CompilationUnit(code) 13893 13894 code = dedent(''' 13895 def test_exit_pair_reset(x): 13896 # type: (int) -> int 13897 if x > 0: 13898 a = 0 13899 def backward(grad_output): 13900 raise RuntimeError("Hi") 13901 a = a + 1 13902 else: 13903 return x 13904 return a + 1 13905 ''') 13906 func = torch.jit.CompilationUnit(code).test_exit_pair_reset 13907 self.assertEqual(func(1,), 2) 13908 self.assertEqual(func(-1,), -1) 13909 # final a + 1 gets inlined into the first branch and optimized away 13910 FileCheck().check_count("prim::If", 1, exactly=True).run(func.graph) 13911 13912 def test_non_final_return(self): 13913 def simple(x): 13914 if bool(x > 3): 13915 return x + 1 13916 else: 13917 return x + 2 13918 raise RuntimeError("nope") 13919 13920 def nest(x): 13921 x = x + 1 13922 if bool(x > 3): 13923 if bool(x > 4): 13924 x += 1 13925 return x + 1 13926 else: 13927 return x + 2 13928 13929 def early_ret(x): 13930 x = x + 1 13931 if bool(x > 3): 13932 return x + 1 13933 x = x + 1 13934 return x + 2 13935 13936 def nest_early_ret(x): 13937 x = x + 1 13938 if bool(x > 3): 13939 if bool(x > 4): 13940 return x + 2 13941 return x + 1 13942 x = x + 1 13943 return x + 2 13944 13945 def not_early_ret(x): 13946 s = "" 13947 if bool(x > 3): 13948 if bool(x > 4): 13949 return 1, s 13950 s += "foo" 13951 else: 13952 s += "5" 13953 s += "hi" 13954 return 7, s 13955 13956 def not_total_ret(x): 13957 s = "" 13958 if bool(x > 3): 13959 if bool(x > 4): 13960 return 1, s 13961 else: 13962 return 2, s 13963 else: 13964 s += "5" 13965 return 7, s 13966 13967 for i in range(3): 13968 for func in [simple, nest, early_ret, nest_early_ret, not_early_ret, 13969 not_total_ret]: 13970 self.checkScript(func, (torch.tensor(2.5 + i),)) 13971 13972 def vars_used_after_ret(x): 13973 # type: (int) -> int 13974 if x == 0: 13975 return x 13976 else: 13977 y = 2 13978 z = 3 13979 return x + y * z 13980 13981 self.checkScript(vars_used_after_ret, (1,)) 13982 self.checkScript(vars_used_after_ret, (0,)) 13983 13984 def complicated(x): 13985 # type: (int) -> int 13986 if x: 13987 if x == 2: 13988 return 1 13989 assert 1 == 2 13990 else: 13991 if x == 3: 13992 return 2 13993 assert 1 == 2 13994 else: 13995 a = 2 13996 b = 3 13997 else: 13998 a = 4 13999 b = 1 14000 return a + b 14001 assert 1 == 2 14002 14003 for i in range(4): 14004 self.checkScript(complicated, (i,)) 14005 14006 def test_partial_returns(self): 14007 with self.assertRaisesRegex(RuntimeError, "does not return along all"): 14008 @torch.jit.script 14009 def no_ret(): 14010 # type: () -> int 14011 pass 14012 14013 with self.assertRaisesRegex(RuntimeError, "does not return along all"): 14014 @torch.jit.script 14015 def partial(x): 14016 # type: (Tensor) -> int 14017 if x: 14018 return 1 14019 14020 with self.assertRaisesRegex(RuntimeError, "does not return along all"): 14021 @torch.jit.script 14022 def typed_none(): 14023 # type: () -> Optional[int] 14024 pass 14025 14026 @torch.jit.script 14027 def none_ret(): 14028 pass 14029 14030 self.assertIs(none_ret(), None) 14031 FileCheck().check(": None").run(none_ret.graph) 14032 14033 def test_early_returns_loops(self): 14034 def nest_while_ret(x): 14035 # type: (int) -> int 14036 y = 4 14037 while x < 4: 14038 if x < 3: 14039 return y 14040 else: 14041 y = y + 1 14042 break 14043 y = y + 2 14044 y = y + 1 14045 return y 14046 14047 self.checkScript(nest_while_ret, (2,)) 14048 self.checkScript(nest_while_ret, (3,)) 14049 self.checkScript(nest_while_ret, (4,)) 14050 14051 def loop_ret(x, y): 14052 # type: (int, int) -> (int) 14053 i = 0 14054 for i in range(x): 14055 if x == y: 14056 return x + y 14057 i = i + y 14058 i = i - 1 14059 return i 14060 14061 self.checkScript(loop_ret, (3, 3)) 14062 self.checkScript(loop_ret, (2, 3)) 14063 self.checkScript(loop_ret, (3, 1)) 14064 14065 def test_will_ret(y): 14066 # type: (int) -> int 14067 for i in range(y): 14068 return 2 14069 return 1 14070 14071 self.checkScript(test_will_ret, (0,)) 14072 self.checkScript(test_will_ret, (1,)) 14073 14074 def test_loop_nest_ret(y): 14075 # type: (int) -> int 14076 for i in range(y): 14077 for i in range(y - 2): 14078 return 10 14079 return 5 14080 return 0 14081 14082 self.checkScript(test_loop_nest_ret, (0,)) 14083 self.checkScript(test_loop_nest_ret, (1,)) 14084 self.checkScript(test_loop_nest_ret, (2,)) 14085 14086 def test_nn_init(self): 14087 tests = ( 14088 ('constant_', (lambda: (torch.ones(2, 2), 2.5)), "Tensor, float"), 14089 ('ones_', (lambda: (torch.ones(2, 2),)), "Tensor"), 14090 ('zeros_', (lambda: (torch.ones(2, 2),)), "Tensor"), 14091 ('uniform_', (lambda: (torch.ones(2, 2),)), "Tensor"), 14092 ('normal_', (lambda: (torch.ones(2, 2),)), "Tensor"), 14093 ('xavier_normal_', (lambda: (torch.ones(2, 2),)), "Tensor"), 14094 ('xavier_uniform_', (lambda: (torch.ones(2, 2),)), "Tensor"), 14095 ) 14096 14097 for name, args_fn, type_str in tests: 14098 # Build test code 14099 arg_str = ', '.join([chr(i + ord('a')) for i in range(len(args_fn()))]) 14100 14101 code = dedent(''' 14102 def test({arg_str}): 14103 # type: ({type_str}) 14104 return torch.nn.init.{name}({arg_str}) 14105 ''').format(arg_str=arg_str, type_str=type_str, name=name) 14106 cu = torch.jit.CompilationUnit(code) 14107 14108 # Compare functions 14109 init_fn = getattr(torch.nn.init, name) 14110 script_out = self.runAndSaveRNG(cu.test, args_fn()) 14111 eager_out = self.runAndSaveRNG(init_fn, args_fn()) 14112 self.assertEqual(script_out, eager_out) 14113 14114 FileCheck().check_not("prim::PythonOp").run(cu.test.graph) 14115 14116 def test_nn_init_generator(self): 14117 init_fns = ( 14118 'uniform_', 'normal_', 'xavier_normal_', 'xavier_uniform_', 14119 ) 14120 14121 for name in init_fns: 14122 # Build test code 14123 code = dedent(''' 14124 def test(tensor, generator): 14125 # type: (Tensor, Generator) 14126 return torch.nn.init.{name}(tensor, generator=generator) 14127 ''').format(name=name) 14128 cu = torch.jit.CompilationUnit(code) 14129 14130 # Compare functions 14131 init_fn = getattr(torch.nn.init, name) 14132 14133 torch.manual_seed(1) 14134 14135 g = torch.Generator() 14136 g.manual_seed(2023) 14137 script_out = cu.test(torch.ones(2, 2), g) 14138 14139 # Change the seed of the default generator to make 14140 # sure that we're using the provided generator 14141 torch.manual_seed(2) 14142 14143 g = torch.Generator() 14144 g.manual_seed(2023) 14145 eager_out = init_fn(torch.ones(2, 2), generator=g) 14146 14147 self.assertEqual(script_out, eager_out) 14148 14149 FileCheck().check_not("prim::PythonOp").run(cu.test.graph) 14150 14151 def test_early_return_rewrite(self): 14152 def test_foo(x: bool): 14153 if x: 14154 return 1 14155 return 2 14156 14157 self.checkScript(test_foo, (True,)) 14158 self.checkScript(test_foo, (False,)) 14159 FileCheck().check_count("prim::If", 1, exactly=True).run(torch.jit.script(test_foo).graph) 14160 14161 def test_multiple(x: int): 14162 if x == 5: 14163 return x * x 14164 else: 14165 y = 2 * x 14166 14167 z = y * 2 14168 if z == 8: 14169 return 1 14170 14171 if z != 16: 14172 z = z - 2 14173 abc = 4 14174 else: 14175 return 3 14176 14177 z = z * abc 14178 return z * z * z 14179 14180 self.checkScript(test_multiple, (5,)) 14181 self.checkScript(test_multiple, (2,)) 14182 self.checkScript(test_multiple, (4,)) 14183 self.checkScript(test_multiple, (3,)) 14184 self.checkScript(test_multiple, (10,)) 14185 14186 graph = torch.jit.script(test_multiple).graph 14187 FileCheck().check_count("prim::If", 3, exactly=True).run(graph) 14188 14189 def test_is_scripting_metacompile(self): 14190 @torch.jit.script 14191 def foo(): 14192 if torch.jit.is_scripting(): 14193 return 1 14194 else: 14195 print("hello") + 2 # will not be compiled 14196 14197 self.assertEqual(foo(), 1) 14198 14199 def test_boolean_literal_constant_metacompile(self): 14200 class Mod(torch.nn.Module): 14201 __constants__ = ['val'] 14202 14203 def __init__(self, val): 14204 super().__init__() 14205 self.val = val 14206 14207 def forward(self): 14208 if self.val: 14209 return 1 14210 else: 14211 return "2" 14212 14213 self.checkModule(Mod(True), ()) 14214 self.checkModule(Mod(False), ()) 14215 14216 @torch.jit.script 14217 def foo(): 14218 if True: 14219 return 1 14220 else: 14221 return "2" 14222 14223 self.assertEqual(foo(), 1) 14224 14225 def test_assert_is_scripting_metacompile(self): 14226 def foo(): 14227 assert not torch.jit.is_scripting(), "TestErrorMsg" 14228 print("hello") + 2 # will not be compiled 14229 14230 f = torch.jit.script(foo) 14231 with self.assertRaisesRegex(torch.jit.Error, "TestErrorMsg"): 14232 f() 14233 14234 def test_isinstance_metacompile(self): 14235 @torch.jit.script 14236 def test_primitive_type(x): 14237 # type: (int) -> int 14238 if isinstance(x, int): 14239 return x + 1 14240 else: 14241 return x - 1 14242 14243 self.assertEqual(test_primitive_type(1), 2) 14244 with self.assertRaisesRegex(Exception, "Expected a value of type"): 14245 test_primitive_type(1.5) 14246 14247 _MyNamedTuple = namedtuple('_MyNamedTuple', ['value']) 14248 14249 @torch.jit.script 14250 def test_non_primitive_types(x): 14251 # type: (_MyNamedTuple) -> Tensor 14252 if isinstance(1, _MyNamedTuple): 14253 return 10 14254 14255 if isinstance(x, _MyNamedTuple): 14256 return x.value + 1 14257 else: 14258 return 1 14259 14260 out = test_non_primitive_types(_MyNamedTuple(value=torch.tensor(5.0))) 14261 self.assertEqual(out, torch.tensor(6.0)) 14262 14263 def test_namedtuple_type_inference(self): 14264 _AnnotatedNamedTuple = NamedTuple('_NamedTupleAnnotated', [('value', int)]) # noqa: UP014 14265 _UnannotatedNamedTuple = namedtuple('_NamedTupleUnAnnotated', ['value']) 14266 14267 def test_check_named_tuple_value(): 14268 named_tuple = _AnnotatedNamedTuple(1) 14269 return named_tuple.value 14270 14271 self.checkScript(test_check_named_tuple_value, ()) 14272 14273 def test_error(): 14274 return _UnannotatedNamedTuple(1) 14275 14276 with self.assertRaisesRegex(RuntimeError, r"Expected a value of type \'Tensor \(inferred\)\' " 14277 r"for argument \'value\' but instead found type \'int\'."): 14278 torch.jit.script(test_error) 14279 14280 def test_namedtuple_default_values_simple_type(self): 14281 14282 class Point(NamedTuple): 14283 x: Optional[int] = None 14284 y: int = 2 14285 14286 make_global(Point) 14287 14288 class M(torch.nn.Module): 14289 def forward(self, point: Point): 14290 return point 14291 14292 p = Point(x=3, y=2) 14293 14294 self.checkModule(M(), (p,)) 14295 self.checkModule(M(), (Point(),)) 14296 14297 m = torch.jit.script(M()) 14298 14299 FileCheck().check(r"NamedTuple(x : int? = None, y : int = 2))") \ 14300 .run(m.graph) 14301 14302 def test_namedtuple_default_values_missing(self): 14303 14304 class Point(NamedTuple): 14305 x: Optional[int] 14306 y: int 14307 z: int = 3 14308 14309 make_global(Point) 14310 14311 class M(torch.nn.Module): 14312 def forward(self, point: Point): 14313 return point 14314 14315 p1 = Point(x=3, y=2) 14316 p2 = Point(x=3, y=2, z=1) 14317 14318 self.checkModule(M(), (p1,)) 14319 self.checkModule(M(), (p2,)) 14320 14321 m = torch.jit.script(M()) 14322 14323 FileCheck().check(r"NamedTuple(x : int?, y : int, z : int = 3))") \ 14324 .run(m.graph) 14325 14326 def test_namedtuple_default_values_container_type(self): 14327 14328 class Point(NamedTuple): 14329 x: Optional[List[int]] = None 14330 y: List[int] = [1, 2, 3] 14331 z: Optional[Dict[str, int]] = {"a": 1} 14332 14333 make_global(Point) 14334 14335 class M(torch.nn.Module): 14336 def forward(self, point: Point): 14337 return point 14338 14339 p = Point(x=[4, 5, 6], y=[3, 2, 1], z={"b": 2}) 14340 14341 self.checkModule(M(), (p,)) 14342 self.checkModule(M(), (Point(),)) 14343 14344 m = torch.jit.script(M()) 14345 14346 first_line = r"NamedTuple(x : int[]? = None, y : int[] = " \ 14347 r"[1, 2, 3], z : Dict(str, int)? = {a: 1}))" 14348 14349 FileCheck().check(first_line) \ 14350 .run(m.graph) 14351 14352 def test_namedtuple_default_values_Tensor_type(self): 14353 14354 class Point(NamedTuple): 14355 x: torch.Tensor = torch.rand(2, 3) 14356 14357 make_global(Point) 14358 14359 class M(torch.nn.Module): 14360 def forward(self, point: Point): 14361 return point 14362 14363 p = Point(x=torch.rand(2, 3)) 14364 14365 with self.assertRaisesRegex(RuntimeError, "Tensors are not " 14366 "supported as default NamedTuple " 14367 "fields"): 14368 m = torch.jit.script(M()) 14369 m(p) 14370 14371 def test_namedtuple_default_values_using_factory_constructor(self): 14372 Pair = namedtuple("Pair", ["x", "y"], defaults=(1, 2)) 14373 14374 make_global(Pair) 14375 14376 @torch.jit.script 14377 def fn(x: Pair) -> Pair: 14378 return x 14379 14380 # TODO: We can't use `checkScript` with the NamedTuple factory 14381 # constructor. Using the factory constructor with TorchScript 14382 # TorchScript creates an anonymous `NamedTuple` class instead of 14383 # preserving the actual name. For example, the actual generated 14384 # signature in this case is: 14385 # graph(%x.1 : NamedTuple(x : Tensor, y : Tensor)) 14386 # It looks like similar test cases have had this issue as well 14387 # (see: `test_namedtuple_python`). 14388 FileCheck().check(r"NamedTuple(x : Tensor = 1, y : Tensor = 2))") \ 14389 .check_next(r"return (%x.1)") \ 14390 .run(fn.graph) 14391 14392 def test_isinstance_dynamic(self): 14393 @torch.jit.script 14394 def foo(a): 14395 # type: (Optional[List[int]]) -> int 14396 b = 0 14397 if isinstance(a, (int, (float,), list, str)): 14398 b += 1 14399 if isinstance(a, (int, str)): 14400 b += 1 14401 if isinstance(a, List[int]): 14402 b += 1 14403 return b 14404 self.assertEqual(foo([3, 4]), 2) 14405 self.assertEqual(foo(None), 0) 14406 14407 def test_function_overloads(self): 14408 # TODO: pyflakes currently does not compose @overload annotation with other 14409 # decorators. This is fixed on master but not on version 2.1.1. 14410 # Next version update remove noqa and add @typing.overload annotation 14411 14412 @torch.jit._overload # noqa: F811 14413 def test_simple(x1): # noqa: F811 14414 # type: (int) -> int 14415 pass 14416 14417 @torch.jit._overload # noqa: F811 14418 def test_simple(x1): # noqa: F811 14419 # type: (float) -> float 14420 pass 14421 14422 def test_simple(x1): # noqa: F811 14423 return x1 14424 14425 def invoke_function(): 14426 return test_simple(1.0), test_simple(.5) 14427 14428 self.checkScript(invoke_function, ()) 14429 14430 # testing that the functions are cached 14431 compiled_fns_1 = torch.jit._script._get_overloads(test_simple) 14432 compiled_fns_2 = torch.jit._script._get_overloads(test_simple) 14433 for a, b in zip(compiled_fns_1, compiled_fns_2): 14434 self.assertIs(a.graph, b.graph) 14435 14436 old_func = test_simple 14437 14438 # testing that new functions added work with caching 14439 @torch.jit._overload # noqa: F811 14440 def test_simple(x1): # noqa: F811 14441 # type: (str) -> str 14442 pass 14443 14444 @torch.jit.script 14445 def my_func(): 14446 return old_func("hi") 14447 14448 # testing new function same qualified name 14449 @torch.jit._overload # noqa: F811 14450 def test_simple(a, b): # noqa: F811 14451 # type: (int, int) -> int 14452 pass 14453 14454 def test_simple(a, b): 14455 return a + b 14456 14457 @torch.jit.script 14458 def fn(): 14459 return test_simple(3, 4) 14460 14461 self.assertEqual(fn(), 7) 14462 14463 # currently we take the default values have to be specified in the 14464 # overload as well - TODO take them from implementation and apply 14465 # where the type is valid. 14466 @torch.jit._overload # noqa: F811 14467 def identity(x1): # noqa: F811 14468 # type: (str) -> str 14469 pass 14470 14471 @torch.jit._overload # noqa: F811 14472 def identity(x1): # noqa: F811 14473 # type: (float) -> float 14474 pass 14475 14476 def identity(x1=1.0): # noqa: F811 14477 return x1 14478 14479 def invoke(): 14480 return identity(), identity(.5), identity("hi") 14481 14482 self.checkScript(invoke, ()) 14483 14484 def schema_match_failure(): 14485 return identity((1, 2)) 14486 14487 thrown = False 14488 try: 14489 torch.jit.script(schema_match_failure) 14490 except Exception as e: 14491 thrown = True 14492 self.assertTrue(r"of type 'str'" in str(e) and r"of type 'float" in str(e)) 14493 self.assertTrue(thrown) 14494 14495 with self.assertRaisesRegex(Exception, "cannot be directly compiled"): 14496 torch.jit.script(identity) 14497 14498 @torch.jit._overload # noqa: F811 14499 def impl_compile_failure(x, y): # noqa: F811 14500 # type: (str, str) -> (str) 14501 pass 14502 14503 @torch.jit._overload # noqa: F811 14504 def impl_compile_failure(x, y): # noqa: F811 14505 # type: (int, int) -> (int) 14506 pass 14507 14508 def impl_compile_failure(x, y): # noqa: F811 14509 return x - y 14510 14511 def test(): 14512 impl_compile_failure("one", "two") 14513 14514 14515 with self.assertRaisesRegex(Exception, "Arguments for call are not valid"): 14516 torch.jit.script(test) 14517 14518 @torch.jit._overload # noqa: F811 14519 def good_overload(x=1): # noqa: F811 14520 # type: (int) -> (int) 14521 pass 14522 14523 def good_overload(x=1): # noqa: F811 14524 return x 14525 14526 @torch.jit.script 14527 def foo(): 14528 return good_overload() 14529 14530 self.assertEqual(foo(), 1) 14531 14532 14533 with self.assertRaisesRegex(Exception, "must equal to the default parameter"): 14534 @torch.jit._overload # noqa: F811 14535 def bad_default_on_overload(x, y=2): # noqa: F811 14536 # type: (int, int) -> (int) 14537 pass 14538 14539 def bad_default_on_overload(x, y=1): # noqa: F811 14540 # type: (int, int) -> (int) 14541 pass 14542 14543 @torch.jit.script 14544 def test(): 14545 return bad_default_on_overload(1, 2) 14546 14547 @torch.jit._overload # noqa: F811 14548 def diff_default(x): # noqa: F811 14549 # type: (int) -> int 14550 pass 14551 14552 @torch.jit._overload # noqa: F811 14553 def diff_default(x): # noqa: F811 14554 # type: (str) -> str 14555 pass 14556 14557 def diff_default(x="hi"): # noqa: F811 14558 return x 14559 14560 def test(): 14561 return diff_default(), diff_default(2), diff_default("abc") 14562 14563 self.assertEqual(test(), torch.jit.script(test)()) 14564 14565 @torch.jit._overload # noqa: F811 14566 def diff_num_params(x): # noqa: F811 14567 # type: (float) -> float 14568 pass 14569 14570 @torch.jit._overload # noqa: F811 14571 def diff_num_params(x, y): # noqa: F811 14572 # type: (int, int) -> int 14573 pass 14574 14575 def diff_num_params(x, y=2, z=3): # noqa: F811 14576 # type: (Union[float, int], int, int) 14577 return x + y + z 14578 14579 def test(): 14580 return diff_num_params(1.0), diff_num_params(1, 2), diff_num_params(1), diff_num_params(1, 2, 3) 14581 14582 self.assertEqual(test(), torch.jit.script(test)()) 14583 14584 @torch.jit._overload # noqa: F811 14585 def diff_num_params_no_annot(): 14586 # type: () -> int 14587 pass 14588 14589 def diff_num_params_no_annot(x=1): # noqa: F811 14590 return x 14591 14592 def test(): 14593 return diff_num_params_no_annot(1.0) 14594 14595 with self.assertRaisesRegex(Exception, "Parameters not specified"): 14596 torch.jit.script(test) 14597 14598 def test_function_overload_misuse(self): 14599 with self.assertRaisesRegex(RuntimeError, "Only `pass` statement or `...` can be the body"): 14600 @torch.jit._overload 14601 def wrong_decl_body(x: str) -> str: 14602 return x + "0" 14603 14604 with self.assertRaisesRegex(RuntimeError, "Only `pass` statement or `...` can be the body"): 14605 class MyClass: 14606 @torch.jit._overload_method 14607 def method(self): 14608 return 0 14609 14610 @torch.jit._overload 14611 def null_overload(x: int) -> int: ... # noqa: E704 14612 14613 @torch.jit._overload # noqa: F811 14614 def null_overload(x: str) -> str: # noqa: F811 14615 pass 14616 14617 def null_overload_driver(): 14618 return null_overload(0) 14619 14620 with self.assertRaisesRegex(RuntimeError, 'Implementation for the function ".+" is missing.'): 14621 torch.jit.script(null_overload_driver) 14622 14623 class OverloadMisuse(torch.nn.Module): 14624 @torch.jit._overload_method 14625 def forward(self, x: int): 14626 pass 14627 14628 @torch.jit._overload_method # noqa: F811 14629 def forward(self, x: Tensor): # noqa: F811 14630 pass 14631 14632 with self.assertRaisesRegex(RuntimeError, 'Implementation for the method ".+" is missing.'): 14633 m = torch.jit.script(OverloadMisuse()) 14634 14635 14636 def test_script_method_torch_function_overload(self): 14637 class MyCustomTensor(torch.Tensor): 14638 pass 14639 14640 class MyCustomModule(torch.nn.Module): 14641 def forward(self, x): 14642 return torch.relu(x) 14643 14644 scripted_mod = torch.jit.script(MyCustomModule()) 14645 t = torch.tensor([3.0]) 14646 ref_out = scripted_mod(t) 14647 14648 t_custom = MyCustomTensor([3.0]) 14649 out1 = scripted_mod(t_custom) 14650 self.assertEqual(out1, ref_out) 14651 14652 out2 = scripted_mod.forward(t_custom) 14653 self.assertEqual(out2, ref_out) 14654 14655 def test_function_overloading_isinstance(self): 14656 @torch.jit._overload # noqa: F811 14657 def my_conv(x, y): # noqa: F811 14658 # type: (float, str) -> (float) 14659 pass 14660 14661 @torch.jit._overload # noqa: F811 14662 def my_conv(x, y): # noqa: F811 14663 # type: (float, float) -> (float) 14664 pass 14665 14666 def my_conv(x, y=2.0): # noqa: F811 14667 if isinstance(y, str): 14668 if y == "hi": 14669 return 4.0 - x 14670 else: 14671 return 5.0 - x 14672 else: 14673 return 2.0 + x 14674 14675 def test_uses(): 14676 return my_conv(1.5), my_conv(1.5, "hi"), my_conv(1.5, 5.0) 14677 14678 self.checkScript(test_uses, ()) 14679 14680 def test_method_overloading(self): 14681 class Over(torch.nn.Module): 14682 @torch.jit._overload_method # noqa: F811 14683 def forward(self, x): # noqa: F811 14684 # type: (Tuple[Tensor, Tensor]) -> Tensor 14685 pass 14686 14687 @torch.jit._overload_method # noqa: F811 14688 def forward(self, x): # noqa: F811 14689 # type: (Tensor) -> Tensor 14690 pass 14691 14692 def forward(self, x): # noqa: F811 14693 if isinstance(x, Tensor): 14694 return x + 20 14695 else: 14696 return x[0] + 5 14697 14698 class S(torch.jit.ScriptModule): 14699 def __init__(self) -> None: 14700 super().__init__() 14701 self.weak = Over() 14702 14703 @torch.jit.script_method 14704 def forward(self, x): 14705 return self.weak(x) + self.weak((x, x)) 14706 14707 s_mod = S() 14708 x = torch.ones(1) 14709 self.assertEqual(s_mod(x), x + 20 + 5 + x) 14710 14711 over = Over() 14712 self.assertEqual(over((x, x)), x + 5) 14713 self.assertEqual(over(x), x + 20) 14714 14715 class Unannotated(torch.nn.Module): 14716 @torch.jit._overload_method # noqa: F811 14717 def hello(self, x): # noqa: F811 14718 pass 14719 14720 @torch.jit._overload_method # noqa: F811 14721 def hello(self, x): # noqa: F811 14722 # type: (int) -> (int) 14723 pass 14724 14725 def hello(self, x): # noqa: F811 14726 return x + 3 14727 14728 def forward(self): 14729 return self.hello(1), self.hello(.5) 14730 14731 w = Unannotated() 14732 with self.assertRaisesRegex(Exception, "explicitly add type annotations to overloaded functions"): 14733 torch.jit.script(w) 14734 14735 class CompileOverloadError(torch.nn.Module): 14736 @torch.jit._overload_method # noqa: F811 14737 def hello(self, x): # noqa: F811 14738 # type: (str) -> (int) 14739 pass 14740 14741 @torch.jit._overload_method # noqa: F811 14742 def hello(self, x): # noqa: F811 14743 # type: (int) -> (int) 14744 pass 14745 14746 def hello(self, x): # noqa: F811 14747 return x + 1 14748 14749 def forward(self): 14750 return self.hello("hi"), self.hello(.5) 14751 14752 w = CompileOverloadError() 14753 with self.assertRaisesRegex(Exception, "but instead found type 'str'"): 14754 torch.jit.script(w) 14755 14756 # testing overload declared first, then non-overload 14757 with self.assertRaisesRegex(Exception, "Overloads are not useable when a module"): 14758 class W3(torch.nn.Module): 14759 @torch.jit._overload_method # noqa: F811 14760 def forward(self, x): # noqa: F811 14761 # type: (int) -> int 14762 pass 14763 14764 @torch.jit._overload_method # noqa: F811 14765 def forward(self, x): # noqa: F811 14766 # type: (Tensor) -> Tensor 14767 pass 14768 14769 def forward(self, x): # noqa: F811 14770 return x + 5 14771 14772 a = W3() 14773 b = torch.jit.script(a) 14774 14775 class W3(torch.nn.Module): 14776 def forward(self, x): # noqa: F811 14777 return x + 5 + 10 14778 14779 a = W3() 14780 b = torch.jit.script(a) 14781 14782 # testing non-overload declared first, then overload 14783 class W2(torch.nn.Module): 14784 def hello(self, x1, x2): 14785 return x1 + x2 14786 14787 def forward(self, x): 14788 return self.hello(x, x) 14789 14790 a = torch.jit.script(W2()) 14791 self.assertEqual(a(torch.tensor(1)), torch.tensor(2)) 14792 14793 class W2(torch.nn.Module): 14794 @torch.jit._overload_method # noqa: F811 14795 def hello(self, x): # noqa: F811 14796 pass 14797 14798 @torch.jit._overload_method # noqa: F811 14799 def hello(self, x): # noqa: F811 14800 # type: (int) -> (int) 14801 pass 14802 14803 def hello(self, x): # noqa: F811 14804 return x + 5 + 10 14805 14806 def forward(self, x): 14807 return self.hello(1), self.hello(x) 14808 14809 with self.assertRaisesRegex(Exception, "Overloads are not useable when a module"): 14810 a = torch.jit.script(W2()) 14811 14812 def test_narrow_copy(self): 14813 def foo(a): 14814 return a.narrow_copy(0, 0, 5) 14815 14816 self.checkScript(foo, [torch.rand(10)]) 14817 14818 def test_select_after_chunk(self): 14819 def foo(x): 14820 chunked = torch.chunk(x, 1) 14821 foo = chunked[0] 14822 foo.add_(5) 14823 return x 14824 14825 self.checkScript(foo, [torch.rand(2, 3)]) 14826 14827 def test_nn_LSTM_with_layers(self): 14828 class M(torch.jit.ScriptModule): 14829 def __init__(self) -> None: 14830 super().__init__() 14831 self.rnn = nn.LSTM(2, 3, 2, dropout=0) 14832 14833 @torch.jit.script_method 14834 def forward(self, x, lengths, h0, c0): 14835 return self.rnn(x, (h0, c0))[0] 14836 14837 class Eager(torch.nn.Module): 14838 def __init__(self) -> None: 14839 super().__init__() 14840 self.rnn = nn.LSTM(2, 3, 2, dropout=0) 14841 14842 def forward(self, x, lengths, h0, c0): 14843 return self.rnn(x, (h0, c0))[0] 14844 14845 inputs = (torch.randn(1, 1, 2), torch.LongTensor([7]), torch.randn(2, 1, 3), torch.randn(2, 1, 3)) 14846 eager_out = self.runAndSaveRNG(lambda: Eager()(*inputs), ())[0] 14847 script_out = self.runAndSaveRNG(lambda: M()(*inputs), ())[0] 14848 14849 self.assertEqual(eager_out, script_out) 14850 14851 def test_nn_LSTM(self): 14852 input = torch.nn.utils.rnn.pack_sequence([torch.randn(5, 5)]) 14853 14854 class S(torch.jit.ScriptModule): 14855 def __init__(self) -> None: 14856 super().__init__() 14857 self.x = torch.nn.LSTM(5, 5) 14858 14859 @torch.jit.script_method 14860 def forward(self, input: PackedSequence) -> Tuple[PackedSequence, Tuple[torch.Tensor, torch.Tensor]]: 14861 return self.x(input) 14862 14863 eager_out = self.runAndSaveRNG(lambda x: torch.nn.LSTM(5, 5)(x), (input,))[0] 14864 script_out = self.runAndSaveRNG(lambda x: S()(x), (input,))[0] 14865 14866 self.assertEqual(eager_out, script_out) 14867 14868 def test_nn_GRU(self): 14869 seq_input = torch.nn.utils.rnn.pack_sequence([torch.randn(5, 5)]) 14870 tensor_input = torch.randn(5, 5, 5) 14871 14872 class SeqLengthGRU(torch.jit.ScriptModule): 14873 def __init__(self) -> None: 14874 super().__init__() 14875 self.x = torch.nn.GRU(5, 5) 14876 14877 @torch.jit.script_method 14878 def forward(self, input: PackedSequence) -> Tuple[PackedSequence, torch.Tensor]: 14879 return self.x(input) 14880 14881 class TensorGRU(torch.jit.ScriptModule): 14882 def __init__(self) -> None: 14883 super().__init__() 14884 self.x = torch.nn.GRU(5, 5) 14885 14886 @torch.jit.script_method 14887 def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 14888 return self.x(input) 14889 14890 seq_eager_out = self.runAndSaveRNG(lambda x: torch.nn.GRU(5, 5)(x), (seq_input,))[0] 14891 seq_script_out = self.runAndSaveRNG(lambda x: SeqLengthGRU()(x), (seq_input,))[0] 14892 tensor_eager_out = self.runAndSaveRNG(lambda x: torch.nn.GRU(5, 5)(x), (tensor_input,))[0] 14893 tensor_script_out = self.runAndSaveRNG(lambda x: TensorGRU()(x), (tensor_input,))[0] 14894 14895 self.assertEqual(seq_eager_out, seq_script_out) 14896 self.assertEqual(tensor_eager_out, tensor_script_out) 14897 14898 def test_torchscript_memoryformat(self): 14899 @torch.jit.script 14900 def fn(x): 14901 return x.contiguous(memory_format=torch.channels_last) 14902 x = torch.randn(4, 3, 6, 6) 14903 y = fn(x) 14904 self.assertTrue(y.is_contiguous(memory_format=torch.channels_last)) 14905 14906 def test_torchscript_multi_head_attn(self): 14907 @torch.jit.script 14908 def jit_multihead_attn_forward(query, # type: Tensor 14909 key, # type: Tensor 14910 value, # type: Tensor 14911 embed_dim_to_check, # type: int 14912 num_heads, # type: int 14913 in_proj_weight, # type: Tensor 14914 in_proj_bias, # type: Tensor 14915 bias_k, # type: Optional[Tensor] 14916 bias_v, # type: Optional[Tensor] 14917 add_zero_attn, # type: bool 14918 dropout, # type: float 14919 out_proj_weight, # type: Tensor 14920 out_proj_bias, # type: Tensor 14921 training=True, # type: bool 14922 key_padding_mask=None, # type: Optional[Tensor] 14923 need_weights=True, # type: bool 14924 attn_mask=None # type: Optional[Tensor] 14925 ): 14926 # type: (...) -> Tuple[Tensor, Optional[Tensor]] 14927 return torch.nn.functional.multi_head_attention_forward(query, key, value, 14928 embed_dim_to_check, num_heads, 14929 in_proj_weight, in_proj_bias, 14930 bias_k, bias_v, 14931 add_zero_attn, dropout, 14932 out_proj_weight, out_proj_bias, 14933 training, key_padding_mask, 14934 need_weights, attn_mask) 14935 14936 src_l = 3 14937 bsz = 5 14938 embed_size = 8 14939 nhead = 2 14940 multi_head_attn = torch.nn.MultiheadAttention(embed_size, nhead) 14941 query = torch.rand((src_l, bsz, embed_size)) 14942 key = torch.rand((src_l, bsz, embed_size)) 14943 value = torch.rand((src_l, bsz, embed_size)) 14944 14945 mask = (torch.triu(torch.ones(src_l, src_l)) == 1).transpose(0, 1) 14946 mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, 0.0).to(torch.get_default_dtype()) 14947 14948 jit_out = jit_multihead_attn_forward(query, key, value, 14949 embed_size, nhead, 14950 multi_head_attn.in_proj_weight, 14951 multi_head_attn.in_proj_bias, 14952 multi_head_attn.bias_k, multi_head_attn.bias_v, 14953 multi_head_attn.add_zero_attn, multi_head_attn.dropout, 14954 multi_head_attn.out_proj.weight, 14955 multi_head_attn.out_proj.bias, attn_mask=mask)[0] 14956 14957 py_out = torch.nn.functional.multi_head_attention_forward(query, key, value, 14958 embed_size, nhead, 14959 multi_head_attn.in_proj_weight, 14960 multi_head_attn.in_proj_bias, 14961 multi_head_attn.bias_k, 14962 multi_head_attn.bias_v, 14963 multi_head_attn.add_zero_attn, 14964 multi_head_attn.dropout, 14965 multi_head_attn.out_proj.weight, 14966 multi_head_attn.out_proj.bias, 14967 attn_mask=mask)[0] 14968 # print("rel. error: ") 14969 # print(jit_out / py_out - 1) 14970 self.assertEqual(jit_out, py_out, atol=5e-4, rtol=1e-4) 14971 14972 def test_torchscript_multi_head_attn_fast_path(self): 14973 src_l = 3 14974 bsz = 5 14975 embed_size = 8 14976 nhead = 2 14977 multi_head_attn = torch.nn.MultiheadAttention(embed_size, nhead, batch_first=True) 14978 multi_head_attn = multi_head_attn.eval() 14979 14980 query = key = value = torch.rand((bsz, src_l, embed_size)) 14981 14982 with torch.no_grad(): 14983 py_out = multi_head_attn(query, key, value) 14984 mha = torch.jit.script(multi_head_attn) 14985 jit_out = mha(query, key, value) 14986 torch.testing.assert_close(jit_out, py_out) 14987 14988 @unittest.skipIf(not RUN_CUDA, "no CUDA") 14989 def test_scriptmodule_multi_head_attn_cuda(self): 14990 14991 class MyModule(torch.jit.ScriptModule): 14992 def __init__(self, embed_dim, num_heads): 14993 super().__init__() 14994 sample_q = torch.randn(3, 2, embed_dim) 14995 sample_kv = torch.randn(3, 2, embed_dim) 14996 attention = nn.MultiheadAttention(embed_dim, num_heads) 14997 attention.eval() 14998 14999 self.mod = torch.jit.trace(attention, 15000 (sample_q, sample_kv, sample_kv)) 15001 15002 @torch.jit.script_method 15003 def forward(self, q, k, v): 15004 return self.mod(q, k, v) 15005 15006 embed_dim = 8 15007 num_heads = 2 15008 sl = 3 15009 bs = 2 15010 model = MyModule(embed_dim, num_heads).cuda() 15011 q = torch.randn(sl, bs, embed_dim, device="cuda") 15012 kv = torch.randn(sl, bs, embed_dim, device="cuda") 15013 15014 jit_out = model(q, kv, kv)[0] 15015 py_out = torch.nn.functional.multi_head_attention_forward(q, kv, kv, 15016 embed_dim, num_heads, 15017 model.mod.in_proj_weight, 15018 model.mod.in_proj_bias, 15019 None, None, None, 0.0, 15020 model.mod.out_proj.weight, 15021 model.mod.out_proj.bias)[0] 15022 self.assertEqual(jit_out, py_out, atol=5e-4, rtol=1e-4) 15023 15024 @unittest.skipIf(not RUN_CUDA, "no CUDA") 15025 def test_scriptmodule_transformer_cuda(self): 15026 15027 class MyModule(torch.jit.ScriptModule): 15028 def __init__(self, transformer, sample_q, sample_kv): 15029 super().__init__() 15030 transformer.eval() 15031 15032 self.mod = torch.jit.trace(transformer, 15033 (sample_q, sample_kv)) 15034 15035 @torch.jit.script_method 15036 def forward(self, q, k): 15037 return self.mod(q, k) 15038 15039 d_model = 8 15040 nhead = 2 15041 num_encoder_layers = 2 15042 num_decoder_layers = 2 15043 dim_feedforward = 16 15044 bsz = 2 15045 seq_length = 5 15046 tgt_length = 3 15047 15048 with torch.no_grad(): 15049 src = torch.randn(seq_length, bsz, d_model) 15050 tgt = torch.randn(tgt_length, bsz, d_model) 15051 transformer = nn.Transformer(d_model, nhead, num_encoder_layers, 15052 num_decoder_layers, dim_feedforward, dropout=0.0) 15053 model = MyModule(transformer, tgt, src) 15054 15055 src = torch.randn(seq_length, bsz, d_model) 15056 tgt = torch.randn(tgt_length, bsz, d_model) 15057 jit_out = model(tgt, src) 15058 py_out = transformer(tgt, src) 15059 15060 # print(jit_out/py_out-1) 15061 # print(torch.allclose(jit_out, py_out, atol=5e-4, rtol=1e-4)) 15062 self.assertEqual(jit_out, py_out, atol=5e-4, rtol=1e-4) 15063 15064 def test_list_python_op(self): 15065 def python_list_op(lst): 15066 # type: (List[Tensor]) -> Tensor 15067 return lst[0] 15068 15069 def fn(lst): 15070 # type: (List[Tensor]) -> Tensor 15071 return python_list_op(lst) 15072 15073 self.checkScript(fn, ([torch.ones(2) + 2, torch.ones(2)],)) 15074 15075 @unittest.skipIf(not RUN_CUDA, "no CUDA") 15076 def test_weak_cuda(self): 15077 class M(torch.jit.ScriptModule): 15078 def __init__(self) -> None: 15079 super().__init__() 15080 self.lstm = torch.nn.LSTM(5, 5) 15081 self.lstm.cuda() 15082 15083 @torch.jit.script_method 15084 def forward(self, x): 15085 return self.lstm(x) 15086 15087 m = M() 15088 m.cuda() 15089 out = m(torch.ones(5, 5, 5).cuda()) 15090 self.assertTrue(out[0].is_cuda) 15091 15092 def test_ignore_decorator(self): 15093 with warnings.catch_warnings(record=True) as warns: 15094 class M(torch.jit.ScriptModule): 15095 def __init__(self) -> None: 15096 super().__init__() 15097 tensor = torch.zeros(1, requires_grad=False) 15098 self.some_state = nn.Buffer(torch.nn.Parameter(tensor)) 15099 15100 @torch.jit.script_method 15101 def forward(self, x): 15102 self.ignored_code(x) 15103 return x 15104 15105 @torch.jit.ignore(drop_on_export=True) 15106 def ignored_code(self, x): 15107 self.some_state = torch.tensor((100,)) 15108 15109 FileCheck().check("TorchScript will now drop the function").run(str(warns[0])) 15110 15111 # Assert ignored code is run 15112 m = M() 15113 15114 m2 = self.getExportImportCopy(m) 15115 pp = str(m2.forward.code) 15116 self.assertNotIn('ignored_code', pp) 15117 15118 with self.assertRaisesRegex(torch.jit.Error, "annotated to be ignored and cannot be run"): 15119 m2.forward(torch.ones(1)) 15120 15121 def test_ignored_as_value(self): 15122 class Model(nn.Module): 15123 @torch.jit.unused 15124 def tuple_ignored(self, x): 15125 # type: (Tensor) -> Tuple[Tensor, Tensor] 15126 return x, x 15127 15128 @torch.jit.unused 15129 def single_val_ignored(self, x, y): 15130 # type: (Tensor, Tensor) -> Tensor 15131 return x 15132 15133 def forward(self, x, use_ignore_path): 15134 # type: (Tensor, bool) -> Tuple[Tensor, Tensor] 15135 if 1 == 2: 15136 return self.tuple_ignored(x) 15137 if use_ignore_path: 15138 return self.single_val_ignored(x, x), self.single_val_ignored(x, x) 15139 return x, x 15140 15141 original = Model() 15142 scripted = torch.jit.script(original) 15143 self.assertEqual(scripted(torch.tensor(.5), False), (torch.tensor(.5), torch.tensor(.5))) 15144 15145 buffer = io.BytesIO() 15146 torch.jit.save(scripted, buffer) 15147 buffer.seek(0) 15148 loaded = torch.jit.load(buffer) 15149 15150 with self.assertRaisesRegex(torch.jit.Error, "annotated to be ignored and cannot be run"): 15151 loaded(torch.tensor(.5), True) 15152 15153 def test_module_error(self): 15154 class MyModule(torch.nn.Module): 15155 def forward(self, foo): 15156 return foo 15157 15158 with self.assertRaisesRegex(RuntimeError, "cannot be compiled since it inherits from nn.Module"): 15159 torch.jit.script(MyModule) 15160 15161 def test_view_write(self): 15162 def fn(x, y): 15163 l = [] 15164 l.append(x) 15165 x_view = l[0] 15166 a = x + x 15167 x_view.add_(y) 15168 b = x + x 15169 return a == b 15170 self.checkScript(fn, (torch.rand(2, 3), torch.rand(2, 3))) 15171 15172 def test_module_attrs(self): 15173 class M(torch.jit.ScriptModule): 15174 def __init__(self, table): 15175 super().__init__() 15176 self.table = torch.jit.Attribute(table, Dict[str, torch.Tensor]) 15177 self.x = torch.nn.Parameter(torch.tensor([100.0])) 15178 15179 @torch.jit.script_method 15180 def forward(self, key): 15181 # type: (str) -> Tensor 15182 return self.table[key] + self.x 15183 15184 with torch._jit_internal._disable_emit_hooks(): 15185 # TODO: re-enable module hook when Python printing of attributes is 15186 # supported 15187 m = M({char : torch.ones(1) + ord(char) - ord("a") for char in "abcdefg"}) 15188 self.assertEqual(m("c"), torch.tensor([103.])) 15189 15190 def test_module_none_attrs(self): 15191 class MyMod(torch.jit.ScriptModule): 15192 def __init__(self) -> None: 15193 super().__init__() 15194 self.optional_value = None 15195 15196 @torch.jit.script_method 15197 def forward(self): 15198 return self.optional_value 15199 15200 graph = MyMod().forward.graph 15201 FileCheck().check("prim::GetAttr").run(graph) 15202 self.run_pass('peephole', graph) 15203 FileCheck().check_not("prim::GetAttr").run(graph) 15204 15205 def test_tensor_import_export(self): 15206 @torch.jit.script 15207 def foo(x): 15208 a = torch.tensor(1) 15209 b = torch.tensor([1, 2]) 15210 c = [a, b] 15211 return c 15212 15213 self.run_pass('constant_propagation', foo.graph) 15214 m = self.createFunctionFromGraph(foo.graph) 15215 self.getExportImportCopy(m) 15216 15217 def get_pickle_values(self): 15218 return (('dict', {"I": "am", "a test": "test"}, Dict[str, str]), 15219 ('float', 2.3, float), 15220 ('int', 99, int), 15221 ('bool', False, bool), 15222 ('tuple', (1, 2, 3, 4), Tuple[int, int, int, int]), 15223 ('list', [(1, 2), (3, 4)], List[Tuple[int, int]]), 15224 ('tensor', torch.randn(2, 2), torch.Tensor), 15225 ('int_list', [1, 2, 3, 4], List[int]), 15226 ('tensor_list', [torch.ones(2, 2) + i for i in range(4)], List[torch.Tensor]), 15227 ('bool_list', [True, True, False, True], List[bool]), 15228 ('float_list', [1., 2., 3., 4.], List[float]), 15229 ('str_list', ['hello', 'bye'], List[str]), 15230 ('none', None, Optional[int]), 15231 ('a_device', torch.device('cpu'), torch.device), 15232 ('another_device', torch.device('cuda:1'), torch.device)) 15233 15234 def test_attribute_serialization(self): 15235 tester = self 15236 15237 class M(torch.jit.ScriptModule): 15238 def __init__(self) -> None: 15239 super().__init__() 15240 for name, value, the_type in tester.get_pickle_values(): 15241 setattr(self, name, torch.jit.Attribute(value, the_type)) 15242 15243 @torch.jit.script_method 15244 def forward(self): 15245 return (self.dict, self.float, self.int, self.bool, self.tuple, 15246 self.list, self.int_list, self.tensor_list, self.bool_list, 15247 self.float_list, self.str_list, self.none) 15248 15249 m = M() 15250 imported_m = self.getExportImportCopy(m) 15251 self.assertEqual(m(), imported_m()) 15252 15253 def test_string_len(self): 15254 def fn(x): 15255 # type: (str) -> int 15256 return len(x) 15257 15258 self.checkScript(fn, ("",)) 15259 self.checkScript(fn, ("h",)) 15260 self.checkScript(fn, ("hello",)) 15261 15262 def test_multiline_optional_future_refinement(self): 15263 @torch.jit.script 15264 def fun() -> int: 15265 future: Optional[ 15266 torch.jit.Future[Tuple[torch.Tensor]] 15267 ] = None 15268 15269 return 1 15270 self.assertEqual(fun(), 1) 15271 15272 @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle") 15273 def test_attribute_unpickling(self): 15274 tensor = torch.randn(2, 2) 15275 tester = self 15276 15277 class M(torch.jit.ScriptModule): 15278 def __init__(self) -> None: 15279 super().__init__() 15280 for name, value, the_type in tester.get_pickle_values(): 15281 setattr(self, "_" + name, torch.jit.Attribute(value, the_type)) 15282 15283 @torch.jit.script_method 15284 def forward(self): 15285 return (self._dict, self._float, self._int, self._bool, self._tuple, 15286 self._list, self._int_list, self._tensor_list, self._bool_list, 15287 self._float_list, self._str_list, self._none) 15288 15289 with TemporaryFileName() as fname: 15290 M().save(fname) 15291 loaded = torch.jit.load(fname) 15292 15293 def is_tensor_value(item): 15294 if isinstance(item, torch.Tensor): 15295 return True 15296 if isinstance(item, list): 15297 return is_tensor_value(item[0]) 15298 return False 15299 for name, value, the_type in self.get_pickle_values(): 15300 if is_tensor_value(value): 15301 continue 15302 self.assertEqual(value, getattr(loaded, "_" + name)) 15303 15304 15305 def test_submodule_attribute_serialization(self): 15306 class S(torch.jit.ScriptModule): 15307 def __init__(self, list_data): 15308 super().__init__() 15309 self.table = torch.jit.Attribute({"I": "am", "a test": "test"}, Dict[str, str]) 15310 self.list = torch.jit.Attribute(list_data, List[Tuple[int, int]]) 15311 15312 @torch.jit.script_method 15313 def forward(self): 15314 return (self.table, self.list) 15315 15316 class M(torch.jit.ScriptModule): 15317 def __init__(self) -> None: 15318 super().__init__() 15319 self.table = torch.jit.Attribute({"this": "is", "a different": "dict"}, Dict[str, str]) 15320 self.tensor = torch.jit.Attribute(torch.randn(2, 2), torch.Tensor) 15321 self.s1 = S([(1, 2)]) 15322 self.s2 = S([(4, 5)]) 15323 15324 @torch.jit.script_method 15325 def forward(self): 15326 return (self.table, self.tensor, self.s1.table, self.s2.list, self.s1.list) 15327 15328 m = M() 15329 imported_m = self.getExportImportCopy(m) 15330 self.assertEqual(m(), imported_m()) 15331 15332 def test_serialization_big_ints(self): 15333 class M(torch.jit.ScriptModule): 15334 def __init__(self) -> None: 15335 super().__init__() 15336 self.int32_max = torch.jit.Attribute(2**31 - 1, int) 15337 self.int32_min = torch.jit.Attribute(-2**31, int) 15338 self.uint32_max = torch.jit.Attribute(2**32, int) 15339 15340 self.int64_max = torch.jit.Attribute(2**63 - 1, int) 15341 self.int64_min = torch.jit.Attribute(-2**63, int) 15342 15343 self.tensor = torch.nn.Parameter(torch.ones(2, 2)) 15344 15345 @torch.jit.script_method 15346 def forward(self, x): 15347 # type: (int) -> (int) 15348 return x + (self.int32_max + self.int32_min) + (self.int64_max + self.int64_min) 15349 15350 m = M() 15351 imported = self.getExportImportCopy(m) 15352 self.assertEqual(m(10), imported(10)) 15353 15354 self.assertEqual(m.int32_max, imported.int32_max) 15355 self.assertEqual(m.int32_min, imported.int32_min) 15356 self.assertEqual(m.uint32_max, imported.uint32_max) 15357 self.assertEqual(m.int64_max, imported.int64_max) 15358 self.assertEqual(m.int64_min, imported.int64_min) 15359 15360 def test_script_scope(self): 15361 scripted = torch.jit.script(torch.nn.functional.triplet_margin_loss) 15362 15363 @unittest.skipIf(IS_WINDOWS, "NYI: TemporaryFileName on Windows") 15364 def test_serialization_sharing(self): 15365 class M(torch.jit.ScriptModule): 15366 def __init__(self) -> None: 15367 super().__init__() 15368 self.list = torch.jit.Attribute([], List[str]) 15369 15370 @torch.jit.script_method 15371 def forward(self, key): 15372 # type: (str) -> List[str] 15373 self.list.append(key) 15374 self.list.append(key) 15375 self.list.append(key) 15376 return self.list 15377 15378 # the text of the string should only appear once in the pickling 15379 m = M() 15380 s1 = "a long string" 15381 s2 = "a different, even longer string" 15382 self.assertEqual(m(s1), [s1] * 3) 15383 self.assertEqual(m(s2), [s1] * 3 + [s2] * 3) 15384 with TemporaryFileName() as fname: 15385 m.save(fname) 15386 archive_name = os.path.basename(os.path.normpath(fname)) 15387 archive = zipfile.ZipFile(fname, 'r') 15388 pickled_data = archive.read(os.path.join(archive_name, 'data.pkl')) 15389 15390 out = io.StringIO() 15391 pickletools.dis(pickled_data, out=out) 15392 disassembled = out.getvalue() 15393 15394 FileCheck().check_count(s1, 1, exactly=True) \ 15395 .check_count("BINGET", 2, exactly=True) \ 15396 .check_count(s2, 1, exactly=True) \ 15397 .check_count("BINGET", 2, exactly=True).run(out.getvalue()) 15398 15399 def test_sys_stdout_override(self): 15400 @torch.jit.script 15401 def foo(): 15402 print('foo') 15403 15404 class Redirect: 15405 def __init__(self) -> None: 15406 self.s = '' 15407 15408 def write(self, s): 15409 self.s += s 15410 15411 old_stdout = sys.stdout 15412 redirect = Redirect() 15413 try: 15414 sys.stdout = redirect 15415 foo() 15416 finally: 15417 sys.stdout = old_stdout 15418 15419 FileCheck().check('foo').run(redirect.s) 15420 15421 def test_dtype_attr(self): 15422 class Foo(torch.nn.Module): 15423 def __init__(self) -> None: 15424 super().__init__() 15425 self.dtype = torch.zeros([]).dtype 15426 15427 def forward(self): 15428 return torch.zeros(3, 4, dtype=self.dtype) 15429 15430 f = Foo() 15431 torch.jit.script(f) 15432 15433 15434 def test_named_buffers_are_iterable(self): 15435 class MyMod(torch.nn.Module): 15436 def __init__(self) -> None: 15437 super().__init__() 15438 self.mod = (torch.nn.ReLU()) 15439 self.mod2 = (torch.nn.ReLU()) 15440 self.mod3 = torch.nn.Sequential(torch.nn.Sequential(torch.nn.ReLU())) 15441 self.x = nn.Buffer(torch.zeros(3)) 15442 self.y = nn.Buffer(torch.zeros(3)) 15443 self.z = torch.zeros(3) 15444 15445 def bleh(self): 15446 return self.z + 4 15447 15448 @torch.jit.export 15449 def method(self): 15450 names = [""] 15451 vals = [] 15452 for name, buffer in self.named_buffers(): 15453 names.append(name) 15454 vals.append(buffer + 2) 15455 15456 return names, vals 15457 15458 def forward(self, x): 15459 return x 15460 15461 model = MyMod() 15462 x = torch.jit.script(model) 15463 z = self.getExportImportCopy(x) 15464 15465 self.assertEqual(z.method(), x.method()) 15466 self.assertEqual(z.method(), model.method()) 15467 self.assertEqual(x.method(), model.method()) 15468 names = x.method() 15469 for name in names: 15470 self.assertNotEqual('z', name) 15471 15472 15473 def test_static_if_prop(self): 15474 class MaybeHasAttr(torch.nn.Module): 15475 def __init__(self, add_attr): 15476 super().__init__() 15477 if add_attr: 15478 self.maybe_attr = 1 15479 15480 def forward(self): 15481 if hasattr(self, "maybe_attr") and True: 15482 return self.maybe_attr 15483 else: 15484 return 0 15485 15486 class MaybeHasAttr2(torch.nn.Module): 15487 def __init__(self, add_attr): 15488 super().__init__() 15489 if add_attr: 15490 self.maybe_attr = 1 15491 15492 def forward(self): 15493 if not hasattr(self, "maybe_attr") or False: 15494 return 0 15495 else: 15496 return self.maybe_attr 15497 15498 torch.jit.script(MaybeHasAttr(True)) 15499 torch.jit.script(MaybeHasAttr(False)) 15500 torch.jit.script(MaybeHasAttr2(True)) 15501 torch.jit.script(MaybeHasAttr2(False)) 15502 15503 class MyMod(torch.nn.Module): 15504 def forward(self): 15505 if hasattr(self, "foo"): 15506 return 1 15507 else: 15508 return 0 15509 15510 @torch.jit.export 15511 def fee(self): 15512 return 1 15513 15514 self.checkModule(MyMod(), ()) 15515 15516 class HasAttrMod(torch.nn.Module): 15517 __constants__ = ["fee"] 15518 15519 def __init__(self) -> None: 15520 super().__init__() 15521 self.fee = 3 15522 15523 def forward(self): 15524 a = hasattr(self, "fee") 15525 b = hasattr(self, "foo") 15526 c = hasattr(self, "hi") 15527 d = hasattr(self, "nonexistant") 15528 return (a, b, c, d) 15529 15530 def foo(self): 15531 return 1 15532 15533 @torch.jit._overload_method 15534 def hi(self, x: Tensor): ... # noqa: E704 15535 15536 def hi(self, x): # noqa: F811 15537 return 2 15538 15539 self.checkModule(HasAttrMod(), ()) 15540 15541 @torch.jit.script 15542 class FooTest: 15543 def __init__(self) -> None: 15544 self.x = 1 15545 15546 def foo(self, y): 15547 return self.x + y 15548 15549 def foo(): 15550 a = FooTest() 15551 val1 = hasattr(a, "foo"), hasattr(a, "x"), hasattr(a, "bla") 15552 val2 = hasattr(FooTest, "foo"), hasattr(FooTest, "a") 15553 return val1, val2 15554 15555 self.assertEqual(foo(), torch.jit.script(foo)()) 15556 15557 def _test_pickle_checkpoint(self, device): 15558 with TemporaryFileName() as fname: 15559 class M(torch.jit.ScriptModule): 15560 __constants__ = ['fname'] 15561 15562 def __init__(self, tensor): 15563 super().__init__() 15564 self.fname = fname 15565 self.tensor = torch.nn.Parameter(tensor) 15566 15567 @torch.jit.script_method 15568 def forward(self, x): 15569 y = self.tensor + x 15570 torch.save(y, self.fname) 15571 return y 15572 15573 param = torch.randn(2, 2).to(device) 15574 input = torch.randn(2, 2).to(device) 15575 m = M(param) 15576 m(input) 15577 with open(fname, "rb") as handle: 15578 loaded_tensor = torch.load(fname) 15579 self.assertEqual(loaded_tensor, input + param) 15580 15581 def _test_pickle_checkpoint_views(self, device): 15582 with TemporaryFileName() as fname: 15583 class M(torch.jit.ScriptModule): 15584 __constants__ = ['fname'] 15585 15586 def __init__(self, tensor): 15587 super().__init__() 15588 self.fname = fname 15589 self.tensor = torch.nn.Parameter(tensor) 15590 15591 @torch.jit.script_method 15592 def forward(self, x): 15593 y = self.tensor + x 15594 y_view = y.view(4) 15595 torch.save((y, y_view, y), self.fname) 15596 return y 15597 15598 param = torch.randn(2, 2).to(device) 15599 input = torch.randn(2, 2).to(device) 15600 m = M(param) 15601 m(input) 15602 with open(fname, "rb") as handle: 15603 loaded_y, loaded_y_view, loaded_y_2 = torch.load(fname) 15604 self.assertEqual(loaded_y, input + param) 15605 with torch.no_grad(): 15606 loaded_y_view[1] += 20 15607 # assert that loaded_y changed as well 15608 self.assertEqual(loaded_y.view(4), loaded_y_view) 15609 self.assertEqual(loaded_y_2.view(4), loaded_y_view) 15610 15611 @unittest.skipIf(not RUN_CUDA, "no CUDA") 15612 def test_pickle_checkpoint_cuda(self): 15613 self._test_pickle_checkpoint('cuda') 15614 self._test_pickle_checkpoint_views('cuda') 15615 15616 def test_pickle_checkpoint(self): 15617 self._test_pickle_checkpoint('cpu') 15618 self._test_pickle_checkpoint_views('cpu') 15619 15620 def test_pickle_checkpoint_tup(self): 15621 @torch.jit.script 15622 def foo(fname): 15623 # type: (str) -> None 15624 torch.save((3, 4), fname) 15625 with TemporaryFileName() as name: 15626 foo(name) 15627 self.assertEqual(torch.load(name), (3, 4)) 15628 15629 def test_string_list(self): 15630 def fn(string): 15631 # type: (str) -> List[str] 15632 return list(string) 15633 15634 self.checkScript(fn, ("abcdefgh",)) 15635 15636 def test_unicode_comments(self): 15637 @torch.jit.script 15638 def test(self, a): 15639 # 15640 return torch.nn.functional.relu(a) 15641 15642 def test_get_set_state_with_tensors(self): 15643 class M(torch.nn.Module): 15644 def __init__(self) -> None: 15645 super().__init__() 15646 self.tensor = torch.randn(2, 2) 15647 15648 @torch.jit.export 15649 def __getstate__(self): 15650 return (self.tensor, self.training) 15651 15652 @torch.jit.export 15653 def __setstate__(self, state): 15654 self.tensor = state[0] 15655 self.training = state[1] 15656 15657 def forward(self, x): 15658 return x + self.tensor 15659 15660 with TemporaryFileName() as fname: 15661 m = torch.jit.script(M()) 15662 m.save(fname) 15663 loaded = torch.jit.load(fname) 15664 self.assertEqual(loaded.tensor, m.tensor) 15665 15666 def test_in_for_and_comp_expr(self): 15667 def fn(d): 15668 # type: (Dict[str, int]) -> List[int] 15669 out = [1] 15670 for i in range(d["hi"] if "hi" in d else 6): 15671 out.append(i) # noqa: PERF402 15672 return out 15673 15674 self.checkScript(fn, ({'hi': 2, 'bye': 3},)) 15675 self.checkScript(fn, ({'bye': 3},)) 15676 15677 def test_for_else(self): 15678 def fn(): 15679 c = 0 15680 for i in range(4): 15681 c += 10 15682 else: 15683 print("In else block of for...else") 15684 15685 with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, "else branches of for loops aren't supported"): 15686 torch.jit.script(fn) 15687 15688 def test_split(self): 15689 def split_two(tensor): 15690 a, b, c = torch.split(tensor, 2, dim=1) 15691 return a, b, c 15692 x = torch.randn(3, 6) 15693 y = torch.randn(3, 6) 15694 self.checkScript(split_two, [(x + y)]) 15695 15696 def test_conv_error(self): 15697 @torch.jit.script 15698 def fn(x, y): 15699 return F.conv2d(x, y) 15700 15701 try: 15702 fn(torch.ones(2, 2), torch.ones(4, 4)) 15703 except RuntimeError as e: 15704 self.assertFalse('frame' in str(e)) 15705 15706 def test_python_op_name(self): 15707 import random 15708 15709 with self.assertRaisesRegex(RuntimeError, "randint"): 15710 @torch.jit.script 15711 def fn(): 15712 return random.randint() 15713 15714 def test_dir(self): 15715 class M(torch.jit.ScriptModule): 15716 def forward(self, t): 15717 return t 15718 15719 self.assertTrue('forward' in dir(M())) 15720 15721 def test_kwarg_expansion_error(self): 15722 @torch.jit.ignore 15723 def something_else(h, i): 15724 pass 15725 15726 def fn(x): 15727 something_else(**x) 15728 15729 with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, "keyword-arg expansion is not supported"): 15730 torch.jit.script(fn) 15731 15732 def test_kwargs_error_msg(self): 15733 def other(**kwargs): 15734 print(kwargs) 15735 15736 def fn(): 15737 return other() 15738 15739 with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, 'variable number'): 15740 torch.jit.script(fn) 15741 15742 def another_other(*args): 15743 print(args) 15744 15745 def another_fn(): 15746 return another_other() 15747 15748 with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, 'variable number'): 15749 torch.jit.script(another_fn) 15750 15751 def test_inferred_error_msg(self): 15752 """ 15753 Test that when we get a type mismatch on a function where we inferred 15754 the type to be tensor, a good error message is given. 15755 """ 15756 @torch.jit.script 15757 def foo(a): 15758 return a 15759 15760 with self.assertRaisesRegex(RuntimeError, (r"Expected a value of type \'Tensor \(inferred\)\'" 15761 r"[\S\s]*Inferred \'a\' to be of type \'Tensor\'")): 15762 foo("1") 15763 15764 def test_type_comments_in_body(self): 15765 @torch.jit.script 15766 def foo(a, # type: int 15767 b, # type: int 15768 ): 15769 # type: (...) -> int 15770 # type: int 15771 return a + b 15772 15773 class M(torch.nn.Module): 15774 def __init__(self, 15775 a, # type: int 15776 b # type: int 15777 ): 15778 # type: (...) -> None 15779 super().__init__() 15780 self.a = a # type: int 15781 self.b = b # type: int 15782 15783 torch.jit.script(M(2, 3)) 15784 15785 def test_input_keyword_in_schema(self): 15786 def f(x): 15787 return torch.ceil(input=x) 15788 15789 inp = torch.randn(10) 15790 self.checkScript(f, (inp, )) 15791 15792 def test_module_method_reassignment(self): 15793 class Foo(torch.nn.Module): 15794 def _forward(self, x): 15795 return x 15796 15797 forward = _forward 15798 15799 sm = torch.jit.script(Foo()) 15800 input = torch.ones(2, 2) 15801 self.assertEqual(input, sm(input)) 15802 15803 # Tests the case where a torch.Tensor subclass (like Parameter) is used as 15804 # input. 15805 def test_script_module_tensor_subclass_argument(self): 15806 @torch.jit.script 15807 def parameter_script(x: torch.nn.Parameter): 15808 return x 15809 15810 input = torch.ones(2, 2) 15811 self.assertEqual(input, parameter_script(input)) 15812 15813 def test_save_load_attr_error(self): 15814 class Inner(nn.Module): 15815 def forward(self, x): 15816 return x 15817 15818 class Wrapper(nn.Module): 15819 def __init__(self, inner): 15820 super().__init__() 15821 self.inner = inner 15822 15823 def forward(self, x): 15824 # this attribute doesn't exist on `Inner` 15825 return self.inner.b(x) 15826 15827 inner_module = torch.jit.script(Inner()) 15828 inner_module = self.getExportImportCopy(inner_module) 15829 wrapped = Wrapper(inner_module) 15830 # This should properly complain that `self.inner` doesn't have the attribute `b` 15831 with self.assertRaisesRegex(RuntimeError, 'has no attribute'): 15832 torch.jit.script(wrapped) 15833 15834 def test_rescripting_loaded_modules(self): 15835 class InnerSubmod(nn.Module): 15836 __constants__ = ['my_constant'] 15837 15838 def __init__(self) -> None: 15839 super().__init__() 15840 self.foo = torch.nn.Buffer(torch.ones(1)) 15841 self.register_parameter("bar", torch.nn.Parameter(torch.ones(1))) 15842 self.baz = torch.ones(1) 15843 self.my_constant = 1 15844 15845 def forward(self, x): 15846 return x + x 15847 15848 class Inner(nn.Module): 15849 def __init__(self) -> None: 15850 super().__init__() 15851 self.submod = InnerSubmod() 15852 15853 def forward(self, x): 15854 return self.submod(x) 15855 15856 class Wrapper(nn.Module): 15857 def __init__(self, inner): 15858 super().__init__() 15859 self.inner = inner 15860 15861 def forward(self, x): 15862 # access inner elements 15863 ret = self.inner.submod(x) + self.inner.submod.foo + self.inner.submod.bar + self.inner.submod.baz 15864 ret = ret + self.inner.submod.my_constant 15865 return ret 15866 15867 inner_module = torch.jit.script(Inner()) 15868 wrapped = Wrapper(inner_module) 15869 self.checkModule(wrapped, torch.ones(1)) 15870 15871 inner_module_loaded = self.getExportImportCopy(inner_module) 15872 wrapped_loaded = Wrapper(inner_module_loaded) 15873 self.assertEqual(wrapped(torch.ones(1)), wrapped_loaded(torch.ones(1))) 15874 15875 def test_interpret_graph(self): 15876 def fn(x): 15877 return x.unfold(0, 1, 1) 15878 15879 graph_str = """ 15880 graph(%a : Tensor, %b : Tensor): 15881 %c : Tensor = aten::mul(%a, %b) 15882 return (%c) 15883 """ 15884 graph = parse_ir(graph_str) 15885 a = torch.rand(10) 15886 b = torch.rand(10) 15887 test = torch._C._jit_interpret_graph(graph, (a, b)) 15888 ref = a * b 15889 self.assertEqual(test, ref) 15890 15891 def test_signed_float_zero(self): 15892 15893 class MyModule(torch.nn.Module): 15894 def forward(self, x): 15895 return torch.div(x, -0.) 15896 15897 inp = torch.ones(1) 15898 self.checkModule(MyModule(), inp) 15899 15900 def test_index_with_tuple(self): 15901 class MyModule(torch.nn.Module): 15902 def forward(self, x): 15903 return x[(1,)] 15904 15905 self.checkModule(MyModule(), (torch.ones(2, 3),)) 15906 15907 def test_context_manager(self): 15908 class MyModule(torch.nn.Module): 15909 def forward(self, x, y): 15910 p = x + y 15911 q = p + 2.0 15912 return q 15913 15914 x = torch.randn(3, 2, dtype=torch.float) 15915 y = torch.randn(3, 2, dtype=torch.float) 15916 for fuser_name in ['fuser0', 'fuser1', 'none']: 15917 with torch.jit.fuser(fuser_name): 15918 self.checkModule(MyModule(), (x, y)) 15919 15920 def test_zero_dimension_tensor_trace(self): 15921 def f(x): 15922 return x[x > 0] 15923 jf = torch.jit.trace(f, torch.tensor(2., device="cpu")) 15924 15925# known to be failing in tracer 15926EXCLUDE_TRACED = { 15927 # The following fail due to #12024. 15928 # A prim::ListConstruct is involved and the indices get traced as TensorType, 15929 # which always require_grad. This causes a crash in autodiff. 15930 'test___getitem___adv_index', 15931 'test___getitem___adv_index_beg', 15932 'test___getitem___adv_index_comb', 15933 'test___getitem___adv_index_dup', 15934 'test___getitem___adv_index_sub', 15935 'test___getitem___adv_index_sub_2', 15936 'test___getitem___adv_index_sub_3', 15937 'test___getitem___adv_index_var', 15938 15939 # jit doesn't support sparse tensors. 15940 'test_to_sparse', 15941 'test_to_sparse_dim', 15942} 15943 15944EXCLUDE_TYPE_CHECK = { 15945 # slogdet tests use itemgetter to select its only differentiable output, 15946 # but this happens outside of the graph we handle, so there are fewer 15947 # reference outputs than graph outputs. 15948 'test_slogdet_1x1_neg_det', 15949 'test_slogdet_1x1_pos_det', 15950 'test_slogdet_distinct_singular_values', 15951 'test_slogdet_neg_det', 15952 'test_slogdet_pos_det', 15953 'test_slogdet_symmetric', 15954 'test_slogdet_symmetric_pd', 15955 'test_slogdet_batched_1x1_neg_det', 15956 'test_slogdet_batched_pos_det', 15957 'test_slogdet_batched_symmetric', 15958 'test_slogdet_batched_symmetric_pd', 15959 'test_slogdet_batched_distinct_singular_values' 15960} 15961 15962# chunk returns a list in scripting and we don't unpack the list, 15963# Thus it won't be replaced by ConstantChunk and run AD. 15964# It's explicitly checked in test_chunk_constant_script_ad 15965# Similary for split, it's replaced by split_with_sizes in tracing, 15966# but we don't have AD formula for aten::split(Tensor, int[], int), 15967# an op registered in JIT so AD is not triggered in scripting. 15968EXCLUDE_SCRIPT_AD_CHECK = { 15969 'test_chunk', 15970 'test_chunk_dim', 15971 'test_chunk_dim_neg0', 15972 'test_split_size_list', 15973 'test_split_size_list_dim', 15974 'test_split_size_list_dim_neg0', 15975 'test_tensor_indices_sections', 15976 'test_tensor_indices_sections_dim', 15977 'test_tensor_indices_sections_dim_neg0', 15978 'test_tensor_split_sections', 15979 'test_tensor_split_sections_dim', 15980 'test_tensor_split_sections_dim_neg0' 15981} 15982 15983EXCLUDE_PYTHON_PRINT = { 15984 # no support for BroadcastingList in python printer 15985 'test_nn_max_unpool1d', 15986 'test_nn_max_unpool2d', 15987 'test_nn_max_unpool3d', 15988 'test_nn_max_pool1d', 15989 'test_nn_max_pool2d', 15990 'test_nn_max_pool3d', 15991 'test_nn_max_pool1d_with_indices', 15992} 15993 15994EXCLUDE_ALIAS = { 15995 # aliases, which may appear in method_tests but are tested elsewhere 15996 'true_divide', 15997 15998 # Disable tests for lu from common_methods_invocations.py 15999 # TODO(@nikitaved) Enable jit tests once autograd.Function does support scripting 16000 'lu' 16001} 16002 16003 16004class TestJitGeneratedModule(JitTestCase): 16005 pass 16006 16007 16008class TestJitGeneratedFunctional(JitTestCase): 16009 pass 16010 16011# UBSAN per-function exclusions don't seem to work with OpenMP pragmas, 16012# and we have to disable the failing tests here instead. 16013UBSAN_DISABLED_TESTS = [ 16014 "test___rdiv___constant", 16015 "test___rdiv___scalar_constant", 16016 "test_addcdiv", 16017 "test_addcdiv_broadcast_all", 16018 "test_addcdiv_broadcast_rhs", 16019 "test_addcdiv_scalar", 16020 "test_addcdiv_scalar_broadcast_lhs", 16021 "test_addcdiv_scalar_broadcast_rhs", 16022 "test_addcdiv_scalar_scale", 16023 "test_addcdiv_scalar_scale_broadcast_lhs", 16024 "test_addcdiv_scalar_scale_broadcast_rhs", 16025 "test_addcdiv_scale", 16026 "test_addcdiv_scale_broadcast_all", 16027 "test_addcdiv_scale_broadcast_rhs", 16028 "test_add_broadcast_all", 16029 "test_add_broadcast_lhs", 16030 "test_add_broadcast_rhs", 16031 "test_add_constant", 16032 "test_add_scalar", 16033 "test_add_scalar_broadcast_lhs", 16034 "test_add_scalar_broadcast_rhs", 16035 "test_div", 16036 "test_div_broadcast_all", 16037 "test_div_broadcast_lhs", 16038 "test_div_broadcast_rhs", 16039 "test_div_scalar", 16040 "test_div_scalar_broadcast_lhs", 16041 "test_div_scalar_broadcast_rhs", 16042 "test_rsqrt", 16043 "test_rsqrt_scalar", 16044 "test_add", 16045 "test_reciprocal", 16046 "test_reciprocal_scalar", 16047] 16048 16049L = 20 16050M = 10 16051S = 5 16052 16053def add_nn_module_test(*args, **kwargs): 16054 no_grad = False if 'no_grad' not in kwargs else kwargs['no_grad'] 16055 16056 if 'desc' in kwargs and 'eval' in kwargs['desc']: 16057 # eval() is not supported, so skip these tests 16058 return 16059 16060 test_name = get_nn_mod_test_name(**kwargs) 16061 16062 @suppress_warnings 16063 def do_test(self): 16064 if test_name in EXCLUDE_SCRIPT_MODULES: 16065 return 16066 if not kwargs.get('check_jit', True): 16067 raise unittest.SkipTest('module test skipped on JIT') 16068 16069 default_dtype = torch.get_default_dtype() 16070 if 'default_dtype' in kwargs and kwargs['default_dtype'] is not None: 16071 default_dtype = kwargs['default_dtype'] 16072 16073 module_name = get_nn_module_name_from_kwargs(**kwargs) 16074 16075 if 'constructor' in kwargs: 16076 nn_module = kwargs['constructor'] 16077 else: 16078 nn_module = getattr(torch.nn, module_name) 16079 16080 if "FunctionalModule" in str(nn_module): 16081 return 16082 16083 with set_default_dtype(default_dtype): 16084 if 'constructor_args_fn' in kwargs: 16085 constructor_args = kwargs['constructor_args_fn']() 16086 else: 16087 constructor_args = kwargs.get('constructor_args', ()) 16088 16089 def create_script_module(*args, **kwargs): 16090 """Construct a script module that passes arguments through to self.submodule""" 16091 formals, tensors, actuals = get_script_args(args) 16092 16093 method_args = ', '.join(['self'] + actuals) 16094 call_args_str = ', '.join(actuals) 16095 call = f"self.submodule({call_args_str})" 16096 script = script_method_template.format(method_args, call) 16097 16098 submodule_constants = [] 16099 if kwargs.get('is_constant'): 16100 submodule_constants = ['submodule'] 16101 16102 # Create module to use the script method 16103 class TheModule(torch.jit.ScriptModule): 16104 __constants__ = submodule_constants 16105 16106 def __init__(self) -> None: 16107 super().__init__() 16108 self.submodule = nn_module(*constructor_args) 16109 16110 def make_module(script): 16111 module = TheModule() 16112 # check __repr__ 16113 str(module) 16114 module.define(script) 16115 return module 16116 16117 module = make_module(script) 16118 self.assertExportImportModule(module, tensors) 16119 create_script_module.last_graph = module.graph 16120 mod = module(*args) 16121 return mod 16122 16123 # Construct a normal nn module to stay consistent with create_script_module 16124 # and make use of a single global rng_state in module initialization 16125 def create_nn_module(*args, **kwargs): 16126 module = nn_module(*constructor_args) 16127 return module(*args) 16128 16129 # Set up inputs from tuple of sizes or constructor fn 16130 dtype = torch.get_default_dtype() 16131 if 'input_fn' in kwargs: 16132 input = kwargs['input_fn']() 16133 if isinstance(input, Tensor): 16134 input = (input,) 16135 16136 if all(tensor.is_complex() for tensor in input): 16137 if dtype == torch.float: 16138 dtype = torch.cfloat 16139 elif dtype == torch.double: 16140 dtype = torch.cdouble 16141 else: 16142 raise AssertionError(f"default_dtype {default_dtype} is not supported") 16143 16144 else: 16145 input = (kwargs['input_size'],) 16146 16147 if 'target_size' in kwargs: 16148 input = input + (kwargs['target_size'],) 16149 elif 'target_fn' in kwargs: 16150 if torch.is_tensor(input): 16151 input = (input,) 16152 input = input + (kwargs['target_fn'](),) 16153 elif 'target' in kwargs: 16154 input = input + (kwargs['target'],) 16155 16156 # Extra parameters to forward() 16157 if 'extra_args' in kwargs: 16158 input = input + kwargs['extra_args'] 16159 16160 args_variable, kwargs_variable = create_input(input, dtype=dtype) 16161 f_args_variable = deepcopy(unpack_variables(args_variable)) 16162 16163 # TODO(issue#52052) Neither this nor no_grad should be required 16164 # if check_against_reference() is updated to check gradients 16165 # w.r.t. weights and then only check w.r.t. inputs if any 16166 # inputs require it. 16167 any_requires_grad = any(input.requires_grad for input in f_args_variable) 16168 16169 # Check against Python module as reference 16170 check_against_reference(self, create_script_module, create_nn_module, 16171 lambda x: x, f_args_variable, 16172 no_grad=no_grad or not any_requires_grad) 16173 16174 if 'slowTest' in kwargs: 16175 do_test = slowTest(do_test) 16176 16177 post_add_test(test_name, (), do_test, TestJitGeneratedModule) 16178 16179 16180def post_add_test(test_name, skipTestIf, do_test, test_class): 16181 assert not hasattr(test_class, test_name), 'Two tests have the same name: ' + test_name 16182 16183 for skip in skipTestIf: 16184 do_test = skip(do_test) 16185 16186 if not (TEST_WITH_UBSAN and test_name in UBSAN_DISABLED_TESTS): 16187 setattr(test_class, test_name, do_test) 16188 16189 16190def normalize_check_ad(check_ad, name): 16191 # normalized check_ad is 3-element tuple: (bool, List[str], List[str]) 16192 if len(check_ad) == 0: 16193 check_ad = [False, ['aten::' + name], []] 16194 elif len(check_ad) == 1: 16195 check_ad = [check_ad[0], ['aten::' + name], []] 16196 elif len(check_ad) == 2: 16197 check_ad = [check_ad[0], check_ad[1], []] 16198 elif len(check_ad) == 3: 16199 check_ad = list(check_ad) 16200 else: 16201 raise Exception('Invalid check_ad, requires (bool, str|List[str], str|List[str])') # noqa: TRY002 16202 16203 check_ad = [[t] if isinstance(t, str) else t for t in check_ad] 16204 16205 return check_ad 16206 16207 16208class TestProducerVersion(TestCase): 16209 16210 def test_version(self): 16211 # issue gh-32561 16212 self.assertTrue(torch.__version__.startswith(torch.onnx.producer_version)) 16213 16214for test in module_tests + new_module_tests + additional_module_tests: 16215 add_nn_module_test(**test) 16216 16217for test in criterion_tests: 16218 test['no_grad'] = True 16219 add_nn_module_test(**test) 16220 16221if __name__ == '__main__': 16222 TestCase._default_dtype_check_enabled = True 16223 run_tests() 16224 import jit.test_module_interface 16225 suite = unittest.findTestCases(jit.test_module_interface) 16226 unittest.TextTestRunner().run(suite) 16227