xref: /aosp_15_r20/external/pytorch/test/test_jit.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import torch
4
5# This is how we include tests located in test/jit/...
6# They are included here so that they are invoked when you call `test_jit.py`,
7# do not run these test files directly.
8from jit.test_tracer import TestTracer, TestMixTracingScripting  # noqa: F401
9from jit.test_recursive_script import TestRecursiveScript  # noqa: F401
10from jit.test_type_sharing import TestTypeSharing  # noqa: F401
11from jit.test_logging import TestLogging  # noqa: F401
12from jit.test_backends import TestBackends, TestBackendsWithCompiler  # noqa: F401
13from jit.test_backend_nnapi import TestNnapiBackend  # noqa: F401
14from jit.test_list_dict import TestList, TestDict, TestNamedTuple, TestScriptDict, TestScriptList  # noqa: F401
15from jit.test_async import TestAsync  # noqa: F401
16from jit.test_await import TestAwait  # noqa: F401
17from jit.test_data_parallel import TestDataParallel  # noqa: F401
18from jit.test_models import TestModels  # noqa: F401
19from jit.test_modules import TestModules  # noqa: F401
20from jit.test_autodiff import TestAutodiffJit  # noqa: F401
21from jit.test_autodiff_subgraph_slicing import TestAutodiffSubgraphSlicing  # noqa: F401
22from jit.test_custom_operators import TestCustomOperators  # noqa: F401
23from jit.test_graph_rewrite_passes import TestGraphRewritePasses  # noqa: F401
24from jit.test_class_type import TestClassType  # noqa: F401
25from jit.test_builtins import TestBuiltins, TestTensorBuiltins  # noqa: F401
26from jit.test_ignore_context_manager import TestIgnoreContextManager  # noqa: F401
27from jit.test_symbolic_shape_analysis import TestSymbolicShapeAnalysis  # noqa: F401
28from jit.test_op_decompositions import TestOpDecompositions  # noqa: F401
29from jit.test_unsupported_ops import TestUnsupportedOps  # noqa: F401
30from jit.test_freezing import TestFreezing, TestFrozenOptimizations, TestMKLDNNReinplacing  # noqa: F401
31from jit.test_peephole import TestPeephole  # noqa: F401
32from jit.test_alias_analysis import TestAliasAnalysis  # noqa: F401
33from jit.test_save_load import TestSaveLoad, TestSaveLoadFlatbuffer  # noqa: F401
34from jit.test_save_load_for_op_version import TestSaveLoadForOpVersion  # noqa: F401
35from jit.test_module_containers import TestModuleContainers  # noqa: F401
36from jit.test_python_bindings import TestPythonBindings  # noqa: F401
37from jit.test_python_ir import TestPythonIr  # noqa: F401
38from jit.test_functional_blocks import TestFunctionalBlocks  # noqa: F401
39from jit.test_remove_mutation import TestRemoveMutation  # noqa: F401
40from jit.test_torchbind import TestTorchbind  # noqa: F401
41from jit.test_module_interface import TestModuleInterface  # noqa: F401
42from jit.test_with import TestWith  # noqa: F401
43from jit.test_enum import TestEnum  # noqa: F401
44from jit.test_string_formatting import TestStringFormatting  # noqa: F401
45from jit.test_profiler import TestProfiler  # noqa: F401
46from jit.test_slice import TestSlice  # noqa: F401
47from jit.test_ignorable_args import TestIgnorableArgs  # noqa: F401
48from jit.test_hooks import TestHooks  # noqa: F401
49from jit.test_warn import TestWarn  # noqa: F401
50from jit.test_isinstance import TestIsinstance  # noqa: F401
51from jit.test_cuda import TestCUDA  # noqa: F401
52from jit.test_python_builtins import TestPythonBuiltinOP  # noqa: F401
53from jit.test_typing import TestTyping  # noqa: F401
54from jit.test_hash import TestHash  # noqa: F401
55from jit.test_complex import TestComplex  # noqa: F401
56from jit.test_jit_utils import TestJitUtils  # noqa: F401
57from jit.test_scriptmod_ann import TestScriptModuleInstanceAttributeTypeAnnotation  # noqa: F401
58from jit.test_types import TestTypesAndAnnotation  # noqa: F401
59from jit.test_misc import TestMisc  # noqa: F401
60from jit.test_upgraders import TestUpgraders  # noqa: F401
61from jit.test_pdt import TestPDT  # noqa: F401
62from jit.test_tensor_creation_ops import TestTensorCreationOps  # noqa: F401
63from jit.test_module_apis import TestModuleAPIs  # noqa: F401
64from jit.test_script_profile import TestScriptProfile  # noqa: F401
65from jit.test_convert_activation import TestFunctionalToInplaceActivation, TestInplaceToFunctionalActivation  # noqa: F401
66from jit.test_parametrization import TestParametrization  # noqa: F401
67from jit.test_attr import TestGetDefaultAttr  # noqa: F401
68from jit.test_aten_pow import TestAtenPow  # noqa: F401
69from jit.test_optimize_for_mobile_preserve_debug_info import TestOptimizeForMobilePreserveDebugInfo  # noqa: F401
70from jit.test_union import TestUnion  # noqa: F401
71from jit.test_batch_mm import TestBatchMM  # noqa: F401
72from jit.test_dtype_analysis import TestDtypeAnalysis, TestDtypeCustomRulesCPU  # noqa: F401
73from jit.test_device_analysis import TestDeviceAnalysis  # noqa: F401
74from jit.test_dce import TestDCE  # noqa: F401
75from jit.test_sparse import TestSparse  # noqa: F401
76from jit.test_tensor_methods import TestTensorMethods  # noqa: F401
77from jit.test_dataclasses import TestDataclasses  # noqa: F401
78from jit.test_generator import TestGenerator  # noqa: F401
79
80# Torch
81from torch import Tensor
82from torch._C import TensorType, BoolType, parse_ir, _propagate_shapes
83from torch.autograd import Variable
84from torch.jit.annotations import BroadcastingList2, BroadcastingList3, Any  # noqa: F401
85from torch.nn.utils.rnn import PackedSequence
86from torch.testing import FileCheck, make_tensor
87import torch.autograd.profiler
88import torch.cuda
89import torch.jit
90import torch.jit._logging
91import torch.jit.frontend
92import torch.nn as nn
93import torch.nn.functional as F
94
95# Testing utils
96from torch.testing._internal import jit_utils
97from torch.testing._internal.common_jit import check_against_reference
98from torch.testing._internal.common_utils import run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
99    suppress_warnings, IS_SANDCASTLE, GRAPH_EXECUTOR, ProfilingMode, TestCase, \
100    freeze_rng_state, slowTest, TemporaryFileName, \
101    enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs, \
102    skipIfCrossRef, skipIfTorchDynamo
103from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, disable_autodiff_subgraph_inlining, \
104    _trace, do_input_map, get_execution_plan, make_global, \
105    execWrapper, _inline_everything, _tmp_donotuse_dont_inline_everything, \
106    RUN_CUDA
107from torch.testing._internal.jit_metaprogramming_utils import (
108    get_script_args,
109    create_input, unpack_variables,
110    additional_module_tests, EXCLUDE_SCRIPT_MODULES,
111    get_nn_module_name_from_kwargs, get_nn_mod_test_name, script_method_template)
112
113from torch.testing._internal.common_nn import module_tests, new_module_tests, criterion_tests
114
115# For testing truediv in python 2
116from torch.testing._internal.test_module.future_div import div_int_future, div_float_future
117from torch.testing._internal.test_module.no_future_div import div_int_nofuture, div_float_nofuture
118
119# Standard library
120from collections import defaultdict, namedtuple, OrderedDict
121from copy import deepcopy
122from itertools import product
123from textwrap import dedent
124from typing import List, Dict, NamedTuple, Optional, Tuple, Union
125import copy
126import functools
127import inspect
128import io
129import itertools
130import math
131import numpy as np
132import os
133import pickle
134import pickletools
135import random
136import re
137import shutil
138import string
139import sys
140import tempfile
141import types
142import typing
143import unittest
144import warnings
145import zipfile
146import tracemalloc
147
148
149def canonical(graph):
150    return torch._C._jit_pass_canonicalize(graph).str(False)
151
152def LSTMCellF(input, hx, cx, *params):
153    return LSTMCell(input, (hx, cx), *params)
154
155def doAutodiffCheck(testname):
156    # TODO: setting false on test itself is not working
157    if "test_t_" in testname or testname == "test_t":
158        return False
159
160    if GRAPH_EXECUTOR == ProfilingMode.SIMPLE:
161        return False
162
163    if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
164        return True
165
166
167    # these tests are disabled because BailOut nodes
168    # inserted by ProfilingExecutor interfere with
169    # subgraph slicing of Differentiable Graphs
170    test_exceptions = (
171        # functional
172        'test_nn_dropout',
173        'test_nn_log_softmax',
174        'test_nn_relu',
175        'test_nn_softmax',
176        'test_nn_threshold',
177        'test_nn_lp_pool2d',
178        'test_nn_lp_pool1d',
179        'test_nn_gumbel_softmax_hard',
180        'test_nn_gumbel_softmax',
181        'test_nn_multilabel_soft_margin_loss',
182        'test_nn_batch_norm',
183        'test_nn_max_pool2d_with_indices',
184        # AutogradJitGenerated
185        'test___rdiv___constant',
186        'test___rdiv___scalar_constant',
187        'test_split',
188        'test_split_dim',
189        'test_split_dim_neg0',
190        'test_split_size_list',
191        'test_split_size_list_dim',
192        'test_split_size_list_dim_neg0',
193        'test_split_with_sizes',
194        'test_split_with_sizes_dim',
195        'test_split_with_sizes_dim_neg0',
196        'test_split_with_sizes_size_0',
197        'test_nn_max_pool2d_with_indices',
198    )
199
200    return testname not in test_exceptions
201
202
203# TODO: enable TE in PE when all tests are fixed
204torch._C._jit_set_texpr_fuser_enabled(GRAPH_EXECUTOR == ProfilingMode.PROFILING)
205torch._C._jit_set_profiling_executor(GRAPH_EXECUTOR != ProfilingMode.LEGACY)
206
207def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
208    hx, cx = hidden
209    gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)
210
211    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
212    ingate = torch.sigmoid(ingate)
213    forgetgate = torch.sigmoid(forgetgate)
214    cellgate = torch.tanh(cellgate)
215    outgate = torch.sigmoid(outgate)
216
217    cy = (forgetgate * cx) + (ingate * cellgate)
218    hy = outgate * torch.tanh(cy)
219    return hy, cy
220
221
222def LSTMCellC(*args, **kwargs):
223    hy, cy = LSTMCellF(*args, **kwargs)
224    return torch.cat((hy, cy))
225
226
227def LSTMCellS(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
228    gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh
229    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
230    ingate = torch.sigmoid(ingate)
231    forgetgate = torch.sigmoid(forgetgate)
232    cellgate = torch.tanh(cellgate)
233    outgate = torch.sigmoid(outgate)
234    cy = (forgetgate * cx) + (ingate * cellgate)
235    hy = outgate * torch.tanh(cy)
236    return hy, cy
237
238
239# Code reference: https://github.com/pytorch/translate/blob/master/pytorch_translate/rnn_cell.py#L27:44
240def MiLSTMCell(x, hx, cx, w_ih, w_hh, alpha, beta_i, beta_h, bias):
241    Wx = x.mm(w_ih.t())
242    Uz = hx.mm(w_hh.t())
243    # Section 2.1 in https://arxiv.org/pdf/1606.06630.pdf
244    gates = alpha * Wx * Uz + beta_i * Wx + beta_h * Uz + bias
245    # Same as LSTMCell after this point
246    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
247    ingate = ingate.sigmoid()
248    forgetgate = forgetgate.sigmoid()
249    cellgate = cellgate.tanh()
250    outgate = outgate.sigmoid()
251    cy = (forgetgate * cx) + (ingate * cellgate)
252    hy = outgate * cy.tanh()
253    return hy, cy
254
255
256
257def get_lstm_inputs(device, training=False, seq_length=None):
258    input_shape = (3, 10) if seq_length is None else (seq_length, 3, 10)
259    input = torch.randn(*input_shape, dtype=torch.float, device=device, requires_grad=training)
260    hx = torch.randn(3, 20, dtype=torch.float, device=device, requires_grad=training)
261    cx = torch.randn(3, 20, dtype=torch.float, device=device, requires_grad=training)
262    module = nn.LSTMCell(10, 20).to(device, torch.float)  # Just to allocate weights with correct sizes
263    if training:
264        params = tuple(module.parameters())
265    else:
266        params = tuple(p.requires_grad_(False) for p in module.parameters())
267    return (input, hx, cx) + params
268
269
270def get_milstm_inputs(device, training=False):
271    minibatch = 3
272    input_size = 10
273    hidden_size = 20
274    x = torch.randn(minibatch, input_size, device=device, dtype=torch.float)
275    hx = torch.randn(minibatch, hidden_size, device=device, dtype=torch.float)
276    cx = torch.randn(minibatch, hidden_size, device=device, dtype=torch.float)
277
278    ih = torch.randn(4 * hidden_size, input_size, device=device, dtype=torch.float, requires_grad=training)
279    hh = torch.randn(4 * hidden_size, hidden_size, device=device, dtype=torch.float, requires_grad=training)
280    alpha = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training)
281    ibeta = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training)
282    hbeta = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training)
283    bias = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training)
284    return x, hx, cx, ih, hh, alpha, ibeta, hbeta, bias
285
286
287def get_fn(file_name, script_path):
288    import importlib.util
289    spec = importlib.util.spec_from_file_location(file_name, script_path)
290    module = importlib.util.module_from_spec(spec)
291    spec.loader.exec_module(module)
292    fn = module.fn
293    return fn
294
295def get_grad_executor(plan_state, diff_graph_idx=None, skip_check=False):
296    if diff_graph_idx is None:
297        nodes = list(plan_state.graph.nodes())
298
299        if not skip_check:
300            nodes = list(filter(lambda n : n.kind() != "prim::BailOut" and n.kind() != "prim::BailoutTemplate", nodes))
301            if len(nodes) == 1 or (len(nodes) == 2 and nodes[1].kind() == "prim::TupleConstruct"):
302                pass
303            elif len(nodes) == 2 and nodes[0].kind() == "prim::RequiresGradCheck" and nodes[1].kind() == "prim::If":
304                pass
305            else:
306                raise RuntimeError("Can't get a grad_executor for a non-differentiable graph")
307    grad_executors = list(plan_state.code.grad_executor_states())
308    return grad_executors[diff_graph_idx or 0]
309
310
311def all_backward_graphs(script_module, diff_graph_idx=None):
312    # Note: for Python 2 the order seems to be unstable
313    ge_state = script_module.get_debug_state()
314    fwd_plan = get_execution_plan(ge_state)
315    grad_executor_state = get_grad_executor(fwd_plan, diff_graph_idx=diff_graph_idx)
316    bwd_plans = list(grad_executor_state.execution_plans.values())
317    return [p.graph.copy() for p in bwd_plans]
318
319
320def backward_graph(script_module, diff_graph_idx=None, skip_check=False):
321    ge_state = script_module.get_debug_state()
322    fwd_plan = get_execution_plan(ge_state)
323    grad_executor_state = get_grad_executor(fwd_plan, diff_graph_idx=diff_graph_idx, skip_check=skip_check)
324    bwd_plan = get_execution_plan(grad_executor_state)
325    # Running JIT passes requires that we own the graph (with a shared_ptr).
326    # The debug state struct does not own its graph so we make a copy of it.
327    return bwd_plan.graph.copy()
328
329
330# helper function to get sum of List[Tensor]
331def _sum_of_list(tensorlist):
332    s = 0
333    for t in tensorlist:
334        s += t.sum()
335    return s
336
337
338# has to be at top level or Pickle complains
339class FooToPickle(torch.nn.Module):
340    def __init__(self) -> None:
341        super().__init__()
342        self.bar = torch.jit.ScriptModule()
343
344
345class TestJitProfiler(JitTestCase):
346    """
347    This runs tests that requires setting some global states like torch._C._set_graph_executor_optimize
348    and restore the values afterward, i.e. test_profiler. This is to address the flaky issue in
349    https://github.com/pytorch/pytorch/issues/91483 in which test_profiler was flaky and failed in the
350    middle without the chance to restore torch._C._set_graph_executor_optimize to its original value.
351    This causes issues for all future tests running after.
352
353    Using a separate test class here, so that there is no need to run setup and teardown for all tests
354    in TestJit.
355    """
356
357    def setUp(self):
358        super().setUp()
359        self.graph_executor_optimize_opt = torch._C._get_graph_executor_optimize()
360
361    def tearDown(self):
362        super().tearDown()
363        # Resetting
364        torch._C._set_graph_executor_optimize(
365            self.graph_executor_optimize_opt
366        )
367
368    def test_profiler(self):
369        torch._C._set_graph_executor_optimize(False)
370
371        def other_fn(x):
372            return x * 2
373
374        x = torch.rand(3, 4)
375        traced_other_fn = torch.jit.trace(other_fn, x)
376
377        def fn(x):
378            y = traced_other_fn(x)
379            fut = torch.jit._fork(traced_other_fn, x)
380            y = torch.jit._wait(fut)
381            return y
382
383        traced_fn = torch.jit.trace(fn, x)
384        with torch.autograd.profiler.profile() as prof:
385            traced_fn(x)
386
387        # expecting to see other_fn TS function call
388        # with cpu time >= mul cpu time and
389        # a forked other_fn
390
391        mul_events = defaultdict(int)
392        other_fn_events = defaultdict(int)
393        for e in prof.function_events:
394            if e.name == "aten::mul":
395                self.assertTrue(e.thread not in mul_events)
396                mul_events[e.thread] = e.time_range.elapsed_us()
397            elif e.name == "other_fn":
398                self.assertTrue(e.thread not in other_fn_events)
399                other_fn_events[e.thread] = e.time_range.elapsed_us()
400
401        self.assertTrue(len(mul_events) == 2)
402        self.assertTrue(len(other_fn_events) == 2)
403
404        for thread, mul_time in mul_events.items():
405            self.assertTrue(thread in other_fn_events)
406            self.assertTrue(other_fn_events[thread] >= mul_time)
407
408
409class TestJit(JitTestCase):
410    @unittest.skip("Requires a lot of RAM")
411    def test_big(self):
412        m = torch.jit.ScriptModule()
413        gig = int(1024 * 1024 * 1024 / 4)
414        # a small tensor in the first 4GB
415        m.v0 = nn.Parameter(torch.full((2,), 1, dtype=torch.float))
416        # a large tensor in the first 4GB that ends outside of it
417        m.v1 = nn.Parameter(torch.full((5, gig), 2, dtype=torch.float))
418        # a small tensor in >4GB space
419        m.v2 = nn.Parameter(torch.full((2,), 3, dtype=torch.float))
420        # s large tensor in the > 4GB space
421        m.v3 = nn.Parameter(torch.full((5, gig), 4, dtype=torch.float))
422
423        m2 = self.getExportImportCopy(m)
424
425        self.assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
426
427    def test_inferred_as_tensor(self):
428        with self.assertRaisesRegex(RuntimeError, "Inferred the value for argument 'dim' to be of type 'Tensor' "
429                                                  "because it was not annotated with an explicit type"):
430            @torch.jit.script
431            def dot(points, query, dim):
432                return (points * query).sum(dim)
433
434    def test_constants_pkl(self):
435        # This test asserts that the serialization archive includes a `constants.pkl`
436        # file. This file is used by `torch.load` to determine whether a zip file
437        # is a normal eager-mode serialization zip or a jit serialization zip. If
438        # you are deleting `constants.pkl`, make sure to update `torch.serialization.load`
439        # so it is still able to figure out which is which.
440        @torch.jit.script
441        def fn(x):
442            return x
443
444        buf = io.BytesIO()
445        torch.jit.save(fn, buf)
446        buf.seek(0)
447
448        files = zipfile.ZipFile(buf).filelist
449        self.assertTrue(any('archive/constants.pkl' == f.filename for f in files))
450
451    def test_script_fn_pkl(self):
452        with self.assertRaisesRegex(pickle.PickleError, "ScriptFunction cannot be pickled"):
453
454            @torch.jit.script
455            def fn(x: torch.Tensor) -> torch.Tensor:
456                return x
457
458            pkl_fn = pickle.dumps(fn, protocol=0)
459
460    def test_restore_device(self):
461        class M(torch.jit.ScriptModule):
462            def __init__(self, cpu_device_str):
463                super().__init__()
464                self.p0 = nn.Parameter(torch.tensor([0.3], dtype=torch.float,
465                                                    device=cpu_device_str))
466                self.b0 = torch.tensor([0.9], dtype=torch.float,
467                                       device=cpu_device_str)
468
469        # main purpose is checking map_location works
470        m = M("cpu")
471        m2 = self.getExportImportCopy(m)
472        self.assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
473        self.assertEqual(tuple(m.buffers()), tuple(m2.buffers()))
474        self.assertFalse(m2.p0.is_cuda)
475        self.assertFalse(m2.b0.is_cuda)
476
477    @unittest.skipIf(not RUN_CUDA, "restore device requires CUDA")
478    def test_restore_device_cuda(self):
479        class MyModule(torch.jit.ScriptModule):
480            def __init__(self) -> None:
481                super().__init__()
482                self.b0 = nn.Buffer(torch.randn(1, 3))
483                self.p0 = nn.Parameter(torch.randn(2, 3))
484
485            @torch.jit.script_method
486            def forward(self, x):
487                return x + self.b0 + self.p0
488
489        m = MyModule()
490        m.cuda(torch.cuda.device_count() - 1)
491        cuda_device_str = 'cuda:' + str(torch.cuda.device_count() - 1)
492
493        self.assertTrue(m.p0.is_cuda)
494        self.assertTrue(m.b0.is_cuda)
495
496        # restore to the saved devices
497        m2 = self.getExportImportCopy(m)
498        self.assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
499        self.assertEqual(tuple(m.buffers()), tuple(m2.buffers()))
500        self.assertEqual(str(m2.p0.device), cuda_device_str)
501        self.assertEqual(str(m2.b0.device), cuda_device_str)
502
503        # restore all to cpu using string
504        cpu_device_str = 'cpu'
505        m3 = self.getExportImportCopy(m, map_location=cpu_device_str)
506        self.assertEqual(str(m3.p0.device), cpu_device_str)
507        self.assertEqual(str(m3.b0.device), cpu_device_str)
508
509        # restore all to first gpu using device
510        m4 = self.getExportImportCopy(
511            m3, map_location=torch.device('cuda:0'))
512        self.assertEqual(str(m4.p0.device), 'cuda:0')
513        self.assertEqual(str(m4.b0.device), 'cuda:0')
514
515        # compute and compare the results
516        input = torch.rand(2, 3).cuda(torch.cuda.device_count() - 1)
517        origin_result = m(input)
518        self.assertEqual(origin_result, m2(input))
519        self.assertEqual(origin_result, m3(input.cpu()))
520        self.assertEqual(origin_result, m4(input.cuda(0)))
521
522    def test_trace_retains_train(self):
523        class M(torch.nn.Module):
524            def forward(self, x):
525                return x
526        m = M()
527        m.eval()
528        tm = torch.jit.trace(m, (torch.rand(3)))
529        self.assertEqual(tm.training, m.training)
530
531    @unittest.skipIf(not RUN_CUDA, "restore device requires CUDA")
532    def test_restore_shared_storage_on_cuda(self):
533        class Foo(torch.jit.ScriptModule):
534            def __init__(self) -> None:
535                super().__init__()
536                whole_tensor = torch.randn(4, 5, dtype=torch.float, device='cpu')
537                self.p0 = nn.Parameter(whole_tensor.narrow(0, 0, 1))
538                self.b0 = nn.Buffer(whole_tensor.narrow(0, 3, 1))
539
540        m = Foo()
541        m2 = self.getExportImportCopy(m, map_location=torch.device('cuda:0'))
542        self.assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
543        self.assertEqual(tuple(m.buffers()), tuple(m2.buffers()))
544        self.assertTrue(m2.p0.is_cuda)
545        self.assertTrue(m2.b0.is_cuda)
546        self.assertTrue(m2.p0.is_shared())
547        self.assertTrue(m2.b0.is_shared())
548        self.assertEqual(m2.b0.storage().data_ptr(), m2.p0.storage().data_ptr())
549
550    def test_add_relu_fusion(self):
551        class M(torch.nn.Module):
552            def __init__(self, relu_op):
553                super().__init__()
554                self.relu_op = relu_op
555
556            def forward(self, a, b, c):
557                tmp = torch.add(a, b)
558                x = self.relu_op(tmp)
559                d = torch.add(a, c)
560                return x + d
561        a = torch.rand((7, 11))
562        a = a * -10
563        a = a + 5
564        b = torch.rand((7, 11))
565        c = torch.rand((7, 11))
566        m = torch.jit.script(M(torch.relu))
567        orig_res = m(a, b, c)
568        torch._C._jit_pass_fuse_add_relu(m.graph)
569        buffer = io.BytesIO()
570        torch.jit.save(m, buffer)
571        buffer.seek(0)
572        m = torch.jit.load(buffer)
573        new_res = m(a, b, c)
574        FileCheck().check_not("aten::relu(") \
575            .check("aten::_add_relu(") \
576            .run(m.graph)
577        torch.testing.assert_close(orig_res, new_res)
578
579        # add, relu_
580        a = torch.rand((7, 11))
581        a = a * -10
582        a = a + 5
583        b = torch.rand((7, 11))
584        c = torch.rand((7, 11))
585        m = torch.jit.script(M(torch.relu_))
586        orig_res = m(a, b, c)
587        torch._C._jit_pass_fuse_add_relu(m.graph)
588        buffer = io.BytesIO()
589        torch.jit.save(m, buffer)
590        buffer.seek(0)
591        m = torch.jit.load(buffer)
592        new_res = m(a, b, c)
593        FileCheck().check_not("aten::relu_(") \
594            .check("aten::_add_relu(") \
595            .run(m.graph)
596        torch.testing.assert_close(orig_res, new_res)
597
598        class Madd_(torch.nn.Module):
599            def __init__(self, relu_op):
600                super().__init__()
601                self.relu_op = relu_op
602
603            def forward(self, a, b):
604                x = a.add_(b)
605                x = self.relu_op(x)
606                return x
607
608        # add_, relu_
609        a = torch.rand((7, 11))
610        a = a * -10
611        a = a + 5
612        b = torch.rand((7, 11))
613        # Because in place add_ will overwrite a
614        a_copy = a.clone()
615        m = torch.jit.script(Madd_(torch.relu_))
616        orig_res = m(a, b)
617        torch._C._jit_pass_fuse_add_relu(m.graph)
618        buffer = io.BytesIO()
619        torch.jit.save(m, buffer)
620        buffer.seek(0)
621        m = torch.jit.load(buffer)
622        new_res = m(a_copy, b)
623        FileCheck().check_not("aten::add_(") \
624            .check_not("aten::relu_(") \
625            .check("aten::_add_relu_(") \
626            .run(m.graph)
627        torch.testing.assert_close(orig_res, new_res)
628        # Since _add_relu_ does inplace mutation ensure
629        # a_copy is modified
630        torch.testing.assert_close(orig_res, a_copy)
631
632        class Madd_out(torch.nn.Module):
633            def __init__(self, relu_op):
634                super().__init__()
635                self.relu_op = relu_op
636
637            def forward(self, a, b):
638                x = torch.add(a, b, out=a)
639                x = self.relu_op(x)
640                return x
641        a = torch.rand((7, 11))
642        a = a * -10
643        a = a + 5
644        b = torch.rand((7, 11))
645
646        # add_out, relu_
647        a = torch.rand((7, 11))
648        a = a * -10
649        a = a + 5
650        b = torch.rand((7, 11))
651        # Because in place add_ will overwrite a
652        a_copy = a.clone()
653        m = torch.jit.script(Madd_out(torch.relu_))
654        orig_res = m(a, b)
655        torch._C._jit_pass_fuse_add_relu(m.graph)
656        buffer = io.BytesIO()
657        torch.jit.save(m, buffer)
658        buffer.seek(0)
659        m = torch.jit.load(buffer)
660        new_res = m(a_copy, b)
661        FileCheck().check_not("aten::add(") \
662            .check_not("aten::relu_(") \
663            .check("aten::_add_relu(") \
664            .run(m.graph)
665        torch.testing.assert_close(orig_res, new_res)
666        # Since _add_relu_ with out=a does inplace mutation ensure
667        # a_copy is modified
668        torch.testing.assert_close(orig_res, a_copy)
669
670    def test_repeat_interleave_script(self):
671        def fn(input: torch.Tensor, repeats: torch.Tensor) -> torch.Tensor:
672            output = input.repeat_interleave(repeats)
673            return output
674        fn_scripted = torch.jit.script(fn)
675
676        input = torch.tensor([5, 7], dtype=torch.int64)
677        repeats = torch.tensor([3, 6], dtype=torch.int64)
678
679        output = fn(input, repeats)
680        output_scripted = fn_scripted(input, repeats)
681        self.assertEqual(output_scripted, output)
682
683    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "Simple executor doesn't have shape information")
684    def test_peephole_optimize_shape_ops(self):
685        def test_input(func, input, result):
686            # if result == 2 we will trigger a bailout and
687            # the unprofiled graph should return the correct result
688            self.assertEqual(func(input, profile_and_replay=True), result)
689            gre = func.graph_for(input)
690            FileCheck().check_not("prim::If").run(gre)
691
692        def test_dim():
693            @torch.jit.script
694            def func(x):
695                if x.dim() == 1:
696                    return 1
697                else:
698                    return 2
699
700            test_input(func, torch.tensor([0.5]), 1)
701            test_input(func, torch.tensor([[0.5]]), 2)
702        test_dim()
703
704        def test_size_index():
705            @torch.jit.script
706            def func(x):
707                if x.size(0) == 1:
708                    return 1
709                else:
710                    return 2
711
712            test_input(func, torch.rand([1, 2]), 1)
713            test_input(func, torch.rand([1, 3]), 1)
714
715            @torch.jit.script
716            def neg_index(x):
717                if x.size(-2) == 1:
718                    return 1
719                else:
720                    return 2
721
722            test_input(neg_index, torch.rand([1, 2]), 1)
723            test_input(neg_index, torch.rand([1, 3]), 1)
724
725        if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
726            test_size_index()
727
728        def test_dtype():
729            @torch.jit.script
730            def func(x):
731                if x.dtype == torch.float32:
732                    return 1
733                else:
734                    return 2
735
736            test_input(func, torch.tensor(0.5, dtype=torch.float32), 1)
737            test_input(func, torch.tensor(0.5, dtype=torch.int64), 2)
738        test_dtype()
739
740        def test_is_floating_poiint():
741            @torch.jit.script
742            def func(x):
743                if x.is_floating_point():
744                    return 1
745                else:
746                    return 2
747
748            test_input(func, torch.tensor(0.5, dtype=torch.float32), 1)
749            test_input(func, torch.tensor(0.5, dtype=torch.int64), 2)
750        test_is_floating_poiint()
751
752        def test_device():
753            @torch.jit.script
754            def func_1(x):
755                if x.device == torch.device('cuda:0'):
756                    a = 0
757                else:
758                    a = 1
759                return a
760
761            @torch.jit.script
762            def func_2(x):
763                if x.is_cuda:
764                    a = 0
765                else:
766                    a = 1
767                return a
768
769            test_input(func_1, torch.tensor(0.5), 1)
770            test_input(func_2, torch.tensor(0.5), 1)
771
772            if RUN_CUDA:
773                test_input(func_1, torch.tensor(0.5, device="cuda:0"), 0)
774                test_input(func_2, torch.tensor(0.5, device="cuda:0"), 0)
775
776        test_device()
777
778    def test_attrs(self):
779        def foo(x):
780            return (
781                # x.dtype, TODO: dtype long -> instance conversion
782                x.device,
783                x.shape,
784                x.is_cuda,
785                x.is_mkldnn,
786                x.is_quantized,
787                x.requires_grad,
788                x.T,
789                x.mT,
790                x.H,
791                x.mH
792                # x.layout TODO: layout long -> instance conversion
793            )
794
795        scripted = torch.jit.script(foo)
796        x = torch.rand(3, 4)
797        self.assertEqual(scripted(x), foo(x))
798
799    def test_layout(self):
800        @torch.jit.script
801        def check(x, y):
802            return x.layout == y.layout
803
804        x = torch.rand(3, 4)
805        y = torch.rand(3, 4)
806
807        self.assertTrue(check(x, y))
808
809    def test_matrix_transpose(self):
810        @torch.jit.script
811        def check(x):
812            return torch.equal(x.mT, x.transpose(-2, -1))
813
814        x = torch.rand(3, 4)
815        self.assertTrue(check(x))
816
817    def test_transpose(self):
818        @torch.jit.script
819        def check(x):
820            return torch.equal(x.T, x.t())
821
822        x = torch.rand(3, 4)
823        self.assertTrue(check(x))
824
825    def test_matrix_conj_transpose(self):
826        @torch.jit.script
827        def check(x):
828            return torch.equal(x.mH, x.transpose(-2, -1).conj())
829
830        x = torch.rand(3, 4)
831        self.assertTrue(check(x))
832
833        x = make_tensor((3, 4), device="cpu", dtype=torch.complex64)
834        self.assertTrue(check(x))
835
836    def test_conj_transpose(self):
837        @torch.jit.script
838        def check(x):
839            return torch.equal(x.H, x.t().conj())
840
841        x = torch.rand(3, 4)
842        self.assertTrue(check(x))
843
844        x = make_tensor((3, 4), device="cpu", dtype=torch.complex64)
845        self.assertTrue(check(x))
846
847    def test_T_mT_H_mH(self):
848        def T(x):
849            return x.mT
850
851        def mT(x):
852            return x.mT
853
854        def H(x):
855            return x.H
856
857        def mH(x):
858            return x.mH
859
860        x = torch.rand(3, 4)
861        y = make_tensor((3, 4), device="cpu", dtype=torch.complex64)
862
863        self.checkScript(T, (x, ))
864        self.checkScript(mT, (x, ))
865        self.checkScript(H, (x, ))
866        self.checkScript(mH, (x, ))
867        self.checkScript(T, (y, ))
868        self.checkScript(mT, (y, ))
869        self.checkScript(H, (y, ))
870        self.checkScript(mH, (y, ))
871
872    def test_nn_conv(self):
873        class Mod(nn.Module):
874            def __init__(self, conv):
875                super().__init__()
876                self.conv = conv
877
878            def forward(self, input):
879                return self.conv(input)
880
881        inputs = [
882            # Conv
883            (Mod(nn.Conv1d(16, 33, 3, stride=2)), torch.randn(20, 16, 5)),
884            (Mod(nn.Conv2d(16, 33, 3, stride=2)), torch.randn(20, 16, 5, 10)),
885            (Mod(nn.Conv3d(16, 33, 3, stride=2)), torch.randn(20, 16, 3, 5, 4)),
886            # ConvTransposed
887            (Mod(nn.ConvTranspose1d(16, 33, 3, stride=2)), torch.randn(20, 16, 5)),
888            (Mod(nn.ConvTranspose2d(16, 33, 3, stride=2)), torch.randn(20, 16, 5, 10)),
889            (Mod(nn.ConvTranspose3d(16, 33, 3, stride=2)), torch.randn(20, 16, 3, 5, 4)),
890        ]
891
892        for m, inp in inputs:
893            self.checkModule(m, (inp,))
894
895    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, 'Not implemented for Simple or Legacy')
896    def test_debug_flush_compilation_cache(self):
897        def foo(x):
898            return x + 2
899
900        class Mod(nn.Module):
901            def forward(self, t):
902                return t + 2
903
904        m = torch.jit.script(Mod())
905        x = torch.rand(1, 10)
906
907        with enable_profiling_mode_for_profiling_tests():
908            jitted = self.checkScript(foo, (x,))
909            # shouldn't throw
910            states = jitted.get_debug_state()
911
912            # after flushing there shouldn't be
913            # no opt plan
914            jitted._debug_flush_compilation_cache()
915            with self.assertRaisesRegex(RuntimeError, "INTERNAL ASSERT FAILED"):
916                states = jitted.get_debug_state()
917
918            NUM_RUNS = 1
919            with num_profiled_runs(NUM_RUNS):
920                m(x)
921                m(x)
922                fwd = m._c._get_method("forward")
923                states = m.get_debug_state()
924
925                # after flushing there shouldn't be
926                # no opt plan
927                fwd._debug_flush_compilation_cache()
928                with self.assertRaisesRegex(RuntimeError, "INTERNAL ASSERT FAILED"):
929                    states = m.get_debug_state()
930
931    def test_numel(self):
932        @torch.jit.script
933        def get_numel_script(x):
934            return x.numel()
935
936        x = torch.rand(3, 4)
937        numel = get_numel_script(x)
938        self.assertEqual(numel, x.numel())
939
940    def test_element_size(self):
941        @torch.jit.script
942        def get_element_size_script(x):
943            return x.element_size()
944
945        x = torch.rand(3, 4)
946        element_size = get_element_size_script(x)
947        self.assertEqual(element_size, x.element_size())
948
949    def test_Sequential(self):
950        class Seq(nn.Module):
951            def __init__(self) -> None:
952                super().__init__()
953                self.seq = nn.Sequential(nn.Linear(10, 20), nn.Linear(20, 30))
954
955            @torch.jit.script_method
956            def forward(self, x):
957                for l in self.seq:
958                    x = l(x)
959                return x
960
961        m = torch.jit.script(Seq())
962        assert m.graph  # ensure jit was able to compile
963
964    def test_ModuleList(self):
965        class Mod(nn.Module):
966            def __init__(self) -> None:
967                super().__init__()
968                self.model = nn.ModuleList([nn.Linear(10, 10) for _ in range(10)])
969                self.model += (nn.Linear(10, 20),)
970                self.model.append(nn.Linear(20, 30))
971                self.model.extend([nn.Linear(30, 40), nn.Linear(40, 50)])
972
973            def forward(self, v):
974                for m in self.model:
975                    v = m(v)
976                return v
977
978        m = torch.jit.script(Mod())
979        assert m.graph  # ensure jit was able to compile
980
981    def test_disabled(self):
982        torch.jit._state.disable()
983        try:
984            def f(x, y):
985                return x + y
986
987            self.assertIs(torch.jit.trace(f, (torch.randn(2, 2), torch.randn(2, 2))), f)
988            self.assertIs(torch.jit.script(f), f)
989
990            class MyModule(torch.jit.ScriptModule):
991                @torch.jit.script_method
992                def method(self, x):
993                    return x
994
995            # XXX: Unfortunately ScriptModule won't simply become Module now,
996            # because that requires disabling the JIT at startup time, which
997            # we can't do in here.
998            # We need to or those two conditions to make it work with all versions of Python
999            self.assertTrue(inspect.ismethod(MyModule.method) or inspect.isfunction(MyModule.method))
1000        finally:
1001            torch.jit._state.enable()
1002
1003    def test_train_eval(self):
1004        class Sub(nn.Module):
1005            def forward(self, input):
1006                if self.training:
1007                    return input
1008                else:
1009                    return -input
1010
1011        class MyModule(torch.jit.ScriptModule):
1012            def __init__(self, module):
1013                super().__init__()
1014                self.module = module
1015
1016            @torch.jit.script_method
1017            def forward(self, input):
1018                return self.module(input) + 1
1019
1020        m = MyModule(Sub())
1021        input = torch.rand(3, 4)
1022        self.assertEqual(input + 1, m(input))
1023        m.eval()
1024        self.assertEqual(-input + 1, m(input))
1025
1026        # test batchnorm and dropout train/eval
1027        input = torch.randn(6, 10)
1028        batchnorm = nn.BatchNorm1d(10)
1029        dropout = nn.Dropout(p=0.2)
1030
1031        m_batchnorm = MyModule(batchnorm)
1032        self.assertEqual(batchnorm(input) + 1, m_batchnorm(input))
1033        batchnorm.eval()
1034        m_batchnorm.eval()
1035        self.assertEqual(batchnorm(input) + 1, m_batchnorm(input))
1036
1037        m_dropout = MyModule(dropout)
1038        dropout.eval()
1039        m_dropout.eval()
1040        self.assertEqual(dropout(input) + 1, m_dropout(input))
1041
1042    def test_nn_lp_pool2d(self):
1043        class Mod(torch.nn.Module):
1044            def __init__(self) -> None:
1045                super().__init__()
1046                self.l = torch.nn.LPPool2d(2, 3)
1047                self.n = torch.nn.LPPool2d(2, (7, 1))
1048
1049            def forward(self, x):
1050                return (self.l(x),
1051                        self.n(x),
1052                        torch.nn.functional.lp_pool2d(x, float(2), 3),
1053                        torch.nn.functional.lp_pool2d(x, 2, 3),
1054                        torch.nn.functional.lp_pool2d(x, float(2), (7, 1)))
1055
1056        self.checkModule(Mod(), (torch.rand(1, 3, 7, 7),))
1057
1058    def test_nn_lp_pool1d(self):
1059        class Mod(torch.nn.Module):
1060            def __init__(self) -> None:
1061                super().__init__()
1062                self.l = torch.nn.LPPool1d(2, 3)
1063                self.n = torch.nn.LPPool1d(2, 7)
1064
1065            def forward(self, x):
1066                return (self.l(x),
1067                        self.n(x),
1068                        torch.nn.functional.lp_pool1d(x, float(2), 3),
1069                        torch.nn.functional.lp_pool1d(x, 2, 3),
1070                        torch.nn.functional.lp_pool1d(x, float(2), 7))
1071
1072        self.checkModule(Mod(), (torch.rand(1, 3, 7),))
1073
1074    def test_nn_padding_functional(self):
1075        class Mod(nn.Module):
1076            def __init__(self, *pad):
1077                super().__init__()
1078                self.pad = pad
1079
1080            def forward(self, x):
1081                return F.pad(x, self.pad, mode='constant', value=3.5)
1082
1083        inputs = [
1084            (Mod(1, 2), torch.randn(1, 3, 4)),  # 1D
1085            (Mod(1, 2, 3, 4), torch.randn(1, 3, 4)),  # 2D
1086            (Mod(1, 2, 3, 4, 5, 6), torch.randn(1, 3, 4)),  # 3D
1087        ]
1088
1089        for m, inp in inputs:
1090            self.checkModule(m, (inp,))
1091
1092    def test_nn_padding(self):
1093        class Mod(nn.Module):
1094            def __init__(self, padding):
1095                super().__init__()
1096                self.padding = padding
1097
1098            def forward(self, input):
1099                return self.padding(input)
1100
1101        inputs = [
1102            (Mod(nn.ConstantPad1d(2, 3.5)), torch.randn(1, 2, 4)),
1103            (Mod(nn.ConstantPad2d(2, 3.5)), torch.randn(1, 2, 2)),
1104            (Mod(nn.ConstantPad3d(3, 3.5)), torch.randn(16, 3, 10, 20, 30)),
1105            (Mod(nn.ReflectionPad1d(2)), torch.arange(8, dtype=torch.float).reshape(1, 2, 4)),
1106            (Mod(nn.ReflectionPad2d(2)), torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3)),
1107            (Mod(nn.ReflectionPad3d(3)), torch.randn(16, 3, 8, 32, 48)),
1108            (Mod(nn.ReplicationPad1d(2)), torch.arange(8, dtype=torch.float).reshape(1, 2, 4)),
1109            (Mod(nn.ReplicationPad2d(2)), torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3)),
1110            (Mod(nn.ReplicationPad3d(3)), torch.randn(16, 3, 8, 32, 48)),
1111            (Mod(nn.ZeroPad2d(2)), torch.randn(1, 1, 3, 3))
1112        ]
1113
1114        for m, inp in inputs:
1115            self.checkModule(m, (inp,))
1116
1117    def test_script_autograd_grad(self):
1118        def test_simple_grad(x, y):
1119            # type: (Tensor, Tensor) -> List[Optional[Tensor]]
1120            z = x + 2 * y + x * y
1121            return torch.autograd.grad((z.sum(), ), (x, y))
1122
1123        def test_simple_grad_with_grad_outputs(x, y):
1124            # type: (Tensor, Tensor) -> List[Optional[Tensor]]
1125            z = x + 2 * y + x * y
1126            grad_outputs = torch.jit.annotate(List[Optional[torch.Tensor]], [torch.ones((2, 2)), ])
1127            return torch.autograd.grad((z, ), (x, y), grad_outputs)
1128
1129        def test_one_output_not_requires_grad(x, y):
1130            # type: (Tensor, Tensor) -> List[Optional[Tensor]]
1131            z = 2 * y + y
1132            return torch.autograd.grad((z.sum(),), (x, y), allow_unused=True)
1133
1134        def test_retain_graph(x, y):
1135            # type: (Tensor, Tensor) -> None
1136            z = x + 2 * y + x * y
1137            torch.autograd.grad((z.sum(), ), (x, y), retain_graph=True)
1138            torch.autograd.grad((z.sum(), ), (x, y))
1139
1140        x = torch.randn(2, 2, requires_grad=True)
1141        y = torch.randn(2, 2, requires_grad=True)
1142        self.checkScript(test_simple_grad, (x, y), inputs_requires_grad=True)
1143        self.checkScript(test_simple_grad_with_grad_outputs, (x, y), inputs_requires_grad=True)
1144        self.checkScript(test_one_output_not_requires_grad, (x, y), inputs_requires_grad=True)
1145        self.checkScript(test_retain_graph, (x, y), inputs_requires_grad=True)
1146
1147    def test_script_backward(self):
1148        def checkBackwardScript(fn, inputs):
1149            scripted_fn = torch.jit.script(fn)
1150            FileCheck().check("torch.autograd.backward").run(scripted_fn.code)
1151            recording_inputs = do_input_map(lambda t: t.detach().requires_grad_(), inputs)
1152
1153            fn(*inputs)
1154            scripted_fn(*recording_inputs)
1155
1156            for inp1, inp2 in zip(inputs, recording_inputs):
1157                self.assertEqual(inp1.grad, inp2.grad)
1158
1159        def test_tensor_backward(input):
1160            # type: (Tensor) -> None
1161            output = torch.relu(input)
1162            output = output.softmax(0)
1163            sum_out = output.sum()
1164            sum_out.backward()
1165
1166        def test_torch_autograd_backward(input):
1167            # type: (Tensor) -> None
1168            output = torch.relu(input)
1169            output = output.softmax(0)
1170            torch.autograd.backward(output.sum())
1171
1172        def test_torch_autograd_backward_with_grad_tensors(input):
1173            # type: (Tensor) -> None
1174            output = torch.relu(input)
1175            output = output.softmax(0)
1176            grad_outputs = torch.jit.annotate(List[Optional[torch.Tensor]], [torch.ones((2, 2)), ])
1177            torch.autograd.backward((output,), grad_outputs)
1178
1179        inp = torch.randn(2, 2, requires_grad=True)
1180        checkBackwardScript(test_tensor_backward, (inp,))
1181        checkBackwardScript(test_torch_autograd_backward, (inp,))
1182        checkBackwardScript(test_torch_autograd_backward_with_grad_tensors, (inp,))
1183
1184    def test_script_backward_twice(self):
1185        def checkBackwardTwiceScript(fn, inputs, retain_graph_=False):
1186            class jit_profiling_executor_false:
1187                def __enter__(self):
1188                    torch._C._jit_set_profiling_executor(False)
1189
1190                def __exit__(self, *args):
1191                    torch._C._jit_set_profiling_executor(GRAPH_EXECUTOR != ProfilingMode.LEGACY)
1192
1193            with jit_profiling_executor_false(), torch.jit.optimized_execution(True):
1194                scripted_fn = torch.jit.script(fn, inputs)
1195                FileCheck().check("prim::DifferentiableGraph").run(scripted_fn.graph_for(*inputs))
1196
1197                result = scripted_fn(*inputs)
1198                result.sum().backward(retain_graph=retain_graph_)
1199                if not retain_graph_:
1200                    self.assertRaisesRegex(RuntimeError, 'Specify retain_graph=True',
1201                                           lambda: result.sum().backward())
1202                else:
1203                    result.sum().backward()
1204
1205        def test_script_backward_twice_with_saved_values(input1, input2):
1206            # type: (Tensor, Tensor) -> Tensor
1207            tmp1 = torch.mul(input1, input2)
1208            tmp2 = torch.abs(tmp1)
1209            if torch.equal(input1, input2):
1210                tmp2 = torch.acos(tmp2)
1211            else:
1212                tmp2 = torch.atan(tmp2)
1213            result = torch.add(tmp2, input2)
1214            return result
1215
1216        inp1 = torch.randn(2, 2, requires_grad=True)
1217        inp2 = torch.randn(2, 2, requires_grad=True)
1218        checkBackwardTwiceScript(test_script_backward_twice_with_saved_values, (inp1, inp2), False)
1219        checkBackwardTwiceScript(test_script_backward_twice_with_saved_values, (inp1, inp2), True)
1220
1221    def test_diff_subgraph_clones_constants(self):
1222        @torch.jit.script
1223        def f(x, y):
1224            return x + x + y + x + y + x + y + x + y + x
1225
1226        def count_constants(graph):
1227            return sum(node.kind() == 'prim::Constant' for node in graph.nodes())
1228
1229        graph = f.graph.copy()
1230        self.run_pass('cse', graph)
1231        self.run_pass('create_autodiff_subgraphs', graph)
1232        nodes = list(graph.nodes())
1233        self.assertEqual(count_constants(graph), 1)
1234        self.assertEqual(count_constants(nodes[1].g('Subgraph')), 1)
1235
1236    # TODO: adapt this test to check that GraphExecutor treats them differently
1237    @unittest.skip("Need to be adjusted to Graph Executor")
1238    def test_arg_configurations(self):
1239        """Different arg configurations should trigger different traces"""
1240        x = Variable(torch.FloatTensor(4, 4).uniform_())
1241        x_double = Variable(x.data.double())
1242        x_grad = Variable(x.data.clone(), requires_grad=True)
1243        y = Variable(torch.randn(4))
1244
1245        configurations = [
1246            (x,),
1247            (x_double,),
1248            (x_grad,),
1249            (y,),
1250            ([x, x],),
1251            ([x, y],),
1252        ]
1253        if torch.cuda.is_available():
1254            x_cuda = Variable(x.data.cuda())
1255            configurations += [
1256                (x_cuda,),
1257                ([x, x_cuda],),
1258                ([x_cuda, x],),
1259                ([[x_cuda, x]],),
1260            ]
1261            if torch.cuda.device_count() > 1:
1262                x_cuda_1 = Variable(x.data.cuda(1))
1263                configurations += [
1264                    (x_cuda_1,),
1265                    ([x_cuda, x_cuda_1],),
1266                ]
1267
1268        @torch.jit.compile(nderivs=0)
1269        def fn(*args):
1270            in_vars, _ = torch._C._jit_flatten(args)
1271            return in_vars[0] + 1
1272
1273        for i, config in enumerate(configurations):
1274            self.assertFalse(fn.has_trace_for(*config))
1275            fn(*config)
1276            self.assertTrue(fn.has_trace_for(*config))
1277            for unk_config in configurations[i + 1:]:
1278                self.assertFalse(fn.has_trace_for(*unk_config))
1279        self.assertEqual(fn.hits, 0)
1280
1281    def test_torch_sum(self):
1282        def fn(x):
1283            return torch.sum(x)
1284
1285        def fn1(x, dim: int):
1286            return torch.sum(x, dim)
1287
1288        x = torch.randn(3, 4)
1289        self.checkScript(fn, (x, ))
1290        self.checkScript(fn1, (x, 1, ))
1291        self.checkScript(fn1, (x, 0, ))
1292
1293    def test_cse(self):
1294        x = torch.tensor([0.4, 0.3], requires_grad=True)
1295        y = torch.tensor([0.7, 0.5], requires_grad=True)
1296
1297        def fn(x, y):
1298            w = (x + y) * (x + y) * (x + y)
1299            t = torch.tanh(w) + torch.tanh(w)
1300            z = (x + y) * (x + y) * (x + y) + t
1301            return z
1302
1303        g, _ = torch.jit._get_trace_graph(fn, (x, y))
1304        self.run_pass('cse', g)
1305        do_exactly = True
1306        FileCheck().check_count("add", 1).check_count("mul", 2, do_exactly) \
1307            .check_count("tanh", 1, do_exactly).check_count("add", 2, do_exactly).check_next("return")  \
1308            .run(str(g))
1309
1310        self.assertExportImport(g, (x, y))
1311
1312    def test_cse_not_introduce_aliasing(self):
1313        @torch.jit.script
1314        def tensor_alias_outputs(x):
1315            return x + x, x + x
1316
1317        self.run_pass('cse', tensor_alias_outputs.graph)
1318        FileCheck().check_count("aten::add", 2).run(tensor_alias_outputs.graph)
1319
1320        @torch.jit.script
1321        def ints_alias_outputs(x):
1322            # type: (int) -> Tuple[int, int]
1323            return x + x, x + x
1324
1325        # non-aliasing types can be CSEd
1326        self.run_pass('cse', ints_alias_outputs.graph)
1327        FileCheck().check_count("aten::add", 1, exactly=True).run(ints_alias_outputs.graph)
1328
1329    def test_recursive_cse(self):
1330        input_str = """
1331graph(%x : Tensor,
1332      %y : Tensor,
1333      %20 : int):
1334  %2 : int = prim::Constant[value=1]()
1335  %3 : Tensor = aten::add(%x, %y, %2)
1336  %4 : int = aten::add(%2, %20)
1337  %5 : bool = aten::Bool(%4)
1338  %z : int = prim::If(%5)
1339    # CHECK: block
1340    block0():
1341      # CHECK-NOT: aten::add
1342      %z.1 : int = aten::add(%2, %20)
1343      -> (%z.1)
1344    block1():
1345      -> (%2)
1346  return (%z)
1347"""
1348        graph = parse_ir(input_str)
1349        self.run_pass('cse', graph)
1350        FileCheck().run(input_str, graph)
1351
1352    def test_pattern_based_rewrite(self):
1353        # mul(mul(mul(mul(x,y),z),x),y) --> mul(mul(mulmul(x,y,z), x), y) -->
1354        # --> mulmul(mulmul(x,y,z), x, y)
1355        input_str = """
1356graph(%x, %y, %z):
1357    # CHECK-NOT: aten::mul
1358    # CHECK: my::fused_mulmul
1359    %t = aten::mul(%x, %y)
1360    %p = aten::mul(%t, %z)
1361    # CHECK: my::fused_mulmul
1362    %u = aten::mul(%p, %x)
1363    %o = aten::mul(%u, %y)
1364    return (%o)"""
1365        graph = parse_ir(input_str)
1366        torch._C._jit_pass_custom_pattern_based_rewrite_graph("""
1367graph(%a, %b, %c):
1368  %q = aten::mul(%a, %b)
1369  %r = aten::mul(%q, %c)
1370  return (%r)""", """
1371graph(%a, %b, %c):
1372  %r = my::fused_mulmul(%a, %b, %c)
1373  return (%r)""", graph)
1374        FileCheck().run(input_str, graph)
1375
1376        # Check that overlapping matches are handled correctly
1377        # mul(mul(mul(x,y),z),x) --> mul(mulmul(x,y,z), x)
1378        input_str = """
1379graph(%x, %y, %z):
1380    # CHECK-NOT: aten::mul
1381    # CHECK: my::fused_mulmul
1382    %t = aten::mul(%x, %y)
1383    %p = aten::mul(%t, %z)
1384    # CHECK-NEXT: aten::mul
1385    %u = aten::mul(%p, %x)
1386    return (%u)"""
1387        graph = parse_ir(input_str)
1388        torch._C._jit_pass_custom_pattern_based_rewrite_graph("""
1389graph(%a, %b, %c):
1390  %q = aten::mul(%a, %b)
1391  %r = aten::mul(%q, %c)
1392  return (%r)""", """
1393graph(%a, %b, %c):
1394  %r = my::fused_mulmul(%a, %b, %c)
1395  return (%r)""", graph)
1396        FileCheck().run(input_str, graph)
1397
1398        # Check add(mul(x,y),z) --> muladd(x,y,z) replacement
1399        input_str = """
1400graph(%x, %y, %z):
1401    # CHECK-NOT: aten::mul
1402    # CHECK-NOT: aten::add
1403    %c = prim::Const[value=1]()
1404    %t = aten::mul(%x, %y)
1405    %p = aten::add(%t, %z, %c)
1406    # CHECK: my::muladd
1407    # CHECK-NEXT: return
1408    return (%p)"""
1409        graph = parse_ir(input_str)
1410        torch._C._jit_pass_custom_pattern_based_rewrite_graph("""
1411graph(%a, %b, %c, %d):
1412  %q = aten::mul(%a, %b)
1413  %r = aten::add(%q, %c, %d)
1414  return (%r)""", """
1415graph(%a, %b, %c, %d):
1416  %r = my::muladd(%a, %b, %c, %d)
1417  return (%r)""", graph)
1418        FileCheck().run(input_str, graph)
1419
1420        # Check add(mul(x,y),z) --> sub(add(x,y),z) replacement
1421        input_str = """
1422graph(%x, %y, %z):
1423    # CHECK-NOT: aten::mul
1424    %c = prim::Const[value=1]()
1425    # CHECK: aten::add
1426    %t = aten::mul(%x, %y)
1427    # CHECK-NEXT: aten::sub
1428    %p = aten::add(%t, %z, %c)
1429    # CHECK-NOT: aten::add
1430    # CHECK-NEXT: return
1431    return (%p)"""
1432        graph = parse_ir(input_str)
1433        torch._C._jit_pass_custom_pattern_based_rewrite_graph("""
1434graph(%a, %b, %c, %d):
1435  %q = aten::mul(%a, %b)
1436  %r = aten::add(%q, %c, %d)
1437  return (%r)""", """
1438graph(%a, %b, %c, %d):
1439  %q = aten::add(%a, %b, %d)
1440  %r = aten::sub(%q, %c, %d)
1441  return (%r)""", graph)
1442        FileCheck().run(input_str, graph)
1443
1444        # Check mul(x,y) --> x replacement
1445        input_str = """
1446graph(%x, %y, %z):
1447    %c = prim::Const[value=1]()
1448    # CHECK-NOT: aten::mul
1449    %t = aten::mul(%x, %y)
1450    # CHECK: aten::add(%x, %z
1451    %p = aten::add(%t, %z, %c)
1452    # CHECK-NEXT: return
1453    return (%p)"""
1454        graph = parse_ir(input_str)
1455        torch._C._jit_pass_custom_pattern_based_rewrite_graph("""
1456graph(%Pa, %Pb):
1457  %Pq = aten::mul(%Pa, %Pb)
1458  return (%Pq)""", """
1459graph(%Ra, %Rb):
1460  return (%Ra)""", graph)
1461        FileCheck().run(input_str, graph)
1462
1463    @_tmp_donotuse_dont_inline_everything
1464    def test_pattern_based_module_rewrite(self):
1465        # Check match::module behavior
1466        class Test(torch.nn.Module):
1467            def __init__(self) -> None:
1468                super().__init__()
1469                self.conv = torch.nn.Conv2d(1, 20, 5, 1)
1470                self.bn = torch.nn.BatchNorm2d(num_features=20)
1471
1472            def forward(self, x):
1473                x = self.conv(x)
1474                x = self.bn(x)
1475                return x
1476        m = torch.jit.script(Test())
1477        torch._C._jit_pass_custom_pattern_based_rewrite_graph("""
1478        graph(%self, %x):
1479                %conv = match::module[name="Conv2d"](%self)
1480                %y = prim::CallMethod[name="forward"](%conv, %x)
1481                %bn = match::module[name="BatchNorm2d"](%self)
1482                %z = prim::CallMethod[name="forward"](%bn, %y)
1483                return (%z)""", """
1484        graph(%self, %x):
1485          %z = my::matched_conv_bn(%self, %x)
1486          return (%z)""", m._c._get_method("forward").graph)
1487
1488        FileCheck().check("my::matched_conv_bn").run(m._c._get_method("forward").graph)
1489
1490    def test_pattern_based_rewrite_with_source_range_preserved(self):
1491        class TestModule1(torch.nn.Module):
1492            def forward(self, x, y, z, w):
1493                x = x + y
1494                x = x * z
1495                return w - x
1496
1497        input_pattern = """
1498        graph(%x, %y, %z, %const):
1499            %t = aten::add(%x, %y, %const)
1500            %o = aten::mul(%t, %z)
1501            return (%o)"""
1502        replacement_pattern = """
1503        graph(%x, %y, %z, %const):
1504            %o = my::add_mul(%x, %y, %z, %const)
1505            return (%o)"""
1506        scripted_model = torch.jit.script(TestModule1())
1507        graph = scripted_model.graph
1508        value_mappings = [("o", "t")]
1509        for node in graph.nodes():
1510            if node.kind() == "aten::add":
1511                source_range_1 = node.sourceRange()
1512        torch._C._jit_pass_custom_pattern_based_rewrite_graph(
1513            input_pattern, replacement_pattern, scripted_model.graph, value_name_pairs=value_mappings)
1514        graph = scripted_model.graph
1515        for node in graph.nodes():
1516            if node.kind() == "my::add_mul":
1517                source_range_2 = node.sourceRange()
1518        self.assertTrue(source_range_1 == source_range_2)
1519
1520        class TestModule2(torch.nn.Module):
1521            def forward(self, x, y, z, w):
1522                x = x + y
1523                x = x + z
1524                x = x * z
1525                x = x * w
1526                return x - 2
1527
1528        # Check source range preservation for two node transforms add -> my_add
1529        input_pattern = """
1530        graph(%x, %y, %const):
1531            %o = aten::add(%x, %y, %const)
1532            return (%o)"""
1533        replacement_pattern = """
1534        graph(%x, %y, %const):
1535            %o = my::add(%x, %y, %const)
1536            return (%o)"""
1537        scripted_model = copy.deepcopy(torch.jit.script(TestModule2()))
1538        graph_copy = scripted_model.graph.copy()
1539        value_mappings = [("o", "o")]
1540        source_range_add_1 = None
1541        for node in graph_copy.nodes():
1542            if source_range_add_1 is None and node.kind() == "aten::add":
1543                source_range_add_1 = node.sourceRange()
1544            if source_range_add_1 is not None and node.kind() == "aten::add":
1545                source_range_add_2 = node.sourceRange()
1546        torch._C._jit_pass_custom_pattern_based_rewrite_graph(
1547            input_pattern, replacement_pattern, graph_copy, value_name_pairs=value_mappings)
1548        source_range_my_add_1 = None
1549        for node in graph_copy.nodes():
1550            if source_range_my_add_1 is None and node.kind() == "my::add":
1551                source_range_my_add_1 = node.sourceRange()
1552            if source_range_my_add_1 is not None and node.kind() == "my::add":
1553                source_range_my_add_2 = node.sourceRange()
1554        self.assertTrue(source_range_add_1 == source_range_my_add_1)
1555        self.assertTrue(source_range_add_2 == source_range_my_add_2)
1556
1557        # Check source range preservation for add-add -> double_add transform
1558        # fuse nodes
1559        input_pattern = """
1560        graph(%x, %y, %z, %const):
1561            %t = aten::add(%x, %y, %const)
1562            %o = aten::add(%t, %z, %const)
1563            return (%o)"""
1564        replacement_pattern = """
1565        graph(%x, %y, %z, %const):
1566            %o = my::double_add(%x, %y, %z, %const)
1567            return (%o)"""
1568        scripted_model = torch.jit.script(TestModule2())
1569        graph_copy = scripted_model.graph.copy()
1570        value_mappings = [("o", "t")]
1571        source_range_1 = None
1572        source_range_2 = None
1573        for node in graph_copy.nodes():
1574            if node.kind() == "aten::add":
1575                source_range_1 = node.sourceRange()
1576                break
1577        torch._C._jit_pass_custom_pattern_based_rewrite_graph(
1578            input_pattern, replacement_pattern, graph_copy, value_name_pairs=value_mappings)
1579        for node in graph_copy.nodes():
1580            if node.kind() == "my::double_add":
1581                source_range_2 = node.sourceRange()
1582        self.assertTrue(source_range_1 == source_range_2)
1583
1584        # Check source range preservation for mul -> add + add transform
1585        # split node
1586        input_pattern = """
1587        graph(%x, %y):
1588            %t = aten::mul(%x, %y)
1589            return (%t)"""
1590        replacement_pattern = """
1591        graph(%x, %y):
1592            %t = my::add(%x, %y)
1593            %o = my::add(%t, %y)
1594            return (%o)"""
1595        scripted_model = torch.jit.script(TestModule2())
1596        graph_copy = scripted_model.graph.copy()
1597        value_mappings = [("t", "t"), ("o", "t")]
1598        source_range_mul_1 = None
1599        for node in graph_copy.nodes():
1600            if source_range_mul_1 is None and node.kind() == "aten::mul":
1601                source_range_mul_1 = node.sourceRange()
1602            if source_range_mul_1 is not None and node.kind() == "aten::mul":
1603                source_range_mul_2 = node.sourceRange()
1604        torch._C._jit_pass_custom_pattern_based_rewrite_graph(
1605            input_pattern, replacement_pattern, graph_copy, value_name_pairs=value_mappings)
1606        source_range_add_1 = None
1607        for node in graph_copy.nodes():
1608            if source_range_add_1 is None and node.kind() == "my::add":
1609                source_range_add_1 = node.sourceRange()
1610            if source_range_add_1 is not None and node.kind() == "my::add":
1611                source_range_add_2 = node.sourceRange()
1612        self.assertTrue(source_range_mul_1 == source_range_add_1)
1613        self.assertTrue(source_range_mul_2 == source_range_add_2)
1614
1615        # Check lack of source range preservation for mul-mul-> double_mul transform
1616        input_pattern = """
1617        graph(%x, %y, %z):
1618            %t = aten::mul(%x, %y)
1619            %o = aten::mul(%t, %z)
1620            return (%o)"""
1621        replacement_pattern = """
1622        graph(%x, %y, %z):
1623            %o = my::double_mul(%x, %y, %z)
1624            return (%o)"""
1625        scripted_model = torch.jit.script(TestModule2())
1626        graph_copy = scripted_model.graph.copy()
1627        for node in graph_copy.nodes():
1628            if node.kind() == "aten::mul":
1629                source_range_1 = node.sourceRange()
1630        torch._C._jit_pass_custom_pattern_based_rewrite_graph(input_pattern, replacement_pattern, graph_copy)
1631        for node in graph_copy.nodes():
1632            if node.kind() == "my::double_mul":
1633                source_range_2 = node.sourceRange()
1634        self.assertFalse(source_range_1 == source_range_2)
1635
1636    def test_expand_quantlint(self):
1637        pass
1638
1639    def test_expand_fold_quant_inputs(self):
1640        pass
1641
1642    def test_shape_analysis_broadcast(self):
1643        def broadcast(a, b):
1644            return a + b
1645
1646        x = torch.randn(3, 1, 5, requires_grad=True)
1647        y = torch.randn(4, 1, 8, 5, requires_grad=True)
1648
1649        graph = torch.jit.script(broadcast).graph
1650        torch._C._jit_pass_complete_shape_analysis(graph, (x, y), False)
1651        FileCheck().check("Float(4, 3, 8, 5, strides=[120, 40, 5, 1], device=cpu)").run(str(graph))
1652
1653    def test_shape_analysis_unsqueeze_in_loop(self):
1654        input_str = """graph(%x.1 : Tensor):
1655          %4 : bool = prim::Constant[value=1]()
1656          %1 : int = prim::Constant[value=2]()
1657          %7 : int = prim::Constant[value=0]()
1658          # CHECK: FloatTensor(requires_grad=0, device=cpu) = prim::Loop
1659          %x : Tensor = prim::Loop(%1, %4, %x.1)
1660            # CHECK: : FloatTensor(requires_grad=0, device=cpu)):
1661            block0(%i : int, %x.6 : Tensor):
1662              # CHECK: FloatTensor(requires_grad=0, device=cpu) = aten::unsqueeze
1663              %x.3 : Tensor = aten::unsqueeze(%x.6, %7)
1664              -> (%4, %x.3)
1665          return (%x)"""
1666        graph = parse_ir(input_str)
1667        torch._C._jit_pass_complete_shape_analysis(graph, (torch.zeros(2, 2, dtype=torch.float32),), False)
1668        FileCheck().run(input_str, graph)
1669
1670    def test_script_tensor_type(self):
1671        def foo(x, t: torch.dtype):
1672            return x.type(t)
1673        scr = torch.jit.script(foo)
1674        x = torch.rand(3, 4)
1675        for t in [torch.int8, torch.float64, torch.float32,
1676                  torch.bfloat16, torch.complex64, torch.complex128, torch.bool]:
1677            self.assertEqual(scr(x, t), foo(x, t))
1678
1679    def test_script_bool_literal_conversion(self):
1680        def foo(x):
1681            return torch.mul(x, True)
1682        scr = torch.jit.script(foo)
1683        x = torch.rand(3, 4)
1684        self.assertEqual(scr(x), foo(x))
1685
1686    def test_shape_analysis_masked_select(self):
1687        input_str = """graph(%0 : Float(),
1688          %1 : Bool()):
1689          # CHECK: Float(*, requires_grad=0, device=cpu) = aten::masked_select
1690          %2 : Tensor = aten::masked_select(%0, %1) # test/test_jit.py:15261:0
1691          return (%2)"""
1692        graph = parse_ir(input_str)
1693        x = torch.ones(1, dtype=torch.float32)[0]
1694        mask = x.ge(0.5)
1695        torch._C._jit_pass_complete_shape_analysis(graph, (x, mask), False)
1696        FileCheck().run(input_str, graph)
1697
1698    # TODO: update verify to work with GraphExecutors
1699    @unittest.skip("verify needs to be updated to work with GraphExecutors")
1700    def test_verify(self):
1701        x = torch.tensor([0.4], requires_grad=True)
1702        y = torch.tensor([0.7], requires_grad=True)
1703
1704        @torch.jit.compile
1705        def f(x, y):
1706            z = torch.sigmoid(x * (x + y))
1707            w = torch.abs(x * x * x + y) + Variable(torch.ones(1))
1708            return z, w
1709
1710        torch.jit.verify(f, (x, y), loss_fn=lambda z, w: z * w, devices=[])
1711
1712    # TODO: adapt to a GraphExecutor test
1713    @unittest.skip("Need to instrument GraphExecutors a bit more")
1714    def test_flags(self):
1715        x, y = torch.randn(2, 2)
1716        y = Variable(torch.randn(2, 2))
1717
1718        @torch.jit.compile
1719        def fn(x, y):
1720            return (x * x + y * y + x * y).sum()
1721
1722        grads = {}
1723        for rx, ry in product((True, False), repeat=2):
1724            x.requires_grad = rx
1725            y.requires_grad = ry
1726
1727            self.assertFalse(fn.has_trace_for(x, y))
1728            out = fn(x, y)
1729
1730            self.assertFalse(fn.has_trace_for(x, y))
1731            for v, name, compute in [(x, 'x', rx), (y, 'y', ry)]:
1732                if not compute:
1733                    continue
1734                grad_v, = torch.autograd.grad(out, v, retain_graph=True)
1735                expected_grad = grads.setdefault(name, grad_v)
1736                self.assertEqual(grad_v, expected_grad)
1737            self.assertEqual(fn.has_trace_for(x, y), rx or ry)
1738
1739    def test_python_ir(self):
1740        x = torch.tensor([0.4], requires_grad=True)
1741        y = torch.tensor([0.7], requires_grad=True)
1742
1743        def doit(x, y):
1744            return torch.sigmoid(torch.tanh(x * (x + y)))
1745
1746        g, _ = torch.jit._get_trace_graph(doit, (x, y))
1747        self.run_pass('dce', g)
1748        self.run_pass('canonicalize', g)
1749        g2 = torch._C.Graph()
1750        g_to_g2 = {}
1751        for node in g.inputs():
1752            g_to_g2[node] = g2.addInput()
1753        for node in g.nodes():
1754            n_ = g2.createClone(node, lambda x: g_to_g2[x])
1755            g2.appendNode(n_)
1756            for o, no in zip(node.outputs(), n_.outputs()):
1757                g_to_g2[o] = no
1758
1759        for node in g.outputs():
1760            g2.registerOutput(g_to_g2[node])
1761
1762        t_node = g2.create("prim::TensorTest").t_("a", torch.ones([2, 2]))
1763        self.assertEqual(t_node.attributeNames(), ["a"])
1764        g2.appendNode(t_node)
1765        self.assertTrue(torch.equal(torch.ones(2, 2), t_node.t("a")))
1766        for node in g.nodes():
1767            self.assertTrue(g2.findNode(node.kind()) is not None)
1768
1769    @unittest.skipIf(IS_SANDCASTLE, "gtest runs these in sandcastle")
1770    @unittest.skipIf(RUN_CUDA, "covered by test_cpp_cuda")
1771    @unittest.skipIf(not torch._C._jit_has_cpp_tests(), "Tests were not built, use BUILD_TEST=1")
1772    def test_cpp(self):
1773        from cpp.jit import tests_setup
1774        tests_setup.setup()
1775        torch._C._jit_run_cpp_tests()
1776        tests_setup.shutdown()
1777
1778    def test_batchnorm(self):
1779        x = torch.ones(2, 2, 2, 2)
1780        g, outputs, inputs = torch.jit._get_trace_graph(nn.BatchNorm2d(2), x,
1781                                                        _force_outplace=True, return_inputs=True)
1782        m = self.createFunctionFromGraph(g)
1783        self.assertEqual(outputs, m(*inputs))
1784
1785    def test_dropout(self):
1786        x = torch.ones(2, 2)
1787        with torch.random.fork_rng(devices=[]):
1788            g, outputs, inputs = torch.jit._get_trace_graph(nn.Dropout(0.6), x, return_inputs=True)
1789        with torch.random.fork_rng(devices=[]):
1790            m = self.createFunctionFromGraph(g)
1791            self.assertEqual(outputs, m(*inputs))
1792
1793    @unittest.skipIf(not RUN_CUDA, "test requires CUDA")
1794    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "skip if profiling isn't enabled")
1795    def test_native_dropout_corner_case(self):
1796        with disable_autodiff_subgraph_inlining():
1797            def t(x, p: float, t: bool):
1798                o = torch.dropout(x, p, t)
1799                return o
1800
1801            jit_t = torch.jit.script(t)
1802            x = torch.randn(5).requires_grad_()
1803            FileCheck().check("prim::DifferentiableGraph").run(jit_t.graph_for(x, 1.0, True, profile_and_replay=True))
1804
1805            for train in [True, False]:
1806                for p in [0.0, 1.0]:
1807                    for device in ["cuda", "cpu"]:
1808                        x = torch.randn(5).to(device=device).requires_grad_()
1809                        x_ref = x.detach().requires_grad_()
1810                        o = jit_t(x, p, train)
1811                        o_ref = t(x_ref, p, train)
1812                        o.sum().backward()
1813                        o_ref.sum().backward()
1814                        assert o.equal(o_ref)
1815                        assert x.grad.equal(x_ref.grad)
1816
1817    @slowTest
1818    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, 'Testing differentiable graph')
1819    def test_dropout_module_requires_grad(self):
1820        with enable_profiling_mode_for_profiling_tests():
1821            class MyModule(torch.nn.Module):
1822                def __init__(self, M):
1823                    super().__init__()
1824                    self.dropout = torch.nn.Dropout(0.5)
1825                    self.linear = torch.nn.Linear(M, M)
1826
1827                def forward(self, input):
1828                    input = self.dropout(input)
1829                    output = self.linear(input)
1830                    return output
1831
1832            def profile(func, X):
1833                with torch.autograd.profiler.profile() as prof:
1834                    func(X)
1835                return [e.name for e in prof.function_events]
1836
1837            M = 1000
1838            scripted = torch.jit.script(MyModule(M))
1839            # To reduce confusion about expected behaviors:
1840            #   requires_grad controls whether dropout is symbolically differentiated.
1841            #   training controls whether bernoulli_ is called inside symbolic differentiation of dropout.
1842            # * When requires_grad == training, the expected behaviors are obvious.
1843            # * When requires_grad=True and training=False, bernoulli_ might still show up in the graph.
1844            #   But it's in a branch that's not called. That's why we have separate checks for autograd
1845            #   profiler to make sure it's not run.
1846            # * When requires_grad=False and training=True, bernoulli_ must be run since it's the expected
1847            #   behavior for the dropout layer in training mode. It's independent of whether graph requires
1848            #   gradient. In fact bernoulli_ comes from autograd instead of autodiff in this case.
1849            for training in (True, False):
1850                if training:
1851                    scripted.train()
1852                else:
1853                    scripted.eval()
1854                for requires_grad in (True, False):
1855                    X = torch.randn(M, M, requires_grad=requires_grad)
1856                    if requires_grad:
1857                        FileCheck().check("aten::native_dropout").run(scripted.graph_for(X, profile_and_replay=True))
1858                    self.assertEqual(training, 'aten::bernoulli_' in profile(scripted, X))
1859
1860    @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.SIMPLE, 'Testing differentiable graph')
1861    @skipIfTorchDynamo("Torchdynamo cannot correctly handle profiler.profile calls")
1862    def test_dropout_func_requires_grad(self):
1863        def dropout_training(input):
1864            return F.dropout(input, 0.5, training=True)
1865
1866        def dropout_eval(input):
1867            return F.dropout(input, 0.5, training=False)
1868
1869        def profile(func, X):
1870            with torch.autograd.profiler.profile() as prof:
1871                func(X)
1872            return [e.name for e in prof.function_events]
1873
1874        M = 1000
1875        scripted_training = torch.jit.script(dropout_training)
1876        scripted_eval = torch.jit.script(dropout_eval)
1877        # See comments in test_dropout_module_requires_grad.
1878        with disable_autodiff_subgraph_inlining():
1879            for requires_grad in (True, False):
1880                X = torch.randn(M, M, requires_grad=requires_grad)
1881                if requires_grad:
1882                    FileCheck().check("aten::native_dropout").run(scripted_training.graph_for(X, profile_and_replay=True))
1883                self.assertIn('aten::bernoulli_', profile(scripted_training, X))
1884                self.assertNotIn('aten::bernoulli_', profile(scripted_eval, X))
1885
1886    @unittest.skipIf(not RUN_CUDA, "test_dropout_cuda require CUDA")
1887    def test_dropout_cuda(self):
1888        # Dropout AD is dispatched to _fused_dropout in CUDA case,
1889        # which is not included in TestJitGeneratedFunctional
1890        def _zero_rate(t):
1891            return torch.true_divide((t == 0).sum(), t.numel())
1892
1893        x = torch.ones(1000, 1000).cuda().requires_grad_()
1894
1895        with enable_profiling_mode_for_profiling_tests():
1896            @torch.jit.script
1897            def func(x):
1898                return torch.nn.functional.dropout(x)
1899
1900            with freeze_rng_state():
1901                out_ref = torch.nn.functional.dropout(x)
1902                grad_ref = torch.autograd.grad(out_ref.sum(), x)
1903
1904            with freeze_rng_state():
1905                out = func(x)
1906                grad = torch.autograd.grad(out.sum(), x)
1907
1908            # TODO(#40882): previously we assert exact matches between eager and JIT result:
1909            #  self.assertEqual(out, out_ref)
1910            #  self.assertEqual(grad, grad_ref)
1911            # This test was disabled during legacy -> profiling executor transition.
1912            # Currently JIT fused results doesn't match eager result exactly due to some changes merged in between.
1913            # We temporarily only check statstical difference but it should be reverted once the issue is fixed.
1914            self.assertEqual(_zero_rate(out), _zero_rate(out_ref), rtol=1e-3, atol=1e-4)
1915            self.assertEqual(_zero_rate(grad[0]), _zero_rate(grad_ref[0]), rtol=1e-3, atol=1e-4)
1916
1917    def test_torch_ops_overloaded(self):
1918        with self.assertRaisesRegex(RuntimeError, "failed to match any schema"):
1919            torch.ops.aten.add("a", 1)
1920        self.assertEqual("ab", torch.ops.aten.add("a", "b"))
1921        a, b = torch.rand(3, 4), torch.rand(3, 4)
1922        self.assertEqual(a + b, torch.ops.aten.add(a, b))
1923        self.assertEqual(a + 1, torch.ops.aten.add(a, 1))
1924
1925    def test_torch_ops_kwonly(self):
1926        a, b = torch.rand(3, 4), torch.rand(3, 4)
1927        with self.assertRaisesRegex(RuntimeError, "positional argument"):
1928            torch.ops.aten.add(a, b, 2)
1929        # h/t Chillee for this ambiguous case
1930        self.assertEqual(a.prod(1), torch.ops.aten.prod(a, 1))
1931
1932    def test_torch_complex(self):
1933        def fn(real, img):
1934            return torch.complex(real, img)
1935
1936        def fn_out(real, img, out):
1937            return torch.complex(real, img, out=out)
1938        self.checkScript(fn, (torch.rand(3, 4), torch.rand(3, 4), ))
1939        self.checkScript(fn, (torch.ones(5, 1, 4), torch.ones(5, 1, 4), ))
1940        self.checkScript(fn, (torch.zeros(1, 6), torch.ones(6, 1), ))
1941        self.checkScript(fn, (torch.zeros(1, 6), torch.zeros(6, 1), ))
1942        self.checkScript(fn, (torch.empty(3, 4), torch.empty(3, 4), ))
1943
1944        real = torch.tensor([1, 2], dtype=torch.float32)
1945        img = torch.tensor([3, 4], dtype=torch.float32)
1946        out = torch.empty([3, 4], dtype=torch.complex64)
1947        self.checkScript(fn_out, (real, img, out, ))
1948
1949        real = torch.tensor([5, 2], dtype=torch.float64)
1950        img = torch.tensor([3, 4], dtype=torch.float64)
1951        out = torch.empty([5, 2], dtype=torch.complex128)
1952        self.checkScript(fn_out, (real, img, out, ))
1953
1954        real = torch.ones([1, 2])
1955        img = torch.ones([1, 2])
1956        out = torch.empty([1, 2], dtype=torch.complex64)
1957        self.checkScript(fn_out, (real, img, out, ))
1958
1959        real = torch.ones([3, 8, 7])
1960        img = torch.ones([3, 8, 7])
1961        out = torch.empty([3, 8, 7], dtype=torch.complex64)
1962        self.checkScript(fn_out, (real, img, out, ))
1963
1964        real = torch.empty([3, 2, 6])
1965        img = torch.empty([3, 2, 6])
1966        out = torch.empty([3, 2, 6], dtype=torch.complex64)
1967        self.checkScript(fn_out, (real, img, out, ))
1968
1969        real = torch.zeros([1, 3])
1970        img = torch.empty([3, 1])
1971        out = torch.empty([3, 3], dtype=torch.complex64)
1972        self.checkScript(fn_out, (real, img, out, ))
1973
1974        real = torch.ones([2, 5])
1975        img = torch.empty([2, 1])
1976        out = torch.empty([2, 5], dtype=torch.complex64)
1977        self.checkScript(fn_out, (real, img, out, ))
1978
1979        real = torch.ones([2, 5])
1980        img = torch.zeros([2, 1])
1981        out = torch.empty([2, 5], dtype=torch.complex64)
1982        self.checkScript(fn_out, (real, img, out, ))
1983
1984    def test_einsum(self):
1985        def check(fn, jitted, *args):
1986            self.assertGraphContains(jitted.graph, kind='aten::einsum')
1987            self.assertEqual(fn(*args), jitted(*args))
1988
1989        def equation_format(x, y):
1990            return torch.einsum('i,j->ij', (x, y))
1991
1992        def equation_format_varargs(x, y):
1993            return torch.einsum('i,j->ij', x, y)
1994
1995        def sublist_format(x, y):
1996            return torch.einsum(x, [0], y, [1], [0, 1])
1997
1998        x = make_tensor((5,), dtype=torch.float32, device="cpu")
1999        y = make_tensor((10,), dtype=torch.float32, device="cpu")
2000
2001        for fn in [equation_format, equation_format_varargs, sublist_format]:
2002            check(fn, torch.jit.script(fn), x, y)
2003            check(fn, torch.jit.trace(fn, (x, y)), x, y)
2004
2005    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
2006    def test_python_ivalue(self):
2007        # Test if pure python object can be hold as IValue and conversion
2008        # between IValue and PyObject are correct
2009        # test for numpy object
2010        py_array = np.arange(15)
2011        ret_py_obj = torch._C._ivalue_debug_python_object(py_array)
2012        self.assertEqual(py_array, ret_py_obj)
2013
2014        # test for function object
2015        ret_py_obj = torch._C._ivalue_debug_python_object(F.relu)
2016        self.assertEqual(F.relu, ret_py_obj)
2017
2018        # test for memory management
2019        # we need to ensure IValue correctly call incref/decref to avoid
2020        # dangling behavior and potential memory leaks during conversions
2021        def test_func_scope_helper(inp):
2022            # create a scope and do the conversion -> ivalue -> pyobject
2023            # this func return a new pyobject that refcount + 1
2024            inp_refcount = sys.getrefcount(inp)
2025            ivalue_holder = torch._C._ivalue_debug_python_object(inp)
2026            self.assertEqual(inp_refcount + 1, sys.getrefcount(ivalue_holder))
2027            return ivalue_holder + 1
2028
2029        test_input = 2200
2030        before_count = sys.getrefcount(test_input)
2031        test_func_scope_helper(test_input)
2032        after_count = sys.getrefcount(test_input)
2033
2034        # after the test_func_scope_helper_call, the refcount of
2035        # test_input should be equal to the original refcount
2036        # otherwise we get either dangling pointer or memory leak!
2037        self.assertEqual(before_count, after_count)
2038
2039    def test_decompose_addmm(self):
2040        def does_decompose():
2041            @torch.jit.script
2042            def addmm(mat, mat1, mat2):
2043                a = mat.addmm(mat1, mat2)
2044                b = mat.addmm(mat1, mat2, alpha=1.0, beta=1.0)
2045                return a + b
2046
2047            mat = torch.randn(2, 2)
2048            mat1 = torch.randn(2, 4)
2049            mat2 = torch.randn(4, 2)
2050
2051            out_ref = addmm(mat, mat1, mat2)
2052            self.run_pass('decompose_ops', addmm.graph)
2053            out_test = addmm(mat, mat1, mat2)
2054            self.assertEqual(out_ref, out_test)
2055            FileCheck().check_not("addmm").run(str(addmm.graph))
2056
2057        def doesnt_decompose():
2058            @torch.jit.script
2059            def addmm(mat, mat1, mat2, alpha, beta):
2060                a = mat.addmm(mat1, mat2, alpha=4.20, beta=2.0)
2061                b = mat.addmm(mat1, mat2, alpha=int(alpha), beta=int(beta))
2062
2063                return a + b
2064
2065            orig = str(addmm.graph)
2066            self.run_pass('decompose_ops', addmm.graph)
2067            self.assertTrue(orig == str(addmm.graph))
2068
2069        does_decompose()
2070        doesnt_decompose()
2071
2072    @suppress_warnings
2073    def test_sparse_tensors(self):
2074        @torch.jit.ignore
2075        def get_sparse():
2076            return torch.sparse_coo_tensor((2, 3), dtype=torch.float32)
2077
2078        @torch.jit.script
2079        def test_is_sparse(input):
2080            # type: (Tensor) -> bool
2081            return input.is_sparse
2082
2083        script_out_is_sparse = test_is_sparse(get_sparse())
2084        script_out_is_dense = test_is_sparse(torch.randn(2, 3))
2085        self.assertEqual(script_out_is_sparse, True)
2086        self.assertEqual(script_out_is_dense, False)
2087
2088        def test_basic_sparse(input):
2089            output = get_sparse()
2090            return output, input
2091
2092        self.checkScript(test_basic_sparse, (get_sparse(),))
2093        self.checkScript(test_basic_sparse, (torch.tensor([1]),))
2094
2095        def test_sparse_sum(input):
2096            return torch.sparse.sum(input)
2097
2098        self.checkScript(test_sparse_sum, (get_sparse(),))
2099
2100        def test_sparse_mm(input1, input2):
2101            return torch.sparse.mm(input1, input2)
2102
2103        self.checkScript(test_sparse_mm, (get_sparse(), torch.randn(3, 4)))
2104
2105        def test_sparse_addmm(input, input1, input2):
2106            return torch.sparse.addmm(input, input1, input2)
2107
2108        def test_sparse_addmm_alpha_beta(input, input1, input2):
2109            return torch.sparse.addmm(input, input1, input2, alpha=1.3, beta=1.5)
2110
2111        self.checkScript(test_sparse_addmm, (torch.randn(2, 4), get_sparse(), torch.randn(3, 4)))
2112        self.checkScript(test_sparse_addmm_alpha_beta, (torch.randn(2, 4), get_sparse(), torch.randn(3, 4)))
2113
2114    @suppress_warnings
2115    def test_sparse_csr_tensors(self):
2116        @torch.jit.ignore
2117        def get_sparse_csr():
2118            return torch.randn(3, 3).to_sparse_csr()
2119
2120        @torch.jit.script
2121        def test_is_sparse_csr(input):
2122            # type: (Tensor) -> bool
2123            return input.is_sparse_csr
2124
2125        script_out_is_sparse_csr = test_is_sparse_csr(get_sparse_csr())
2126        script_out_is_dense_csr = test_is_sparse_csr(torch.randn(3, 3))
2127
2128        self.assertEqual(script_out_is_sparse_csr, True)
2129        self.assertEqual(script_out_is_dense_csr, False)
2130
2131    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
2132    def test_device_not_equal(self):
2133
2134        def compare_device(x: torch.device):
2135            return x != torch.device("cuda:0")
2136
2137        def compare_two_device(x: torch.device, y: torch.device):
2138            return x != y
2139
2140        self.checkScript(compare_device, (torch.device("cuda:0"),))
2141        self.checkScript(compare_two_device, (torch.device("cuda:0"), torch.device("cuda:1"), ))
2142
2143    def test_constant_prop_simple(self):
2144        @torch.jit.script
2145        def constant_prop(input_int):
2146            # type: (int) -> int
2147            a = 2 * 3
2148            b = a + 2
2149            return b - input_int
2150
2151        out_ref = constant_prop(2)
2152        self.run_pass('constant_propagation', constant_prop.graph)
2153        out_test = constant_prop(2)
2154        self.assertEqual(out_ref, out_test)
2155        graph_str = str(constant_prop.graph)
2156        self.assertTrue("aten::add" not in graph_str and "aten::mul" not in graph_str)
2157        const = constant_prop.graph.findNode("prim::Constant").output().toIValue()
2158        self.assertEqual(const, 8)
2159
2160    def test_constant_prop_nested(self):
2161        @torch.jit.script
2162        def constant_prop(a):
2163            b = 2 + 1
2164            if bool(a < 2):
2165                c = b + 2
2166            else:
2167                c = b - 2
2168            return c
2169        out_ref = constant_prop(torch.tensor(2))
2170        self.run_pass('constant_propagation', constant_prop.graph)
2171        out_test = constant_prop(torch.tensor(2))
2172        self.assertEqual(out_ref, out_test)
2173        if_node = constant_prop.graph.findNode("prim::If")
2174        for block in if_node.blocks():
2175            for node in block.nodes():
2176                self.assertTrue(node.kind() == "prim::Constant")
2177
2178    def test_constant_prop_print(self):
2179        @torch.jit.script
2180        def constant_prop(input_tensor):
2181            a = 2 * 3
2182            print(a)
2183            b = a + 2
2184            return b + input_tensor
2185
2186        self.run_pass('constant_propagation', constant_prop.graph)
2187        graph = constant_prop.graph
2188        print_node = graph.findNode("prim::Print")
2189        self.assertTrue(print_node.input().toIValue() == 6)
2190
2191    def test_constant_prop_rand(self):
2192        @torch.jit.script
2193        def constant_prop():
2194            a = torch.randn([3])
2195            b = a + 2
2196            return b
2197
2198        self.run_pass('constant_propagation', constant_prop.graph)
2199        self.assertTrue("aten::randn" in str(constant_prop.graph))
2200
2201    def test_constant_prop_none(self):
2202        @torch.jit.script
2203        def typed_none():
2204            # type: () -> Optional[int]
2205            return None
2206
2207        @torch.jit.script
2208        def constant_prop():
2209            a = typed_none()
2210            b = typed_none()
2211            if (a is None and b is None):
2212                a = 2
2213            else:
2214                a = 1
2215            return a
2216
2217        self.run_pass('constant_propagation', constant_prop.graph)
2218        FileCheck().check("prim::Constant").run(constant_prop.graph)
2219
2220    def test_constant_prop_if_inline(self):
2221        @torch.jit.script
2222        def constant_prop():
2223            cond = True
2224            a = 1
2225            if cond:
2226                a = 1 * 2
2227            else:
2228                a = 1 // 0
2229            return a
2230
2231        # testing that 1 // 0 error is not thrownn
2232        self.run_pass('constant_propagation', constant_prop.graph)
2233
2234    def test_constant_prop_exception(self):
2235        # checking y = a[4] does not error in constant propagation
2236        def bad_index(x):
2237            # type: (bool)
2238            y = 0
2239            if x:
2240                a = [1, 2, 3]
2241                y = a[4]
2242            return y
2243
2244        self.checkScript(bad_index, (False,))
2245
2246    def test_constant_prop_aliasing_type(self):
2247        @torch.jit.script
2248        def foo():
2249            return len([1]), len(torch.tensor([2]))
2250
2251        FileCheck().check_dag("aten::tensor").check_dag("aten::len").run(foo.graph)
2252
2253        @torch.jit.script
2254        def fn():
2255            if 1 == 1:
2256                return 1
2257            else:
2258                return 2
2259
2260        FileCheck().check_not("prim::If").run(fn.graph)
2261
2262    def test_unchecked_cast(self):
2263        def test(cond):
2264            # type: (bool)
2265            a = torch.tensor([10])
2266            if cond:
2267                b = None
2268            else:
2269                b = a
2270            if b is not None:
2271                b[0] = 5
2272            return a.int()
2273
2274        self.checkScript(test, (True,))
2275        self.checkScript(test, (False,))
2276
2277    def test_constant_prop_if_constant(self):
2278        @torch.jit.script
2279        def constant_prop(a, b):
2280            c0 = 1
2281            c1 = 1
2282            c2 = 1
2283            if bool(a):  # -> c0, c1
2284                if bool(b):  # -> c0
2285                    if 1 == 1:  # -> c0
2286                        c0 = c0 + 1
2287                        if 1 == 2:
2288                            c1 = c1 + 1
2289                            c2 = c2 + 1
2290            else:  # -> c0, c1
2291                c1 = c1 + 1
2292
2293            if 1 == 1:  # inlined
2294                c0 = c0 + 1  # dynamic
2295                c2 = c2 + 4  # set to 5
2296            return a + c0 + c1 + c2
2297
2298        graph = constant_prop.graph
2299        self.run_pass('constant_propagation', graph)
2300        ifs = graph.findAllNodes("prim::If", recurse=False)
2301        snd_if_inlined = len(ifs) == 1
2302        self.assertTrue(snd_if_inlined)
2303        first_if = ifs[0]
2304        self.assertTrue(first_if.outputsSize() == 2)
2305        second_if = first_if.findNode("prim::If", recurse=False)
2306        self.assertTrue(second_if.outputsSize() == 1)
2307        self.assertTrue(second_if.findNode("prim::If") is None)
2308
2309    def test_constant_prop_loop_constant(self):
2310        @torch.jit.script
2311        def constant_prop(cond, iter):
2312            # type: (bool, int) -> int
2313            b = 0
2314            while True:
2315                print("stays")
2316            for _ in range(2):
2317                print("stays")
2318            for _ in range(iter):
2319                print("stays")
2320            while cond:
2321                print("stays")
2322            while False:
2323                print("removed")
2324            for _i in range(0):
2325                print("removed")
2326            for _i in range(-4):
2327                print("removed")
2328            return b
2329
2330        self.run_pass('constant_propagation', constant_prop.graph)
2331        graph = canonical(constant_prop.graph)
2332        self.assertTrue(graph.count("removed") == 0)
2333        self.assertTrue(graph.count("stays") == 1)  # constant gets pooled
2334        self.assertTrue(graph.count("prim::Print") == 4)
2335
2336    def test_constant_prop_remove_output(self):
2337        @torch.jit.script
2338        def constant_prop(iter):
2339            # type: (int) -> None
2340            a = 1
2341            b = 1
2342            c = 1
2343            for i in range(iter):
2344                if 1 == 2:
2345                    a = 10
2346                if i == 5:
2347                    b = 2
2348                    c = 3
2349            print(a, b, c)
2350
2351        graph = constant_prop.graph
2352        self.run_pass('constant_propagation', graph)
2353        self.assertTrue(graph.findNode("prim::Loop").outputsSize() == 2)
2354
2355    # TODO(gmagogsfm): Refactor this test to reduce complexity.
2356    def test_constant_insertion(self):
2357        funcs_template = dedent('''
2358        def func():
2359            return {constant_constructor}
2360        ''')
2361
2362        # constants: primitives: int, double, bool, str, lists of primitives,
2363        # and tuples
2364        def check_constant(constant_constructor):
2365            scope = {}
2366            funcs_str = funcs_template.format(constant_constructor=constant_constructor)
2367            execWrapper(funcs_str, globals(), scope)
2368            cu = torch.jit.CompilationUnit(funcs_str)
2369            f_script = cu.func
2370            self.run_pass('constant_propagation', f_script.graph)
2371            FileCheck().check_count("prim::Constant", 1, exactly=True).run(f_script.graph)
2372            self.assertEqual(scope['func'](), f_script())
2373            imported = self.getExportImportCopy(f_script)
2374            self.assertEqual(imported(), f_script())
2375
2376        constants = ["None", "-.5", "0", "1", "True", "False", "''", "'a'", "'b'", "torch.tensor(1)",
2377                     "[True, False]", "[0., .5]", "[torch.tensor(4), torch.tensor(2)]", "[0, 1]", "['0', '1']",
2378                     "[True, None]", "[.5, None, .2]"]
2379
2380        for type in ["Tensor", "str", "int", "float", "bool"]:
2381            constants.append("torch.jit.annotate(List[ " + type + "], [])")
2382
2383        for constant in constants:
2384            check_constant(constant)
2385
2386        for key_type in ["str", "int", "float"]:
2387            for value_type in ["Tensor", "bool", "str", "int", "float"]:
2388                check_constant("torch.jit.annotate(Dict[ " + key_type + ", " + value_type + "], {})")
2389                check_constant("torch.jit.annotate(Dict[ " + key_type + ", Optional[" + value_type + "]], {})")
2390
2391        for i in range(len(constants)):
2392            for j in range(i + 1, len(constants)):
2393                tup_constant = constants[i] + ", " + constants[j]
2394                check_constant(tup_constant)
2395
2396        dict_constants = []
2397        for i in range(len(constants)):
2398            # check_constant constructs the second dict with another Tensor
2399            # which fails the comparison
2400            if not isinstance(eval(constants[i]), (str, int, float)):
2401                continue
2402            for j in range(len(constants)):
2403                dict_constant = "{ " + constants[i] + ": " + constants[j] + "}"
2404                check_constant(dict_constant)
2405                dict_constants.append(dict_constant)
2406        constants = constants + dict_constants
2407
2408        # testing node hashing
2409        funcs_template = dedent('''
2410        def func():
2411            print({constant_constructor})
2412        ''')
2413        single_elem_tuples = ("(" + x + ",)" for x in constants)
2414        input_arg = ", ".join(single_elem_tuples)
2415        scope = {}
2416        funcs_str = funcs_template.format(constant_constructor=input_arg)
2417        execWrapper(funcs_str, globals(), scope)
2418        cu = torch.jit.CompilationUnit(funcs_str)
2419        f_script = cu.func
2420        self.run_pass('constant_propagation', f_script.graph)
2421        # prim::None return adds one constant
2422        self.assertEqual(len(constants) + 1, str(f_script.graph).count("prim::Constant"))
2423        self.run_pass('cse', f_script.graph)
2424        # node hashing correctly working, no CSE occurs
2425        self.assertEqual(len(constants) + 1, str(f_script.graph).count("prim::Constant"))
2426
2427        funcs_template = dedent('''
2428        def func():
2429            a = {constant_constructor}
2430            print(a)
2431            b = {constant_constructor}
2432            print(b)
2433        ''')
2434
2435        # generate dicts with built-in types (excluding torch.Tensor)
2436        xprod = itertools.product(constants, constants)
2437
2438        # test that equal tuples and dicts correctly work with node hashing
2439        for tup in ("(" + x + ",)" for x in constants):
2440            funcs_str = funcs_template.format(constant_constructor=tup)
2441            scope = {}
2442            execWrapper(funcs_str, globals(), scope)
2443            cu = torch.jit.CompilationUnit(funcs_str)
2444            f_script = cu.func
2445            self.run_pass('constant_propagation_immutable_types', f_script.graph)
2446            num_constants = str(f_script.graph).count("prim::Constant")
2447            self.run_pass('cse', f_script.graph)
2448            FileCheck().check_count("prim::Constant", num_constants, exactly=True).run(f_script.graph)
2449
2450    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
2451    def test_cuda_export_restore(self):
2452        class Sub(torch.jit.ScriptModule):
2453            def __init__(self) -> None:
2454                super().__init__()
2455                self.weight = nn.Parameter(torch.randn(3, 4))
2456
2457            @torch.jit.script_method
2458            def forward(self, thing):
2459                return self.weight + thing
2460
2461        class M(torch.jit.ScriptModule):
2462            def __init__(self) -> None:
2463                super().__init__()
2464                self.mod = Sub()
2465
2466            @torch.jit.script_method
2467            def forward(self, v):
2468                return self.mod(v)
2469        m = M()
2470        m.cuda()
2471        m2 = self.getExportImportCopy(m)
2472        m2.cuda()
2473        input = torch.rand(3, 4).cuda()
2474        self.assertEqual(m(input), m2(input))
2475
2476    @slowTest
2477    def test_export_batchnorm(self):
2478        for mode in ['eval', 'train']:
2479            for clazz in [
2480                    torch.nn.BatchNorm1d(100),
2481                    torch.nn.BatchNorm1d(100, affine=False),
2482                    torch.nn.BatchNorm2d(100),
2483                    torch.nn.BatchNorm2d(100, affine=False)]:
2484                getattr(clazz, mode)()
2485                input = torch.randn(20, 100) if isinstance(clazz, torch.nn.BatchNorm1d) else \
2486                    torch.randn(20, 100, 35, 45)
2487                traced = torch.jit.trace(clazz, (input,))
2488                imported = self.getExportImportCopy(traced)
2489                x = torch.randn(20, 100) if isinstance(clazz, torch.nn.BatchNorm1d) else \
2490                    torch.randn(20, 100, 35, 45)
2491                self.assertEqual(traced(x), imported(x))
2492
2493    def test_export_rnn(self):
2494        for clazz in [nn.RNN(10, 20, 2), nn.GRU(10, 20, 2)]:
2495            class RNNTest(torch.nn.Module):
2496                def __init__(self) -> None:
2497                    super().__init__()
2498                    self.rnn = clazz
2499
2500                def forward(self, x, lengths, h0):
2501                    packed = torch.nn.utils.rnn.pack_padded_sequence(x, lengths)
2502                    out, h = self.rnn(packed, h0)
2503                    padded_outs, _ = torch.nn.utils.rnn.pad_packed_sequence(out)
2504                    return padded_outs
2505
2506            test = RNNTest()
2507
2508            traced = torch.jit.trace(test, (torch.randn(5, 3, 10), torch.LongTensor([3, 2, 1]), torch.randn(2, 3, 20)))
2509            imported = self.getExportImportCopy(traced)
2510            # NB: We make sure to pass in a batch with a different max sequence
2511            # length to ensure that the argument stashing for pad_packed works
2512            # properly.
2513            x, lengths, h0 = torch.randn(7, 4, 10), torch.LongTensor([7, 3, 2, 1]), torch.randn(2, 4, 20)
2514            self.assertEqual(traced(x, lengths, h0), imported(x, lengths, h0))
2515
2516    def test_export_lstm(self):
2517        class LSTMTest(torch.nn.Module):
2518            def __init__(self) -> None:
2519                super().__init__()
2520                self.rnn = nn.LSTM(10, 20, 2)
2521
2522            def forward(self, x, lengths, hiddens):
2523                h0, c0 = hiddens
2524                packed = torch.nn.utils.rnn.pack_padded_sequence(x, lengths)
2525                out, (h, c) = self.rnn(packed, (h0, c0))
2526                padded_outs, _ = torch.nn.utils.rnn.pad_packed_sequence(out)
2527                return padded_outs
2528
2529        test = LSTMTest()
2530
2531        traced = torch.jit.trace(test, (torch.randn(5, 3, 10),
2532                                        torch.LongTensor([3, 2, 1]),
2533                                        (torch.randn(2, 3, 20), torch.randn(2, 3, 20))))
2534        imported = self.getExportImportCopy(traced)
2535        x, lengths, h0, c0 = \
2536            torch.randn(7, 3, 10), torch.LongTensor([7, 5, 2]), torch.randn(2, 3, 20), torch.randn(2, 3, 20)
2537        self.assertEqual(traced(x, lengths, (h0, c0)), imported(x, lengths, (h0, c0)))
2538
2539    def test_unique_state_dict(self):
2540        class MyModule(torch.nn.Module):
2541            def __init__(self) -> None:
2542                super().__init__()
2543                shared_param = torch.nn.Parameter(torch.ones(1))
2544                self.register_parameter('w1', shared_param)
2545                self.register_parameter('w2', shared_param)
2546
2547            def forward(self, input):
2548                return input + self.w1 + self.w2
2549
2550        model = MyModule()
2551        unittest.TestCase.assertEqual(
2552            self, len(torch.jit._unique_state_dict(model, keep_vars=False)), 1)
2553        unittest.TestCase.assertEqual(
2554            self, len(torch.jit._unique_state_dict(model, keep_vars=True)), 1)
2555
2556    def test_export_dropout(self):
2557        test = torch.nn.Dropout()
2558        test.eval()
2559
2560        traced = torch.jit.trace(test, (torch.rand(3, 4),), check_trace=False)
2561        imported = self.getExportImportCopy(traced)
2562        x = torch.randn(3, 4)
2563        self.assertEqual(traced(x), imported(x))
2564
2565    def test_pretty_printer(self):
2566        @torch.jit.script
2567        def if_test(a, b):
2568            # FIXME: use 0 instead of a.
2569            # c = 0
2570            c = a
2571            if bool(a < b):
2572                c = b
2573            else:
2574                c = a
2575            return c
2576
2577        @torch.jit.script
2578        def if_one(a, b):
2579            c = b
2580            if bool(a < b):
2581                c = a
2582            return c
2583
2584        @torch.jit.script
2585        def while_test(a, i):
2586            while bool(i < 3):
2587                a *= a
2588                i += 1
2589            return a
2590
2591        @torch.jit.script
2592        def while_if_test(a, b):
2593            c = 0
2594            while bool(a < 10):
2595                a = a + 1
2596                b = b + 1
2597                if bool(a > b):
2598                    c = 2
2599                else:
2600                    c = 3
2601            return a + 1 + c
2602
2603        @torch.jit.script
2604        def loop_use_test(y):
2605            x = y + 1
2606            z = x + 5
2607            while bool(y < 8):
2608                y += 1
2609                z = x
2610            return x, z
2611
2612        @torch.jit.ignore
2613        def python_fn(x):
2614            return x + 10
2615
2616        @torch.jit.script
2617        def python_op_name_test(y):
2618            return python_fn(y)
2619
2620        @torch.jit.script
2621        def empty_int_list_test(y):
2622            x = torch.jit.annotate(List[int], [])
2623            return x[0]
2624
2625        @torch.jit.script
2626        def empty_float_list_test(y):
2627            return [1.0, 2.0, 3.0]
2628
2629        @torch.jit.script
2630        def print_weird_test(y):
2631            print("hi\016")
2632
2633        self.assertExpected(if_test.code, "if_test")
2634        self.assertExpected(if_one.code, "if_one")
2635        self.assertExpected(while_test.code, "while_test")
2636        self.assertExpected(while_if_test.code, "while_if_test")
2637        self.assertExpected(loop_use_test.code, "loop_use_test")
2638        self.assertExpected(python_op_name_test.code, "python_op_name_test")
2639        self.assertExpected(empty_int_list_test.code, "empty_int_list_test")
2640        self.assertExpected(empty_float_list_test.code, "empty_float_list_test")
2641        self.assertExpected(print_weird_test.code, "print_weird_test")
2642
2643    def test_cu_escaped_number(self):
2644        cu = torch.jit.CompilationUnit('''
2645            def foo(a):
2646                print("hi\016")
2647        ''')
2648        self.assertExpected(cu.foo.code)
2649
2650    def test_import_method(self):
2651        with torch._jit_internal._disable_emit_hooks():
2652            class Foo(torch.jit.ScriptModule):
2653                @torch.jit.script_method
2654                def forward(self, x, y):
2655                    return 2 * x + y
2656
2657            foo = Foo()
2658            buffer = io.BytesIO()
2659            torch.jit.save(foo, buffer)
2660
2661            buffer.seek(0)
2662            foo_loaded = torch.jit.load(buffer)
2663            self.assertExpected(foo_loaded.forward.code)
2664
2665    @unittest.skip("temporarily disable the test for fwd compatibility")
2666    def test_non_ascii_string(self):
2667        class Foo(torch.jit.ScriptModule):
2668            def __init__(self) -> None:
2669                super().__init__()
2670                self.a = "Over \u0e55\u0e57 57"
2671
2672            @torch.jit.script_method
2673            def forward(self, x, y):
2674                return self.a + "hi\xA1"
2675
2676        foo = Foo()
2677        buffer = io.BytesIO()
2678        torch.jit.save(foo, buffer)
2679
2680        buffer.seek(0)
2681        foo_loaded = torch.jit.load(buffer)
2682        self.assertExpected(foo_loaded.forward.code)
2683
2684    def test_function_default_values(self):
2685        outer_var = torch.tensor(20)
2686        outer_var2 = torch.tensor(30)
2687        a = torch.tensor(0.5)
2688        b = torch.tensor(10)
2689
2690        @torch.jit.script
2691        def simple_fn(x, a=a, b=b, c=outer_var + outer_var2):
2692            return x + a + b + c
2693
2694        self.assertEqual(
2695            simple_fn(torch.ones(1)),
2696            torch.ones(1) + 0.5 + 10 + (20 + 30))
2697        self.assertEqual(
2698            simple_fn(torch.ones(1), torch.tensor(1), torch.tensor(3), torch.tensor(4)),
2699            torch.ones(1) + 1 + 3 + 4)
2700
2701        outer_c = torch.tensor(9)
2702        outer_flag = torch.tensor(False)
2703
2704        @torch.jit.script
2705        def bool_fn(x, a=outer_c, flag=outer_flag):
2706            if bool(flag):
2707                result = x
2708            else:
2709                result = x + a
2710            return result
2711
2712        self.assertEqual(bool_fn(torch.ones(1)), torch.ones(1) + 9)
2713        self.assertEqual(
2714            bool_fn(torch.ones(1), torch.tensor(1), torch.tensor(True)),
2715            torch.ones(1))
2716
2717        @torch.jit.script
2718        def none_fn(x=None):
2719            # type: (Optional[int]) -> Optional[int]
2720            return x
2721
2722        self.assertEqual(none_fn(), None)
2723        self.assertEqual(none_fn(1), 1)
2724
2725        @torch.jit.script
2726        def hints(x, a=0.5, b=10):
2727            # type: (Tensor, float, int) -> Tensor
2728            return x + a + b
2729
2730        self.assertEqual(hints(torch.ones(1)), torch.ones(1) + 0.5 + 10)
2731
2732        with self.assertRaisesRegex(RuntimeError, "Expected a default value"):
2733
2734            @torch.jit.script
2735            def hints_bad_types(x, a=10, b=0.5):  # noqa: T484
2736                # type: (Tensor, float, int) -> Tensor
2737                return x + a + b
2738        with self.assertRaisesRegex(RuntimeError, "Expected a default value"):
2739            @torch.jit.script
2740            def bad_no_optional(x=None):
2741                # type: (Dict[str, int]) -> Dict[str, int]
2742                return x
2743
2744
2745    def test_module_default_values(self):
2746        four = torch.tensor(4)
2747
2748        class Test(torch.jit.ScriptModule):
2749            @torch.jit.script_method
2750            def forward(self, input, other=four):
2751                return input + other
2752
2753        t = Test()
2754        self.assertEqual(t(torch.ones(1)), torch.ones(1) + 4)
2755
2756    def test_mutable_default_values(self):
2757        with self.assertRaisesRegex(Exception, "Mutable default parameters"):
2758            @torch.jit.script
2759            def foo(x=(1, [])):
2760                # type: (Tuple[int, List[Tensor]])
2761                return x
2762
2763        class Test(torch.nn.Module):
2764            def forward(self, input=[]):  # noqa: B006
2765                return input
2766
2767        with self.assertRaisesRegex(Exception, "Mutable default parameters"):
2768            torch.jit.script(Test())
2769
2770    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
2771    def test_warnings(self):
2772        import warnings
2773
2774        def fn(x):
2775            if bool(x < 2):
2776                warnings.warn("x is less than 2")
2777            return x
2778
2779        class M(torch.nn.Module):
2780            def forward(self, x):
2781                if bool(x < 2):
2782                    warnings.warn("x is less than 2")
2783                return x
2784
2785
2786        scripted_mod = torch.jit.script(M())
2787        scripted_fn = torch.jit.script(fn)
2788
2789        with warnings.catch_warnings(record=True) as warns:
2790            fn(torch.ones(1))
2791
2792        with warnings.catch_warnings(record=True) as script_warns:
2793            scripted_fn(torch.ones(1))
2794
2795        with warnings.catch_warnings(record=True) as script_mod_warns:
2796            scripted_mod(torch.ones(1))
2797
2798        self.assertEqual(str(warns[0]), str(script_warns[0]))
2799        self.assertEqual(len(script_mod_warns), 1)
2800        self.assertEqual(str(warns[0].message), str(script_mod_warns[0].message))
2801
2802    def test_no_erroneous_warnings(self):
2803        import warnings
2804
2805        def fn(x):
2806            if bool(x > 0):
2807                warnings.warn('This should NOT be printed')
2808                x += 1
2809            return x
2810
2811        with warnings.catch_warnings(record=True) as warns:
2812            fn_script = torch.jit.script(fn)
2813            fn_script(torch.tensor(0))
2814        warns = [str(w.message) for w in warns]
2815        self.assertEqual(len(warns), 0)
2816
2817    @unittest.skipIf(True, "TODO: re-enable with https://github.com/pytorch/pytorch/pull/29339")
2818    def test_torch_load_error(self):
2819        class J(torch.jit.ScriptModule):
2820            @torch.jit.script_method
2821            def forward(self, input):
2822                return input + 100
2823
2824        j = J()
2825        with TemporaryFileName() as fname:
2826            j.save(fname)
2827            with self.assertRaisesRegex(RuntimeError, "is a zip"):
2828                torch.load(fname)
2829
2830    def test_torch_load_zipfile_check(self):
2831        @torch.jit.script
2832        def fn(x):
2833            return x + 10
2834
2835        with TemporaryFileName() as fname:
2836            fn.save(fname)
2837            with open(fname, 'rb') as f:
2838                self.assertTrue(torch.serialization._is_zipfile(f))
2839
2840    def test_python_bindings(self):
2841        lstm_cell = torch.jit.script(LSTMCellS)
2842
2843        def lstm(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
2844            for i in range(x.size(0)):
2845                hx, cx = lstm_cell(x[i], hx, cx, w_ih, w_hh, b_ih, b_hh)
2846            return hx
2847
2848        slstm = torch.jit.script(lstm)
2849
2850        inputs = get_lstm_inputs('cpu', training=True, seq_length=10)
2851        slstm(*inputs).sum().backward()
2852        global fw_graph
2853        fw_graph = slstm.graph_for(*inputs)
2854        nodes = list(fw_graph.nodes())
2855        tested_blocks = False
2856        for node in nodes:
2857            for output in node.outputs():
2858                self.assertTrue(hasattr(output, 'type'))
2859                self.assertTrue(output.type() is not None)
2860            for input in node.inputs():
2861                self.assertTrue(hasattr(input, 'type'))
2862                self.assertTrue(input.type() is not None)
2863            for block in node.blocks():
2864                tested_blocks = True
2865                self.assertTrue(hasattr(block, 'inputs'))
2866                self.assertTrue(hasattr(block, 'outputs'))
2867                for output in block.outputs():
2868                    self.assertTrue(hasattr(output, 'type'))
2869                    self.assertTrue(output.type() is not None)
2870                for input in block.inputs():
2871                    self.assertTrue(hasattr(input, 'type'))
2872                    self.assertTrue(input.type() is not None)
2873                self.assertTrue(hasattr(block, 'returnNode'))
2874                self.assertTrue(type(block.returnNode()) == torch._C.Node)
2875                self.assertTrue(hasattr(block, 'paramNode'))
2876                self.assertTrue(type(block.paramNode()) == torch._C.Node)
2877        self.assertTrue(tested_blocks)
2878
2879    def test_export_opnames(self):
2880        class Foo(torch.jit.ScriptModule):
2881            def one(self, x, y):
2882                # type: (Tensor, Tensor) -> Tensor
2883                return x + y
2884
2885            def two(self, x):
2886                # type: (Tensor) -> Tensor
2887                return 2 * x
2888
2889            @torch.jit.script_method
2890            def forward(self, x):
2891                # type: (Tensor) -> Tensor
2892                return self.one(self.two(x), x)
2893
2894        class Bar(torch.jit.ScriptModule):
2895            def __init__(self) -> None:
2896                super().__init__()
2897                self.sub = Foo()
2898
2899            @torch.jit.script_method
2900            def forward(self, x):
2901                # type: (Tensor) -> Tensor
2902                return self.sub.forward(x)
2903
2904        bar = Bar()
2905        ops = torch.jit.export_opnames(bar)
2906        expected = ['aten::add.Tensor', 'aten::mul.Scalar']
2907        self.assertTrue(set(expected).issubset(set(ops)))
2908
2909    def test_pytorch_jit_env_off(self):
2910        import subprocess
2911        env = os.environ.copy()
2912        env['PYTORCH_JIT'] = '0'
2913        try:
2914            subprocess.check_output([sys.executable, '-c', 'import torch'], env=env)
2915        except subprocess.CalledProcessError as e:
2916            raise RuntimeError("Could not 'import torch' with PYTORCH_JIT=0") from e
2917
2918    def test_print_op_module(self):
2919        # Issue #19351: python2 and python3 go through different paths.
2920        # python2 returns '<module 'torch.ops' (built-in)>'
2921        # python3 uses __file__ and return
2922        # '<module 'torch.ops' from '/scratch/ailzhang/pytorch/torch/_ops.py'>'
2923        s = str(torch.ops)
2924        self.assertRegex(s, r'ops')
2925
2926    def test_print_classes_module(self):
2927        s = str(torch.classes)
2928        self.assertRegex(s, r'classes')
2929
2930    def test_print_torch_ops_modules(self):
2931        s = str(torch._ops.ops.quantized)
2932        self.assertRegex(s, r'torch.ops')
2933        s = str(torch._ops.ops.atan)
2934        self.assertRegex(s, r'torch.ops')
2935
2936    def test_hide_source_ranges_context_manager(self):
2937        @torch.jit.script
2938        def foo(x):
2939            return torch.add(x, x)
2940
2941        graph = foo.graph
2942        source_range_regex = "# .*\\.py"
2943        self.assertRegex(graph.__repr__(), source_range_regex)
2944        with torch.jit._hide_source_ranges():
2945            self.assertNotRegex(graph.__repr__(), source_range_regex)
2946            self.assertRegex(graph.str(print_source_ranges=True), source_range_regex)
2947        self.assertRegex(graph.__repr__(), source_range_regex)
2948
2949
2950class TestFrontend(JitTestCase):
2951
2952    def test_instancing_error(self):
2953        @torch.jit.ignore
2954        class MyScriptClass:
2955            def unscriptable(self):
2956                return "a" + 200
2957
2958
2959        class TestModule(torch.nn.Module):
2960            def forward(self, x):
2961                return MyScriptClass()
2962
2963        with self.assertRaises(torch.jit.frontend.FrontendError) as cm:
2964            torch.jit.script(TestModule())
2965
2966        checker = FileCheck()
2967        checker.check("Cannot instantiate class")
2968        checker.check("def forward")
2969        checker.run(str(cm.exception))
2970
2971    def test_dictionary_as_example_inputs_for_jit_trace(self):
2972        class TestModule_v1(torch.nn.Module):
2973            def forward(self, key2=None, key3=None, key4=None, key5=None, key1=None, key6=None):
2974                return key1 + key2 + key3
2975
2976        class TestModule_v2(torch.nn.Module):
2977            def forward(self, x, y):
2978                return x + y
2979
2980        def test_func(x, y):
2981            return x + y
2982        model_1 = TestModule_v1()
2983        model_2 = TestModule_v2()
2984        value1 = torch.ones(1)
2985        value2 = torch.ones(1)
2986        value3 = torch.ones(1)
2987        example_input_dict = {'key1': value1, 'key2': value2, 'key3': value3}
2988        example_input_dict_func = {'x': value1, 'y': value2}
2989        traced_model_1 = torch.jit.trace(model_1, example_kwarg_inputs=example_input_dict, strict=False)
2990        traced_model_1_m = torch.jit.trace_module(
2991            model_1, {'forward': example_input_dict}, example_inputs_is_kwarg=True, strict=False)
2992        traced_model_2 = torch.jit.trace(model_2, example_kwarg_inputs={'x': torch.rand([2]), 'y': torch.rand([2])})
2993        traced_func = torch.jit.trace(test_func, example_kwarg_inputs=example_input_dict_func, strict=False)
2994        res_1 = traced_model_1(**example_input_dict)
2995        res_1_m = traced_model_1_m(**example_input_dict)
2996        self.assertEqual(res_1, 3 * torch.ones(1))
2997        self.assertEqual(res_1_m, 3 * torch.ones(1))
2998        res_func = traced_func(**example_input_dict_func)
2999        self.assertEqual(res_func, 2 * torch.ones(1))
3000        with self.assertRaisesRegex(RuntimeError, r"forward\(\) is missing value for argument 'x'."):
3001            res_2 = traced_model_2(**{'z': torch.rand([2]), 'y': torch.rand([2])})  # noqa: PIE804
3002        with self.assertRaisesRegex(RuntimeError, r"forward\(\) is missing value for argument 'y'."):
3003            res_2 = traced_model_2(**{'x': torch.rand([2]), 'z': torch.rand([2])})  # noqa: PIE804
3004
3005
3006class TestScript(JitTestCase):
3007
3008    # Tests that calling torch.jit.script repeated on function is allowed.
3009    def test_repeated_script_on_function(self):
3010        @torch.jit.script
3011        @torch.jit.script
3012        def fn(x):
3013            return x
3014
3015        torch.jit.script(torch.jit.script(fn))
3016
3017    def test_pretty_print_function(self):
3018        @torch.jit.script
3019        def foo(x):
3020            return torch.nn.functional.interpolate(x)
3021
3022        FileCheck().check("interpolate").run(foo.code)
3023
3024    def test_inlined_graph(self):
3025        """
3026        Check that the `inlined_graph` property correctly returns an inlined
3027        graph, both through function calls and method calls.
3028        """
3029        @torch.jit.script
3030        def foo(x):
3031            return torch.add(x, x)
3032
3033        class MyNestedMod(torch.nn.Module):
3034            def forward(self, x):
3035                return torch.sub(x, x)
3036
3037
3038        class MyMod(torch.nn.Module):
3039            def __init__(self) -> None:
3040                super().__init__()
3041                self.nested = MyNestedMod()
3042
3043            def forward(self, x):
3044                x = self.nested(x)  # sub
3045                x = foo(x)  # add
3046                return torch.mul(x, x)
3047
3048        m = torch.jit.script(MyMod())
3049        FileCheck().check("aten::sub") \
3050            .check("aten::add") \
3051            .check("aten::mul") \
3052            .run(m.inlined_graph)
3053
3054    def test_static_method_on_module(self):
3055        """
3056        Check that the `@staticmethod` annotation on a function on a module works.
3057        """
3058        class MyCell(torch.nn.Module):
3059            @staticmethod
3060            def do_it(x, h):
3061                new_h = torch.tanh(x + h)
3062                return new_h, new_h
3063
3064            def forward(self, x, h):
3065                return self.do_it(x, h)
3066
3067        my_cell = torch.jit.script(MyCell())
3068        x = torch.rand(3, 4)
3069        h = torch.rand(3, 4)
3070        jitted_cell = my_cell(x, h)
3071        non_jitted_cell = MyCell().do_it(x, h)
3072
3073        self.assertEqual(jitted_cell, non_jitted_cell)
3074
3075    def test_code_with_constants(self):
3076        """
3077        Check that the `code_with_constants` property correctly returns graph CONSTANTS in the
3078        CONSTANTS.cN format used in the output of the `code` property.
3079        """
3080        @torch.jit.script
3081        def foo(x=torch.ones(1)):
3082            return x
3083
3084        class Moddy(torch.nn.Module):
3085            def forward(self, x):
3086                return foo()
3087
3088        m = torch.jit.script(Moddy())
3089        src, CONSTANTS = m.code_with_constants
3090
3091        self.assertEqual(CONSTANTS.c0, torch.ones(1))
3092        self.assertEqual(src, m.code)
3093
3094    def test_code_with_constants_restore(self):
3095        """
3096        Check that the `code_with_constants` property correctly works on restoration after save() + load()
3097        """
3098        @torch.jit.script
3099        def foo(x=torch.ones(1)):
3100            return x
3101
3102        class Moddy(torch.nn.Module):
3103            def forward(self, x):
3104                return foo()
3105
3106        m = torch.jit.script(Moddy())
3107        src, CONSTANTS = m.code_with_constants
3108        eic = self.getExportImportCopy(m)
3109
3110        src_eic, CONSTANTS_eic = eic.code_with_constants
3111
3112        self.assertEqual(src, src_eic)
3113        self.assertEqual(CONSTANTS.c0, CONSTANTS_eic.c0)
3114
3115
3116    def test_oneline_func(self):
3117        def fn(x): return x  # noqa: E704
3118
3119        self.checkScript(fn, (torch.ones(2, 2), ))
3120
3121    def test_request_bailout(self):
3122        with enable_profiling_mode_for_profiling_tests():
3123
3124            def fct_loop(x):
3125                for i in range(3):
3126                    x = torch.cat((x, x), 0)
3127                return x
3128
3129            x = torch.ones(2, 3, 4, dtype=torch.float32)
3130            expected = fct_loop(x)
3131            jitted = torch.jit.script(fct_loop)
3132            # profile
3133            jitted(x)
3134            # optimize
3135            jitted(x)
3136            dstate = jitted.get_debug_state()
3137            eplan = get_execution_plan(dstate)
3138            num_bailouts = eplan.code.num_bailouts()
3139
3140            for i in range(0, num_bailouts):
3141                eplan.code.request_bailout(i)
3142                self.assertEqual(jitted(x), expected)
3143
3144    @unittest.skip("bailouts are being deprecated")
3145    def test_dominated_bailout(self):
3146        with enable_profiling_mode_for_profiling_tests():
3147            # functional dominated guard
3148            @torch.jit.script
3149            def foo(x):
3150                dim = x.dim()
3151                if dim == 0:
3152                    y = int(x)
3153                else:
3154                    y = x.size()[dim - 1]
3155                return y
3156
3157            x = torch.zeros(2)
3158            self.assertEqual(foo(x), 2)
3159            self.assertEqual(foo(x), 2)
3160            g = torch.jit.last_executed_optimized_graph()
3161            g_s = str(g)
3162            g_s = g_s[0:g_s.find("return")]
3163            FileCheck().check_count("prim::BailOut[", 1, exactly=True).run(g_s)
3164
3165            # dominated guard of non-functional value
3166            @torch.jit.script
3167            def foo(x):
3168                dim = x.dim()
3169                x.add_(3)
3170                if dim == 0:
3171                    return 0
3172                else:
3173                    return x.size()[dim - 1]
3174
3175            x = torch.zeros(2)
3176            self.assertEqual(foo(x), 2)
3177            self.assertEqual(foo(x), 2)
3178            g = torch.jit.last_executed_optimized_graph()
3179            FileCheck().check("prim::BailOut[").check("aten::add_").check_next("prim::BailOut[").check("return").run(g)
3180
3181            with torch.enable_grad():
3182                @torch.jit.ignore
3183                def disable_grad():
3184                    torch.set_grad_enabled(False)
3185
3186                @torch.jit.ignore
3187                def enable_grad():
3188                    torch.set_grad_enabled(True)
3189
3190                @torch.jit.script
3191                def foo(x):
3192                    x = x + 1
3193                    dim = x.dim()
3194                    disable_grad()
3195                    if dim == 0:
3196                        y = int(x)
3197                    else:
3198                        y = x.size()[dim - 1]
3199                    enable_grad()
3200                    return y
3201
3202                x = torch.zeros(2, requires_grad=True)
3203                self.assertEqual(foo(x), 2)
3204                self.assertEqual(foo(x), 2)
3205                g = torch.jit.last_executed_optimized_graph()
3206                # there should still be a Bailout after disable_grad call
3207                FileCheck().check("disable_grad").check("BailOut[").check("BailoutTemplate").run(g)
3208
3209    @skipIfTorchDynamo("Torchdynamo cannot correctly handle profiler.profile calls")
3210    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "skip if profiling isn't enabled")
3211    def test_profiling_merge(self):
3212        @torch.jit.script
3213        def test_not_const(x):
3214            if x.size(0) == 1:
3215                return 1
3216            else:
3217                return 2
3218
3219        with enable_profiling_mode_for_profiling_tests():
3220            with num_profiled_runs(2):
3221                test_not_const(torch.rand([1, 2]))
3222                test_not_const(torch.rand([2, 2]))
3223
3224                graph_str = torch.jit.last_executed_optimized_graph()
3225                FileCheck().check("profiled_type=Float(*, 2, strides=[2, 1], requires_grad=0, device=cpu").run(graph_str)
3226                FileCheck().check_not("profiled_type=Float(1, 2, strides=[2, 1], requires_grad=0, device=cpu").run(graph_str)
3227
3228
3229    def test_nested_bailouts(self):
3230        @torch.jit.script
3231        def fct_loop(x):
3232            for i in range(3):
3233                x = torch.cat((x, x), 0)
3234            return x
3235
3236        x = torch.ones(2, 3, 4, dtype=torch.float32)
3237        out = fct_loop(x)
3238        jit_trace = torch.jit.trace(fct_loop, x)
3239        out_trace = jit_trace(x)
3240
3241    def test_no_self_arg_ignore_function(self):
3242        class MyModule(nn.Module):
3243            @torch.jit.ignore  # noqa: B902
3244            def call_np():  # noqa: B902
3245                # type: () -> int
3246                return np.random.choice(2, p=[.95, .05])
3247
3248            def forward(self):
3249                return self.call_np()
3250
3251        with self.assertRaisesRegex(Exception, "does not have a self argument"):
3252            torch.jit.script(MyModule())
3253
3254    def test_loop_liveness(self):
3255        with enable_profiling_mode_for_profiling_tests():
3256            @torch.jit.script
3257            def f(i):
3258                # type: (int) -> Tensor
3259                l = []
3260                for n in [2, 1]:
3261                    l.append(torch.zeros(n, i))
3262
3263                return l[0]
3264
3265            f(2)
3266            f(1)
3267
3268    def test_bailout_loop_carried_deps_name_clash(self):
3269        with enable_profiling_mode_for_profiling_tests():
3270            NUM_ITERATIONS = 10
3271
3272            @torch.jit.script
3273            def fct_loop(z, size):
3274                # type: (int, int) -> Tuple[Tensor, List[int]]
3275                counters = torch.jit.annotate(List[int], [])
3276                j = 0
3277                y = torch.ones(2)
3278                for i in range(size):
3279                    counters.append(i + j)
3280                    y = torch.cat((y, torch.ones(z)), 0)
3281                    j = j + 1
3282                return y, counters
3283
3284            inputs = [1, 2, 3, 4]
3285            expected = [x * 2 for x in range(NUM_ITERATIONS)]
3286            for inp in inputs:
3287                results = fct_loop(inp, NUM_ITERATIONS)
3288                self.assertEqual(results[1], expected)
3289
3290    def test_bailout_loop_counter_transition(self):
3291        with enable_profiling_mode_for_profiling_tests():
3292            NUM_ITERATIONS = 10
3293
3294            @torch.jit.script
3295            def fct_loop(z, size):
3296                # type: (int, int) -> Tuple[Tensor, List[int]]
3297                counters = torch.jit.annotate(List[int], [])
3298                y = torch.ones(2)
3299                for i in range(size):
3300                    counters.append(i)
3301                    y = torch.cat((y, torch.ones(z)), 0)
3302                return y, counters
3303
3304            inputs = [1, 2, 3, 4]
3305            expected = list(range(NUM_ITERATIONS))
3306            for inp in inputs:
3307                results = fct_loop(inp, NUM_ITERATIONS)
3308                self.assertEqual(results[1], expected)
3309
3310    def test_ignored_method_binding(self):
3311        class Bar(torch.nn.Module):
3312            def __init__(self) -> None:
3313                super().__init__()
3314                self.x : int = 0
3315
3316            @torch.jit.export
3317            def setx(self, x : int):
3318                self.x = x
3319
3320            @torch.jit.export
3321            def getx(self):
3322                return self.x
3323
3324            @torch.jit.ignore
3325            def ignored_getx(self):
3326                return self.x
3327
3328        b = Bar()
3329        b.setx(123)
3330        sb = torch.jit.script(b)
3331        self.assertEqual(sb.getx(), 123)
3332        self.assertEqual(sb.ignored_getx(), 123)
3333
3334        sb.setx(456)
3335        self.assertEqual(sb.getx(), 456)
3336        self.assertEqual(sb.ignored_getx(), 456)
3337
3338    def test_set_attribute_through_optional(self):
3339        class A(torch.nn.Module):
3340            __annotations__ = {"x": Optional[torch.Tensor]}
3341
3342            def __init__(self) -> None:
3343                super().__init__()
3344                self.x = None
3345
3346            @torch.jit.ignore
3347            def foo(self):
3348                if self.x is None:
3349                    self.x = torch.tensor([3])
3350                return self.x
3351
3352            def forward(self, x):
3353                a = self.foo()
3354                return x + 1
3355
3356        m = torch.jit.script(A())
3357        self.assertEqual(m.x, None)
3358        m(torch.rand(1))
3359        self.assertEqual(m.x, torch.tensor([3]))
3360
3361    def test_mutate_constant(self):
3362        class M(torch.jit.ScriptModule):
3363            __constants__ = ["foo"]
3364
3365            def __init__(self, foo):
3366                super().__init__()
3367                self.foo = foo
3368
3369        m = M(5)
3370        # m has a constant attribute, but we can't
3371        # assign to it
3372        with self.assertRaises(RuntimeError):
3373            m.foo = 6
3374
3375    def test_class_attribute(self):
3376        class M(torch.jit.ScriptModule):
3377            FOO = 0
3378
3379            def __init__(self) -> None:
3380                super().__init__()
3381                self.foo = self.FOO
3382        m = M()
3383        self.assertEqual(m.foo, M.FOO)
3384
3385    def test_class_attribute_in_script(self):
3386        class M(torch.jit.ScriptModule):
3387            FOO = 0
3388
3389            @torch.jit.script_method
3390            def forward(self):
3391                return self.FOO
3392        with self.assertRaises(RuntimeError):
3393            M()
3394
3395    def test_not_initialized_err(self):
3396        class M(torch.jit.ScriptModule):
3397            def __init__(self) -> None:
3398                self.foo = torch.rand(2, 3)
3399        with self.assertRaises(RuntimeError):
3400            M()
3401
3402    def test_attribute_in_init(self):
3403        class M(torch.jit.ScriptModule):
3404            def __init__(self) -> None:
3405                super().__init__()
3406                self.foo = torch.jit.Attribute(0.1, float)
3407                # we should be able to use self.foo as a float here
3408                assert 0.0 < self.foo
3409        M()
3410
3411    def test_scriptable_fn_as_attr(self):
3412        class M(torch.nn.Module):
3413            def __init__(self, fn):
3414                super().__init__()
3415                self.fn = fn
3416
3417            def forward(self, x):
3418                return self.fn(x)
3419
3420        m = M(torch.sigmoid)
3421        inp = torch.rand(2, 3)
3422        self.checkModule(m, (inp, ))
3423
3424    def test_sequence_parsing(self):
3425        tests = [
3426            ("return [x, x,]", True),
3427            ("return [x x]", "expected ]"),
3428            ("return x, x,", True),
3429            ("return bar(x, x,)", True),
3430            ("return bar()", "Argument x not provided"),
3431            ("for a, b, in x, x,:\n        pass", "List of iterables"),
3432            ("a, b, = x, x,\n    return a + b", True)
3433        ]
3434        for exp, result in tests:
3435            cu = torch.jit.CompilationUnit()
3436            full = f"""
3437def bar(x, y):
3438    return x + y
3439def foo(x):
3440    {exp}
3441            """
3442            if isinstance(result, str):
3443                with self.assertRaisesRegex(RuntimeError, result):
3444                    cu.define(full)
3445            else:
3446                cu.define(full)
3447
3448    def test_namedtuple_python(self):
3449        global MyTuple, MyMod  # see [local resolution in python]
3450        MyTuple = namedtuple('MyTuple', ['a'])
3451
3452        @torch.jit.unused
3453        def fn():
3454            # type: () -> MyTuple
3455            return MyTuple(1)
3456
3457        # Only check compilation
3458        @torch.jit.script
3459        def fn2():
3460            # type: () -> MyTuple
3461            return fn()
3462
3463        FileCheck().check("NamedTuple").run(fn2.graph)
3464
3465        class MyMod(torch.nn.Module):
3466            @torch.jit.unused
3467            def fn(self):
3468                # type: () -> MyTuple
3469                return MyTuple(1)
3470
3471            def forward(self, x):
3472                if 1 == 1:
3473                    return MyTuple(torch.rand(2, 3))
3474                else:
3475                    return self.fn()
3476
3477        # shouldn't throw a type error
3478        torch.jit.script(MyMod())
3479
3480    def test_unused_decorator(self):
3481        class MyMod(torch.nn.Module):
3482            @torch.jit.unused
3483            @torch.no_grad()
3484            def fn(self, x):
3485                # type: (Tensor) -> int
3486                return next(x)  # invalid, but should be ignored
3487
3488            def forward(self, x):
3489                return self.fn(x)
3490
3491        torch.jit.script(MyMod())
3492
3493    @_inline_everything
3494    def test_lazy_script(self):
3495        def untraceable(x):
3496            if x.ndim > 2:
3497                print("hello")
3498            else:
3499                print("goodbye")
3500            return x + 2
3501
3502        # Non-working example
3503        def fn(x):
3504            return untraceable(x)
3505
3506        with self.capture_stdout():
3507            traced_bad = torch.jit.trace(fn, [torch.ones(2, 2)])
3508
3509        FileCheck().check_not("goodbye").check_not("hello").run(traced_bad.graph)
3510
3511        # Working example
3512        untraceable = torch.jit.script_if_tracing(untraceable)
3513
3514        def fn2(x):
3515            return untraceable(x)
3516
3517        with self.capture_stdout():
3518            traced = torch.jit.trace(fn, [torch.ones(2, 2)])
3519
3520        FileCheck().check("goodbye").run(traced.graph)
3521
3522        def foo(x: int):
3523            return x + 1
3524
3525        @torch.jit.script_if_tracing
3526        def fee(x: int = 2):
3527            return foo(1) + x
3528
3529        # test directly compiling function
3530        fee_compiled = torch.jit.script(fee)
3531        self.assertEqual(fee_compiled(), fee())
3532
3533        # test compiling it within another function
3534        @torch.jit.script
3535        def hum():
3536            return fee(x=3)
3537
3538        self.assertEqual(hum(), 5)
3539
3540    def test_big_int_literals(self):
3541        def ok():
3542            # signed 64 bit max
3543            a = 9223372036854775807
3544            return a
3545
3546        def toobig():
3547            a = 9223372036854775808
3548            return a
3549
3550        def waytoobig():
3551            a = 99999999999999999999
3552            return a
3553
3554        self.checkScript(ok, [])
3555
3556        with self.assertRaisesRegex(RuntimeError, "out of range"):
3557            torch.jit.script(toobig)
3558
3559        with self.assertRaisesRegex(RuntimeError, "out of range"):
3560            torch.jit.script(waytoobig)
3561
3562    def test_hex_literals(self):
3563        def test1():
3564            return 0xaaaaaa
3565
3566        def test2():
3567            return 0xaaaaaa
3568
3569        def test3():
3570            return -0xaaaaaa
3571
3572        self.checkScript(test1, [])
3573        self.checkScript(test2, [])
3574        self.checkScript(test3, [])
3575
3576        def ok():
3577            a = 0x7FFFFFFFFFFFFFFF
3578            return a
3579
3580        def toobig():
3581            a = 0xFFFFFFFFFFFFFFFF
3582            return a
3583
3584        def waytoobig():
3585            a = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF
3586            return a
3587
3588        self.checkScript(ok, [])
3589
3590        with self.assertRaisesRegex(RuntimeError, "out of range"):
3591            torch.jit.script(toobig)
3592
3593        with self.assertRaisesRegex(RuntimeError, "out of range"):
3594            torch.jit.script(waytoobig)
3595
3596    def test_big_float_literals(self):
3597        def ok():
3598            # Python interprets this as inf
3599            a = 1.2E400
3600            return a
3601
3602        def check(fn):
3603            self.assertTrue(fn() == ok())
3604
3605        # checkScript doesn't work since assertEqual doesn't consider
3606        # `inf` == `inf`
3607        check(torch.jit.script(ok))
3608
3609        cu = torch.jit.CompilationUnit()
3610        cu.define(dedent(inspect.getsource(ok)))
3611        check(cu.ok)
3612
3613    def _test_device_type(self, dest):
3614        def fn(x):
3615            # type: (Device) -> Tuple[str, Optional[int]]
3616            return x.type, x.index
3617
3618        device = torch.ones(2).to(dest).device
3619        self.checkScript(fn, [device])
3620
3621    def test_device_type(self):
3622        self._test_device_type('cpu')
3623
3624    @unittest.skipIf(not RUN_CUDA, "Requires CUDA")
3625    def test_device_type_cuda(self):
3626        self._test_device_type('cuda')
3627
3628    def test_string_device_implicit_conversion(self):
3629        @torch.jit.script
3630        def fn(x: torch.device):
3631            return x
3632
3633        self.assertEqual(fn("cpu"), torch.device("cpu"))
3634
3635        with self.assertRaisesRegex(RuntimeError, "Expected one of"):
3636            fn("invalid_device")
3637
3638    def test_eval_python(self):
3639        def _test(m):
3640            self.assertTrue(m(torch.ones(2, 2)))
3641            self.assertTrue(m.training)
3642            self.assertTrue(m._c.getattr('training'))
3643
3644            m.eval()
3645
3646            self.assertFalse(m.training)
3647            self.assertFalse(m._c.getattr('training'))
3648            self.assertFalse(m(torch.ones(2, 2)))
3649
3650            buffer = io.BytesIO()
3651            torch.jit.save(m, buffer)
3652            buffer.seek(0)
3653
3654            loaded = torch.jit.load(buffer)
3655
3656            self.assertFalse(loaded.training)
3657            self.assertFalse(loaded._c.getattr('training'))
3658
3659        class M(nn.Module):
3660            def forward(self, x):
3661                return self.training
3662
3663        class OldM(torch.jit.ScriptModule):
3664            @torch.jit.script_method
3665            def forward(self, x):
3666                return self.training
3667
3668        _test(torch.jit.script(M()))
3669        _test(OldM())
3670
3671    def test_inherit_method(self):
3672        class A(torch.jit.ScriptModule):
3673            @torch.jit.script_method
3674            def forward(self, x):
3675                return x + self.bar(x)
3676
3677        class B(A):
3678            @torch.jit.script_method
3679            def bar(self, x):
3680                return x * x
3681
3682        with self.assertRaisesRegex(RuntimeError, 'attribute'):
3683            A()  # cannot use because bar is not defined
3684
3685        v = torch.rand(3, 4)
3686        b = B()
3687        self.assertEqual(b(v), v + v * v)
3688
3689        class C(torch.jit.ScriptModule):
3690            @torch.jit.script_method
3691            def bar(self, x):
3692                return x
3693
3694        class D(C, B):
3695            def __init__(self) -> None:
3696                super().__init__()
3697
3698        self.assertEqual(D()(v), v + v)
3699
3700    def test_tensor_subclasses(self):
3701        def check_subclass(x, tensor):
3702            template = dedent("""
3703                def func(input: {}) -> {}:
3704                    return torch.zeros((input.shape[0], 1), dtype=input.dtype)
3705                """)
3706
3707            self._check_code(template.format(x, x), "func", [tensor])
3708
3709        check_subclass("torch.LongTensor", torch.LongTensor([[1, 2], [3, 4]]))
3710        check_subclass("torch.DoubleTensor", torch.DoubleTensor([[1.2, 2.3], [3.4, 4.5]]))
3711        check_subclass("torch.IntTensor", torch.IntTensor([[1, 2], [3, 4]]))
3712        check_subclass("torch.BoolTensor", torch.BoolTensor([[False, True], [True, False]]))
3713
3714        def check_subclass_warn(input: torch.LongTensor) -> torch.LongTensor:
3715            return torch.zeros((input.shape[0], 1), dtype=input.dtype)
3716
3717        with warnings.catch_warnings(record=True) as warns:
3718            scripted = torch.jit.script(check_subclass_warn)
3719        FileCheck().check("TorchScript will treat type annotations of Tensor").run(str(warns[0]))
3720
3721    def test_first_class_module(self):
3722        class Foo(torch.jit.ScriptModule):
3723            def __init__(self) -> None:
3724                super().__init__()
3725                self.foo = nn.Parameter(torch.rand(3, 4))
3726
3727            @torch.jit.script_method
3728            def forward(self, input):
3729                self.foo = input
3730                return self.foo
3731        foo = Foo()
3732        input = torch.rand(3, 4)
3733        foo.forward(input)
3734        self.assertEqual(input, foo.foo)
3735
3736    @_tmp_donotuse_dont_inline_everything
3737    def test_first_class_calls(self):
3738        @torch.jit.script
3739        class Foo:
3740            def __init__(self, x):
3741                self.bar = x
3742
3743            def stuff(self, x):
3744                return self.bar + x
3745
3746        @torch.jit.script
3747        def foo(x):
3748            return x * x + Foo(x).stuff(2 * x)
3749
3750        @torch.jit.script
3751        def bar(x):
3752            return foo(x) * foo(x)
3753
3754        x = torch.rand(3, 4)
3755        self.assertEqual(bar(x), (x * x + 3 * x) * (x * x + 3 * x))
3756
3757    def test_static_methods(self):
3758        class M(nn.Module):
3759            @staticmethod
3760            def my_method(x):
3761                return x + 100
3762
3763            def forward(self, x):
3764                return x + M.my_method(x)
3765
3766        class N(nn.Module):
3767            @staticmethod
3768            def my_method(x):
3769                return x * 100
3770
3771            def forward(self, x):
3772                return x - M.my_method(x) + N.my_method(x)
3773
3774        self.checkModule(M(), (torch.ones(2, 2),))
3775
3776        self.checkModule(N(), (torch.ones(2, 2),))
3777
3778    def test_invalid_prefix_annotation(self):
3779        with self.assertRaisesRegex(RuntimeError, "annotation prefix in line"):
3780            with self.capture_stdout() as captured:
3781                @torch.jit.script
3782                def invalid_prefix_annotation1(a):
3783                    #type: (Int) -> Int # noqa: E265
3784                    return a + 2
3785
3786        with self.assertRaisesRegex(RuntimeError, "annotation prefix in line"):
3787            with self.capture_stdout() as captured:
3788                @torch.jit.script
3789                def invalid_prefix_annotation2(a):
3790                    #type   : (Int) -> Int # noqa: E265
3791                    return a + 2
3792
3793        with self.assertRaisesRegex(RuntimeError, "annotation prefix in line"):
3794            with self.capture_stdout() as captured:
3795                @torch.jit.script
3796                def invalid_prefix_annotation3(a):
3797                    #     type: (Int) -> Int
3798                    return a + 2
3799
3800    def test_builtin_function_attributes(self):
3801        class Add(nn.Module):
3802            def __init__(self) -> None:
3803                super().__init__()
3804                self.add = torch.add
3805
3806            def forward(self, input):
3807                return self.add(input, input)
3808
3809        self.checkModule(Add(), [torch.randn(2, 2)])
3810
3811    def test_pybind_type_comparisons(self):
3812        @torch.jit.script
3813        def f():
3814            return None
3815
3816        node = list(f.graph.nodes())[0]
3817        t = node.outputsAt(0).type()
3818        self.assertIsNotNone(t)
3819
3820    @unittest.skipIf(IS_WINDOWS, 'TODO: need to fix the test case')
3821    def test_unmatched_type_annotation(self):
3822        message1 = re.escape("Number of type annotations (2) did not match the number of function parameters (1):")
3823        message2 = 'def invalid2\\(a\\):\n\\s*~+\\.*\\s+<--- HERE\n\\s+# type: \\(Int, Int\\) -> Int\n\\s+return a \\+ 2'
3824        message3 = 'def invalid4\\(a\\):\n\\s*~+\\.*\\s+<--- HERE\n\\s+# type: \\(Int, Int\\) -> Int\n\\s+return a \\+ 2'
3825        with self.assertRaisesRegex(RuntimeError, message1):
3826            @torch.jit.script
3827            def invalid1(a):
3828                # type: (Int, Int) -> Int
3829                return a + 2
3830
3831        with self.assertRaisesRegex(RuntimeError, message2):
3832            @torch.jit.script
3833            def invalid2(a):
3834                # type: (Int, Int) -> Int
3835                return a + 2
3836
3837        with self.assertRaisesRegex(RuntimeError, message1):
3838            def invalid3(a):
3839                # type: (Int, Int) -> Int
3840                return a + 2
3841            torch.jit.script(invalid3)
3842
3843        with self.assertRaisesRegex(RuntimeError, message3):
3844            def invalid4(a):
3845                # type: (Int, Int) -> Int
3846                return a + 2
3847            torch.jit.script(invalid4)
3848
3849    def test_calls_in_type_annotations(self):
3850        with self.assertRaisesRegex(RuntimeError, "Type annotation should not contain calls"):
3851            def spooky(a):
3852                # type: print("Hello") -> Tensor # noqa: F723
3853                return a + 2
3854            print(torch.__file__)
3855            torch.jit.annotations.get_signature(spooky, None, 1, True)
3856
3857    def test_is_optional(self):
3858        ann = Union[List[int], List[float]]
3859        torch._jit_internal.is_optional(ann)
3860
3861    def test_interpreter_fuzz(self):
3862        import builtins
3863        # This test generates random tree-like programs to fuzz test
3864        # that the interpreter does not have a bug in its stack manipulation
3865        # code. An assert in that code ensures individual operators are
3866        # not reordered.
3867        templates = [
3868            "torch.rand(3, 4)",
3869            "({} + {})",
3870            "-{}",
3871            "({} * {})",
3872            "torch.tanh({})",
3873            "VAR {}",
3874        ]
3875
3876        def gen_code():
3877            src_lines = ['def f():']
3878            exprs = []
3879            n_variables = 0
3880
3881            def get_expr(idx):
3882                elem = exprs[idx]
3883                exprs[idx] = exprs[-1]
3884                exprs.pop()
3885                return elem
3886
3887            def select_expr_or_var():
3888                idx = random.randrange(0, len(exprs) + n_variables)
3889                if idx < len(exprs):
3890                    return get_expr(idx)
3891                else:
3892                    return f'v{idx - len(exprs)}'
3893
3894            for i in range(50):
3895                n = None
3896                while n is None or n > len(exprs) + n_variables:
3897                    template = random.choice(templates)
3898                    n = template.count('{}')
3899
3900                if 'VAR' in template:
3901                    src_lines.append(f'  v{n_variables} = {select_expr_or_var()}')
3902                    n_variables += 1
3903                else:
3904                    exprs.append(template.format(*(select_expr_or_var() for _ in range(n))))
3905
3906            src_lines.append('  return ({})\n'.format(''.join(f'v{i},' for i in range(n_variables))))
3907            return '\n'.join(src_lines)
3908
3909        for i in range(100):
3910            g = {'torch': torch}
3911            code = gen_code()
3912            builtins.exec(code, g, None)
3913            cu = torch.jit.CompilationUnit(code)
3914            with freeze_rng_state():
3915                o1 = g['f']()
3916            with freeze_rng_state():
3917                o2 = cu.f()
3918            self.assertEqual(o1, o2)
3919
3920    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
3921    def test_cpp_module_iterator(self):
3922        a = nn.Module()
3923        a.name = 'a'
3924        a.p = nn.Parameter(torch.rand(3, 4))
3925        a.foo = nn.Module()
3926        a.foo.name = 'foo'
3927        a.foo.b = nn.Buffer(torch.rand(1, 1))
3928        a.foo.bar = nn.Module()
3929        a.foo.bar.name = 'bar'
3930        a.foo.bar.an_int = 4
3931        a.another = nn.Module()
3932        a.another.name = 'another'
3933        sa = torch.jit.script(a)
3934        result = torch._C._jit_debug_module_iterators(sa._c)
3935
3936        def replace(e):
3937            if e is a.p:
3938                return 'P'
3939            elif e is a.foo.b:
3940                return 'B'
3941            elif isinstance(e, torch._C.ScriptModule):
3942                return e.getattr('name')
3943
3944            return e
3945        for v in result.values():
3946            for i in range(len(v)):
3947                if isinstance(v[i], tuple):
3948                    n, v2 = v[i]
3949                    v[i] = (n, replace(v2))
3950                else:
3951                    v[i] = replace(v[i])
3952            # module type creation is not deterministic, so we have to sort
3953            # the result
3954            v.sort()
3955        expected = {'buffers': [],
3956                    'buffers_r': ['B'],
3957                    'children': ['another', 'foo'],
3958                    'modules': ['a', 'another', 'bar', 'foo'],
3959                    'named_attributes': [('_is_full_backward_hook', None),
3960                                         ('another', 'another'),
3961                                         ('foo', 'foo'),
3962                                         ('name', 'a'),
3963                                         ('p', 'P'),
3964                                         ('training', True)],
3965                    'named_attributes_r': [('_is_full_backward_hook', None),
3966                                           ('another', 'another'),
3967                                           ('another._is_full_backward_hook', None),
3968                                           ('another.name', 'another'),
3969                                           ('another.training', True),
3970                                           ('foo', 'foo'),
3971                                           ('foo._is_full_backward_hook', None),
3972                                           ('foo.b', 'B'),
3973                                           ('foo.bar', 'bar'),
3974                                           ('foo.bar._is_full_backward_hook', None),
3975                                           ('foo.bar.an_int', 4),
3976                                           ('foo.bar.name', 'bar'),
3977                                           ('foo.bar.training', True),
3978                                           ('foo.name', 'foo'),
3979                                           ('foo.training', True),
3980                                           ('name', 'a'),
3981                                           ('p', 'P'),
3982                                           ('training', True)],
3983                    'named_buffers': [],
3984                    'named_buffers_r': [('foo.b', 'B')],
3985                    'named_children': [('another', 'another'), ('foo', 'foo')],
3986                    'named_modules': [('', 'a'),
3987                                      ('another', 'another'),
3988                                      ('foo', 'foo'),
3989                                      ('foo.bar', 'bar')],
3990                    'named_parameters': [('p', 'P')],
3991                    'named_parameters_r': [('p', 'P')],
3992                    'parameters': ['P'],
3993                    'parameters_r': ['P']}
3994        self.assertEqual(expected, result)
3995
3996    def test_parameter_order(self):
3997        m = nn.Module()
3998        for i, name in enumerate(string.ascii_letters):
3999            setattr(m, name, nn.Parameter(torch.tensor([float(i)])))
4000        ms = torch.jit.script(m)
4001        print(torch.cat(list(m.parameters())))
4002        print(torch.cat(list(ms.parameters())))
4003        self.assertEqual(list(m.parameters()), list(ms.parameters()))
4004
4005    def test_python_op_builtins(self):
4006        @torch.jit.unused
4007        def fn(x):
4008            # type: (List[int]) -> int
4009            return sum(x)
4010
4011        @torch.jit.script
4012        def script_fn(x):
4013            # type: (List[int]) -> int
4014            return fn(x)
4015
4016    def test_submodule_twice(self):
4017        @torch.jit.script
4018        def foo(x):
4019            return x * x
4020
4021        class What(torch.jit.ScriptModule):
4022            def __init__(self, x):
4023                super().__init__()
4024                self.foo = x
4025        a = What(foo)
4026        c = What(foo)
4027
4028    def test_training_param(self):
4029        class What(torch.jit.ScriptModule):
4030            @torch.jit.script_method
4031            def forward(self, x):
4032                # type: (int) -> int
4033                if self.training:
4034                    r = x
4035                else:
4036                    r = x + 4
4037                # check double use of training
4038                if self.training:
4039                    r = r + 1
4040                return r
4041
4042        w = What()
4043        self.assertEqual(4, w(3))
4044        w.train(False)
4045        self.assertEqual(7, w(3))
4046        self.assertFalse("training" in w.state_dict())
4047
4048    def test_class_as_attribute(self):
4049        @torch.jit.script
4050        class Foo321:
4051            def __init__(self) -> None:
4052                self.x = 3
4053
4054        class FooBar1234(torch.nn.Module):
4055            def __init__(self) -> None:
4056                super().__init__()
4057                self.f = Foo321()
4058
4059            def forward(self, x):
4060                return x + self.f.x
4061
4062        scripted = torch.jit.script(FooBar1234())
4063        eic = self.getExportImportCopy(scripted)
4064        x = torch.rand(3, 4)
4065        self.assertEqual(scripted(x), eic(x))
4066
4067    def test_module_str(self):
4068        class Foo(torch.nn.Module):
4069            def forward(self, x):
4070                return torch.relu(x)
4071
4072        f = torch.jit.script(Foo())
4073
4074        str_f = str(f._c)
4075        self.assertTrue(str_f.startswith('ScriptObject'))
4076        self.assertTrue('__torch__.' in str_f)
4077        self.assertTrue('.Foo' in str_f)
4078
4079    def test_jitter_bug(self):
4080        @torch.jit.script
4081        def fn2(input, kernel_size):
4082            # type: (Tensor, List[int]) -> Tensor
4083            if kernel_size[0] > 1:
4084                _stride = [2]
4085            else:
4086                _stride = kernel_size
4087            print(_stride, kernel_size)
4088            return input
4089
4090        @torch.jit.script
4091        def fn(input):
4092            # type: (Tensor) -> Tensor
4093            return fn2(input, [1])
4094
4095    def test_parser_kwargonly(self):
4096        cu = torch.jit.CompilationUnit('''
4097            def foo(x, *, y) -> Tuple[Tensor, Tensor]:
4098                return x, x
4099            def bar(x):
4100                return foo(x, y=x)
4101        ''')
4102        self.assertTrue('*' in str(cu.foo.schema))
4103        with self.assertRaisesRegex(RuntimeError, "not provided"):
4104            torch.jit.CompilationUnit('''
4105                def foo(x, *, y) -> Tuple[Tensor, Tensor]:
4106                    return x, x
4107                def bar(x):
4108                    return foo(x, x)
4109            ''')
4110
4111    def test_annoying_doubles(self):
4112        mod = types.ModuleType("temp")
4113        mod.inf = float("inf")
4114        mod.ninf = float("-inf")
4115        mod.nan = float("nan")
4116
4117        with torch._jit_internal._disable_emit_hooks():
4118            class Foo(torch.jit.ScriptModule):
4119                @torch.jit.script_method
4120                def forward(self):
4121                    return math.pi, 0.1, mod.inf, mod.ninf, 2.225073858507201e-308, mod.nan
4122
4123            foo = Foo()
4124            buffer = io.BytesIO()
4125            torch.jit.save(foo, buffer)
4126
4127            buffer.seek(0)
4128            foo_loaded = torch.jit.load(buffer)
4129
4130            r = foo()
4131            r2 = foo_loaded()
4132            # use precise assert, we are checking floating point details
4133            self.assertTrue(r[:-1] == r2[:-1])
4134            self.assertTrue(math.isnan(r[-1]) and math.isnan(r2[-1]))
4135
4136    def test_type_annotate(self):
4137
4138        def foo(a):
4139            return torch.jit.annotate(torch.Tensor, a)
4140
4141        self.checkScript(foo, (torch.rand(3),))
4142
4143        def bar():
4144            a = torch.jit.annotate(List[int], [])
4145            for _ in range(10):
4146                a.append(4)
4147            return a
4148
4149        self.checkScript(bar, ())
4150
4151        def baz(a):
4152            return torch.jit.annotate(float, a)
4153        self.checkScript(baz, (torch.rand(()),))
4154
4155        # test annotate none types
4156        def annotate_none():
4157            return torch.jit.annotate(Optional[torch.Tensor], None)
4158
4159        self.checkScript(annotate_none, ())
4160
4161
4162    def test_robust_op_resolution(self):
4163        neg = torch.add  # misleading name to make sure we resolve by function
4164
4165        def stuff(x):
4166            return neg(x, x)
4167
4168        a = (torch.rand(3),)
4169        self.checkScript(stuff, a)
4170
4171    def test_nested_aug_assign(self):
4172        @torch.jit.script
4173        class SomeClass:
4174            def __init__(self) -> None:
4175                self.num = 99
4176
4177            def __iadd__(self, x):
4178                # type: (int)
4179                self.num += x
4180                return self
4181
4182            def __eq__(self, other):
4183                # type: (SomeClass) -> bool
4184                return self.num == other.num
4185
4186        @torch.jit.script
4187        class SomeOutOfPlaceClass:
4188            def __init__(self) -> None:
4189                self.num = 99
4190
4191            def __add__(self, x):
4192                # type: (int)
4193                self.num = x
4194                return self
4195
4196            def __eq__(self, other):
4197                # type: (SomeClass) -> bool
4198                return self.num == other.num
4199
4200        class Child(nn.Module):
4201            def __init__(self) -> None:
4202                super().__init__()
4203                self.x = 2
4204                self.o = SomeClass()
4205                self.oop = SomeOutOfPlaceClass()
4206                self.list = [1, 2, 3]
4207
4208        class A(nn.Module):
4209            def __init__(self) -> None:
4210                super().__init__()
4211                self.child = Child()
4212
4213            def forward(self):
4214                self.child.x += 1
4215                self.child.o += 5
4216                self.child.oop += 5
4217                some_list = [1, 2]
4218                self.child.list += some_list
4219                self.child.list *= 2
4220                return self.child.x, self.child.o, self.child.list, self.child.oop
4221
4222        a = A()
4223        sa = torch.jit.script(A())
4224        eager_result = a()
4225        script_result = sa()
4226        self.assertEqual(eager_result, script_result)
4227        self.assertEqual(a.child.x, sa.child.x)
4228        self.assertEqual(a.child.o, sa.child.o)
4229        self.assertEqual(a.child.list, sa.child.list)
4230
4231        @torch.jit.script
4232        class SomeNonAddableClass:
4233            def __init__(self) -> None:
4234                self.num = 99
4235
4236            def __eq__(self, other):
4237                # type: (SomeClass) -> bool
4238                return self.num == other.num
4239
4240        # with self.assertRaisesRegex(RuntimeError, "")
4241        class A(nn.Module):
4242            def __init__(self) -> None:
4243                super().__init__()
4244                self.x = SomeNonAddableClass()
4245
4246            def forward(self):
4247                self.x += SomeNonAddableClass()
4248                return self.x
4249
4250        with self.assertRaisesRegex(RuntimeError, "Cannot emit inplace op"):
4251            torch.jit.script(A())
4252
4253    def test_var_aug_assign(self):
4254        @torch.jit.script
4255        class SomeNonAddableClass:
4256            def __init__(self) -> None:
4257                self.num = 99
4258
4259            def __eq__(self, other):
4260                # type: (SomeNonAddableClass) -> bool
4261                return self.num == other.num
4262
4263        with self.assertRaisesRegex(RuntimeError, "Cannot emit inplace op"):
4264            @torch.jit.script
4265            def fn():
4266                a = SomeNonAddableClass()
4267                a += SomeNonAddableClass()
4268                return a
4269
4270        @torch.jit.script
4271        class SomeClass:
4272            def __init__(self) -> None:
4273                self.num = 99
4274
4275            def __iadd__(self, x):
4276                # type: (int)
4277                self.num += x
4278                return self
4279
4280            def __eq__(self, other):
4281                # type: (SomeClass) -> bool
4282                return self.num == other.num
4283
4284        @torch.jit.script
4285        class SomeOutOfPlaceClass:
4286            def __init__(self) -> None:
4287                self.num = 99
4288
4289            def __add__(self, x):
4290                # type: (int)
4291                self.num = x
4292                return self
4293
4294            def __eq__(self, other):
4295                # type: (SomeClass) -> bool
4296                return self.num == other.num
4297
4298        def fn2():
4299            a = SomeClass()
4300            a_copy = a
4301            a += 20
4302            assert a is a_copy
4303            b = SomeOutOfPlaceClass()
4304            b_copy = b
4305            b += 99
4306            assert b is b_copy
4307            c = [1, 2, 3]
4308            c_copy = c
4309            c *= 2
4310            assert c is c_copy
4311            c += [4, 5, 6]
4312            d = torch.ones(2, 2)
4313            d_copy = d
4314            d += torch.ones(2, 2)
4315            assert d is d_copy
4316            return a, b, c, d
4317
4318        self.checkScript(fn2, [])
4319
4320    def test_nested_list_construct(self):
4321        def foo():
4322            return [[4]] + [[4, 5]]
4323        self.checkScript(foo, ())
4324
4325    def test_file_line_error(self):
4326        def foobar(xyz):
4327            return torch.blargh(xyz)
4328
4329        _, lineno = inspect.getsourcelines(foobar)
4330        with self.assertRaisesRegex(RuntimeError, f'test_jit.py", line {lineno + 1}'):
4331            scripted = torch.jit.script(foobar)
4332
4333    def test_file_line_error_class_defn(self):
4334        class FooBar:
4335            def baz(self, xyz):
4336                return torch.blargh(xyz)
4337
4338        _, lineno = inspect.getsourcelines(FooBar)
4339        with self.assertRaisesRegex(RuntimeError, f'test_jit.py", line {lineno + 2}'):
4340            torch.jit.script(FooBar)
4341
4342    def test_file_line_graph(self):
4343        def foobar(xyz):
4344            return torch.neg(xyz)
4345
4346        scripted = torch.jit.script(foobar)
4347
4348        _, lineno = inspect.getsourcelines(foobar)
4349        fc = FileCheck().check(f'test_jit.py:{lineno + 1}:19')
4350        fc.run(scripted.graph)
4351        fc.run(str(scripted.graph))
4352
4353    def test_file_line_save_load(self):
4354        class Scripted(torch.jit.ScriptModule):
4355            @torch.jit.script_method
4356            def forward(self, xyz):
4357                return torch.neg(xyz)
4358
4359        scripted = Scripted()
4360
4361        # NB: not using getExportImportCopy because that takes a different
4362        # code path that calls CompilationUnit._import rather than
4363        # going through the full save/load pathway
4364        buffer = scripted.save_to_buffer()
4365        bytesio = io.BytesIO(buffer)
4366        scripted = torch.jit.load(bytesio)
4367
4368        _, lineno = inspect.getsourcelines(Scripted)
4369        fc = FileCheck().check(f':{lineno + 3}')
4370        fc.run(scripted.graph)
4371        fc.run(str(scripted.graph))
4372
4373    def test_file_line_string(self):
4374        scripted = torch.jit.CompilationUnit('''
4375def foo(xyz):
4376    return torch.neg(xyz)
4377        ''')
4378
4379        fc = FileCheck().check('<string>:3:11')
4380        fc.run(scripted.foo.graph)
4381        fc.run(str(scripted.foo.graph))
4382
4383    @skipIfCrossRef
4384    def test_file_line_trace(self):
4385        def foobar(xyz):
4386            return torch.neg(xyz)
4387
4388        scripted = torch.jit.trace(foobar, (torch.rand(3, 4)))
4389
4390        _, lineno = inspect.getsourcelines(foobar)
4391        fc = FileCheck().check(f'test_jit.py:{lineno + 1}:0')
4392        fc.run(scripted.graph)
4393        fc.run(str(scripted.graph))
4394
4395    def test_serialized_source_ranges(self):
4396
4397        class FooTest(torch.jit.ScriptModule):
4398            @torch.jit.script_method
4399            def forward(self, x, w):
4400                return torch.mm(x, w.t())
4401
4402        ft = FooTest()
4403        loaded = self.getExportImportCopy(ft)
4404        _, lineno = inspect.getsourcelines(FooTest)
4405
4406        with self.assertRaisesRegex(RuntimeError, f'test_jit.py", line {lineno + 3}'):
4407            loaded(torch.rand(3, 4), torch.rand(30, 40))
4408
4409    def test_serialized_source_ranges_graph(self):
4410
4411        class FooTest3(torch.jit.ScriptModule):
4412            @torch.jit.script_method
4413            def forward(self, x, w):
4414                return torch.mm(x, w.t())
4415
4416        ft = FooTest3()
4417        loaded = self.getExportImportCopy(ft)
4418        _, lineno = inspect.getsourcelines(FooTest3)
4419
4420        fc = FileCheck().check(f'test_jit.py:{lineno + 3}')
4421        fc.run(loaded.graph)
4422
4423    def test_serialized_source_ranges2(self):
4424
4425        class FooTest2(torch.jit.ScriptModule):
4426            @torch.jit.script_method
4427            def forward(self):
4428                raise RuntimeError('foo')
4429
4430        _, lineno = inspect.getsourcelines(FooTest2)
4431
4432        with self.assertRaisesRegex(torch.jit.Error, f'test_jit.py", line {lineno + 3}'):
4433            ft = FooTest2()
4434            loaded = self.getExportImportCopy(ft)
4435            loaded()
4436
4437    def test_serialized_source_ranges_dont_jitter(self):
4438        class FooTest3(torch.jit.ScriptModule):
4439            @torch.jit.script_method
4440            def forward(self, lim):
4441                first = 1
4442                second = 1
4443                i = 1
4444                somenum = 5
4445                dontmutateme = 3
4446                third = 0
4447                while bool(i < lim):
4448                    third = first + second
4449                    first = second
4450                    second = third
4451                    j = 0
4452                    while j < 10:
4453                        somenum = somenum * 2
4454                        j = j + 1
4455                    i = i + j
4456                    i = i + dontmutateme
4457
4458                st = second + third
4459                fs = first + second
4460                return third, st, fs
4461
4462        ft3 = FooTest3()
4463
4464        def debug_records_from_mod(self, mod):
4465            buffer = io.BytesIO()
4466            torch.jit.save(ft3, buffer)
4467            buffer.seek(0)
4468            archive = zipfile.ZipFile(buffer)
4469            files = filter(lambda x: x.startswith('archive/code/'), archive.namelist())
4470            debug_files = list(filter(lambda f: f.endswith('.debug_pkl'), files))
4471            self.assertEqual(len(debug_files), 1)
4472            debug_file = archive.open(debug_files[0])
4473            return pickle.load(debug_file), buffer
4474
4475        records1, buffer = debug_records_from_mod(self, ft3)
4476
4477        buffer.seek(0)
4478        loaded = torch.jit.load(buffer)
4479        records2, buffer = debug_records_from_mod(self, loaded)
4480
4481        buffer.seek(0)
4482        loaded2 = torch.jit.load(buffer)
4483        records3, _ = debug_records_from_mod(self, loaded2)
4484
4485        self.assertEqual(records1, records2)
4486        self.assertEqual(records2, records3)
4487
4488    def test_serialized_source_ranges_no_dups(self):
4489        class FooTest3(torch.jit.ScriptModule):
4490            @torch.jit.script_method
4491            def forward(self, lim):
4492                first = 1
4493                second = 1
4494                i = 1
4495                somenum = 5
4496                dontmutateme = 3
4497                third = 0
4498                while bool(i < lim):
4499                    third = first + second
4500                    first = second
4501                    second = third
4502                    j = 0
4503                    while j < 10:
4504                        somenum = somenum * 2
4505                        j = j + 1
4506                    i = i + j
4507                    i = i + dontmutateme
4508
4509                st = second + third
4510                fs = first + second
4511                return third, st, fs
4512
4513        ft3 = FooTest3()
4514
4515        def debug_records_from_mod(mod):
4516            buffer = io.BytesIO()
4517            torch.jit.save(ft3, buffer)
4518            buffer.seek(0)
4519            archive = zipfile.ZipFile(buffer)
4520            files = list(filter(lambda x: x.startswith('archive/code/'), archive.namelist()))
4521            debug_files = filter(lambda f: f.endswith('.debug_pkl'), files)
4522            debug_files = (archive.open(f) for f in debug_files)
4523            debug_files = (pickle.load(f) for f in debug_files)
4524            debug_files = (f[2] for f in debug_files)
4525            return list(debug_files)
4526
4527        debug_files = debug_records_from_mod(ft3)
4528        for debug_file in debug_files:
4529            for i in range(len(debug_file) - 1):
4530                offset, source_range_tag, source_range = debug_file[i]
4531                offset2, source_range_tag2, source_range2 = debug_file[i + 1]
4532                self.assertNotEqual(source_range, source_range2)
4533
4534    def test_circular_dependency(self):
4535        """
4536        https://github.com/pytorch/pytorch/issues/25871
4537        """
4538        class A(torch.jit.ScriptModule):
4539            @torch.jit.script_method
4540            def forward(self, x):
4541                return x
4542
4543        class B(torch.jit.ScriptModule):
4544            def __init__(self) -> None:
4545                super().__init__()
4546                self.foo = torch.nn.ModuleList([A()])
4547
4548            @torch.jit.script_method
4549            def forward(self, x):
4550                for f in self.foo:
4551                    x = f(x)
4552                return x
4553
4554        class C(torch.jit.ScriptModule):
4555            def __init__(self) -> None:
4556                super().__init__()
4557                self.foo = torch.nn.Sequential(B())
4558
4559            @torch.jit.script_method
4560            def forward(self, x):
4561                for f in self.foo:
4562                    x = f(x)
4563                return x
4564        self.getExportImportCopy(C())
4565
4566    def test_serialize_long_lines(self):
4567        class OrderModuleLong(torch.nn.Module):
4568            def forward(self, long_arg_name: List[torch.Tensor]):
4569                return [(long_arg_name[1],), (long_arg_name[0].argmax(),)]
4570        src = str(torch.jit.script(OrderModuleLong()).code)
4571        # make long_arg_name[1] does not get reordered after the argmax
4572        FileCheck().check("long_arg_name[1]").check("argmax").run(src)
4573
4574    def test_tensor_shape(self):
4575        x = torch.empty(34, 56, 78)
4576
4577        def f(x):
4578            return x.shape
4579
4580        self.checkScript(f, (x,))
4581
4582
4583    def test_block_input_grad_in_loop(self):
4584
4585        x = torch.randn(3, 3, requires_grad=False)
4586        y = torch.randn(3, 3, requires_grad=True)
4587
4588        def grad_in_loop(x, y):
4589            for i in range(100):
4590                x = y @ x
4591            return x
4592
4593        scripted = torch.jit.script(grad_in_loop)
4594        outer = scripted.graph_for(x, y)
4595        loop = outer.findNode("prim::Loop")
4596        loop_block = next(loop.blocks())
4597        param_node = loop_block.paramNode()
4598        x_value = list(param_node.outputs())[1]
4599        self.assertTrue(x_value.requires_grad())
4600
4601    def test_tensor_grad(self):
4602        x = torch.randn(3, 4, requires_grad=True)
4603        y = torch.randn(3, 4, requires_grad=False)
4604
4605        def f_requires_grad(x):
4606            return x.requires_grad
4607
4608        self.checkScript(f_requires_grad, (x,))
4609        self.checkScript(f_requires_grad, (y,))
4610
4611        def f_grad(x):
4612            return x.grad
4613
4614        x.sum().backward()
4615        self.checkScript(f_grad, (x,))
4616        self.checkScript(f_grad, (y,))
4617
4618    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "shape analysis is only enabled in Legacy")
4619    def test_prim_grad_undefined(self):
4620
4621        x = torch.ones(2)
4622
4623        def f_grad(x):
4624            return x.grad
4625
4626        scripted = self.checkScript(f_grad, (x,))
4627        g = scripted.graph_for(x)
4628
4629        prim_grad_node = g.findNode("prim::grad")
4630        self.assertTrue(next(prim_grad_node.outputs()).type().undefined() is None)
4631
4632    def test_tensor_data(self):
4633        x = torch.randn(3, 4, requires_grad=True)
4634        y = torch.randn(4, 5)
4635
4636        def f_data(x):
4637            return x.data
4638
4639        scripted_f_data = torch.jit.script(f_data)
4640
4641        scripted_x = scripted_f_data(x)
4642        self.assertEqual(scripted_x, f_data(x))
4643        self.assertEqual(scripted_x.requires_grad, False)
4644
4645        scripted_y = scripted_f_data(y)
4646        self.assertEqual(scripted_y, f_data(y))
4647        self.assertEqual(scripted_x.requires_grad, False)
4648
4649    def test_tensor_dtype(self):
4650        x_byte = torch.empty(34, 56, 78, dtype=torch.uint8)
4651        x_long = torch.empty(34, 56, 78, dtype=torch.long)
4652        x_float32 = torch.empty(34, 56, 78, dtype=torch.float32)
4653
4654        @torch.jit.script
4655        def byte(x):
4656            return x.dtype == torch.uint8
4657
4658        @torch.jit.script
4659        def long(x):
4660            return x.dtype == torch.long
4661
4662        @torch.jit.script
4663        def float32(x):
4664            return x.dtype == torch.float32
4665
4666        self.assertTrue(byte(x_byte))
4667        self.assertFalse(byte(x_long))
4668        self.assertFalse(byte(x_float32))
4669        self.assertFalse(long(x_byte))
4670        self.assertTrue(long(x_long))
4671        self.assertFalse(long(x_float32))
4672        self.assertFalse(float32(x_byte))
4673        self.assertFalse(float32(x_long))
4674        self.assertTrue(float32(x_float32))
4675
4676    @unittest.skipIf(not RUN_CUDA, "device tests require CUDA")
4677    def test_tensor_device(self):
4678        cpu = torch.empty(34, 56, 78, device='cpu')
4679        gpu = torch.empty(34, 56, 78, device='cuda')
4680
4681        @torch.jit.script
4682        def same_device(x, y):
4683            return x.device == y.device
4684
4685        self.assertTrue(same_device(cpu, cpu))
4686        self.assertTrue(same_device(gpu, gpu))
4687        self.assertFalse(same_device(cpu, gpu))
4688
4689    @unittest.skipIf(not RUN_CUDA, "device tests require CUDA")
4690    def test_tensor_to_device(self):
4691        def to_device(x):
4692            return x.to(device="cuda").to(device=torch.device("cpu"))
4693
4694        self.checkScript(to_device, (torch.ones(3, 4),))
4695
4696    def test_tensor_to_cpu(self):
4697        def to_cpu(x):
4698            return x.cpu()
4699
4700        x = torch.ones(3, 4)
4701        script_fn = torch.jit.script(to_cpu)
4702        self.assertEqual(to_cpu(x).device, script_fn(x).device)
4703        self.checkScript(to_cpu, (x,))
4704
4705    @unittest.skipIf(not RUN_CUDA, "device tests require CUDA")
4706    def test_tensor_to_cuda(self):
4707        def to_cuda(x):
4708            return x.cuda()
4709
4710        x = torch.ones(3, 4)
4711        script_fn = torch.jit.script(to_cuda)
4712        self.assertEqual(to_cuda(x).device, script_fn(x).device)
4713        self.checkScript(to_cuda, (x,))
4714
4715    def test_generic_list_errors(self):
4716        with self.assertRaisesRegex(RuntimeError, "previously matched to type"):
4717            @torch.jit.script
4718            def foo(x):
4719                return [[x]] + [[1]]
4720
4721    def test_script_cu(self):
4722        cu = torch.jit.CompilationUnit('''
4723            def foo(a):
4724                b = a
4725                return b
4726        ''')
4727        a = Variable(torch.rand(1))
4728        self.assertEqual(a, cu.foo(a))
4729
4730    # because the compilation unit ingests python strings
4731    # to use an escape sequence escape the backslash (\\n = \n)
4732    def test_string_cu(self):
4733        cu = torch.jit.CompilationUnit('''
4734            def foo(a):
4735                print(a, """a\\n\tb\\n""", 2, "a\
4736a")
4737                return a
4738        ''')
4739        FileCheck().check("aa").check("a\\n\\tb\\n").run(str(cu.foo.graph))
4740
4741    def test_function_compilation_caching(self):
4742        def fun():
4743            return 1 + 2
4744
4745        fun_compiled = torch.jit.script(fun)
4746        # python wrapper around the script function is a different pointer,
4747        # but the underlying script function graph is the same
4748        self.assertIs(fun_compiled.graph, torch.jit.script(fun).graph)
4749
4750        def fun():
4751            return 3 + 4
4752
4753        num_ref_counts = sys.getrefcount(fun)
4754
4755        # caching doesn't get tripped up by same qualname
4756        fun_compiled_2 = torch.jit.script(fun)
4757        self.assertIsNot(fun_compiled, fun_compiled_2)
4758        self.assertEqual(fun_compiled_2(), 7)
4759
4760        # caching doesnt increase refcounts to function (holds weak reference)
4761        self.assertTrue(sys.getrefcount(fun), num_ref_counts)
4762
4763    def test_string_ops(self):
4764        def foo():
4765            a = "a" + "b"
4766            return a + a, "ab" == "b", "ab" != "b", "ab" == "ab", "ab" != "ab"
4767
4768        self.checkScript(foo, ())
4769
4770    def test_string_sorted(self):
4771        def foo(strs: List[str]):
4772            return sorted(strs)
4773
4774        FileCheck() \
4775            .check("graph") \
4776            .check_next("str[] = aten::sorted") \
4777            .check_next("return") \
4778            .run(str(torch.jit.script(foo).graph))
4779
4780        inputs = ["str3", "str2", "str1"]
4781        self.checkScript(foo, (inputs,))
4782
4783    def test_string_sort(self):
4784        def foo(strs: List[str]):
4785            strs.sort()
4786            return strs
4787
4788        inputs = ["str3", "str2", "str1"]
4789        self.checkScript(foo, (inputs,))
4790
4791    def test_tuple_sorted(self):
4792        def foo(tups: List[Tuple[int, int]]):
4793            return sorted(tups)
4794
4795        inputs = [(1, 2), (0, 2), (1, 3)]
4796        self.checkScript(foo, (inputs,))
4797
4798    def test_tuple_sort(self):
4799        def foo(tups: List[Tuple[int, int]]):
4800            tups.sort()
4801            return tups
4802
4803        inputs = [(1, 2), (0, 2), (1, 3)]
4804        self.checkScript(foo, (inputs,))
4805
4806    def test_tuple_sort_reverse(self):
4807        def foo(tups: List[Tuple[int, int]]):
4808            tups.sort(reverse=True)
4809            return tups
4810
4811        inputs = [(1, 2), (0, 2), (1, 3)]
4812        self.checkScript(foo, (inputs,))
4813
4814    def test_tuple_unsortable_element_type(self):
4815        @torch.jit.script
4816        def foo():
4817            tups = [({1: 2}, {2: 3})]
4818            tups.sort()
4819            return tups
4820
4821        with self.assertRaisesRegexWithHighlight(RuntimeError, "are not sortable", "tups.sort"):
4822            foo()
4823
4824    def test_tuple_unsortable_diff_type(self):
4825        @torch.jit.script
4826        def foo(inputs: List[Any]):
4827            inputs.sort()
4828            return inputs
4829
4830        inputs = [(1, 2), ("foo", "bar")]
4831        with self.assertRaisesRegexWithHighlight(RuntimeError, "Only values of same type can be compared", "inputs.sort"):
4832            foo(inputs)
4833
4834    def test_tuple_nested_sort(self):
4835        def foo(inputs: List[Tuple[int, Tuple[int, str]]]):
4836            inputs.sort()
4837            return inputs
4838
4839        inputs = [(1, (2, "foo")), (1, (2, "bar")), (1, (0, "bar"))]
4840        self.checkScript(foo, (inputs,))
4841
4842    def test_tuple_unsortable_nested_diff_type(self):
4843        @torch.jit.script
4844        def foo(inputs: List[Any]):
4845            inputs.sort()
4846            return inputs
4847
4848        inputs = [(1, (2, 3)), (2, ("foo", "bar"))]
4849        with self.assertRaisesRegexWithHighlight(RuntimeError, "Only values of same type can be compared", "inputs.sort"):
4850            foo(inputs)
4851
4852    def test_string_new_line(self):
4853        with self.assertRaisesRegex(RuntimeError, "expected a valid token*"):
4854            torch.jit.CompilationUnit('''
4855            def test_while(a):
4856                print("
4857                    a")
4858                return a
4859            ''')
4860
4861    def test_string_single_escape(self):
4862        with self.assertRaisesRegex(RuntimeError, "expected a valid token*"):
4863            torch.jit.CompilationUnit('''
4864            def test_while(a):
4865                print("\\")
4866                return a
4867            ''')
4868
4869    def test_script_annotation(self):
4870        @torch.jit.script
4871        def foo(a):
4872            return a + a + a
4873        s = Variable(torch.rand(2))
4874        self.assertEqual(s + s + s, foo(s))
4875
4876    def test_torch_pow(self):
4877        def func(a, b):
4878            return pow(a, b)
4879
4880        def func2(a, b, c, d):
4881            return pow(pow(c + a, b), d)
4882
4883        def func3(a : int, b : float):
4884            # type: (int, float) -> float
4885            return pow(a, b)
4886
4887        def func4():
4888            # type: () -> float
4889            return pow(2, -2)
4890
4891        def func5(x, y):
4892            return pow(x.item(), y.item())
4893
4894        def func6(a : int, b : int):
4895            # type: (int, int) -> float
4896            return pow(a, b)
4897
4898        a = torch.rand(1)
4899        b = torch.rand(1)
4900        c = torch.rand(1)
4901        d = torch.rand(1)
4902        self.checkScript(func, (a, b))
4903        self.checkScript(func2, (a, b, c, d))
4904        self.checkScript(func3, (4, -0.5))
4905        self.checkScript(func4, ())
4906        self.checkScript(func6, (2, 4))
4907
4908        inputs = [torch.tensor(2), torch.tensor(-2), torch.tensor(.5), torch.tensor(.2)]
4909        for x in inputs:
4910            for y in inputs:
4911                if x < 0:
4912                    continue
4913                else:
4914                    self.checkScript(func5, (x, y))
4915
4916    @unittest.skipIf(not RUN_CUDA, "device tests require CUDA")
4917    def test_pow_scalar_backward_cuda(self):
4918        # see that scalar exponent works with cuda base (#19253)
4919        with enable_profiling_mode_for_profiling_tests():
4920            for dtype in [torch.float, torch.double]:
4921                @torch.jit.script
4922                def func(a, b):
4923                    # type: (Tensor, float) -> Tensor
4924                    return (a * 2) ** b
4925
4926                a = torch.rand(1, requires_grad=True, device='cuda', dtype=dtype)
4927                func(a, 1, profile_and_replay=True).backward()
4928
4929                @torch.jit.script
4930                def func(a, b):
4931                    # type: (float, Tensor) -> Tensor
4932                    return a ** (b * 2 + 1)
4933
4934                a = torch.rand(1, requires_grad=True, device='cuda', dtype=dtype)
4935                func(2, a, profile_and_replay=True).backward()
4936
4937    def _check_code(self, code_str, fn_name, inputs):
4938        scope = {}
4939        exec(code_str, globals(), scope)
4940        cu = torch.jit.CompilationUnit(code_str)
4941        self.assertEqual(cu.func(*inputs), scope[fn_name](*inputs))
4942
4943    @unittest.skipIf(not RUN_CUDA, 'no CUDA')
4944    def test_scriptmodule_releases_tensors_cuda(self):
4945        with enable_profiling_mode_for_profiling_tests():
4946            @torch.jit.script
4947            def fn(x, y):
4948                return x.sigmoid() * y.tanh()
4949
4950            def test(backward=False):
4951                x = torch.randn(3, 3, dtype=torch.double, device='cuda', requires_grad=True)
4952                y = torch.randn(3, 3, dtype=torch.double, device='cuda', requires_grad=True)
4953                out = fn(x, y, profile_and_replay=True)
4954                if backward:
4955                    out.sum().backward()
4956
4957            with self.assertLeaksNoCudaTensors():
4958                test()
4959                test()
4960                test()
4961
4962            if GRAPH_EXECUTOR != ProfilingMode.SIMPLE:
4963                with self.assertLeaksNoCudaTensors():
4964                    test(backward=True)
4965                    test(backward=True)
4966                    test(backward=True)
4967
4968    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
4969    def test_index(self):
4970        def consec(size, start=0):
4971            numel = torch.tensor(size).prod().item()
4972            return torch.arange(numel).view(size)
4973
4974        def consec_list(size):
4975            return list(range(size))
4976
4977        def random_string(size):
4978            letters = string.ascii_lowercase
4979            return "".join(random.choice(letters) for i in range(size))
4980
4981        def check_indexing(indexing, tensor):
4982            template = dedent("""
4983            def func(x):
4984                return x{}
4985            """)
4986
4987            self._check_code(template.format(indexing), "func", [tensor])
4988
4989        def check_dynamic_indexing(indexing, tensor, value1, value2):
4990            value1 = torch.tensor(value1)
4991            value2 = torch.tensor(value2)
4992
4993            template = dedent("""
4994            def func(x, value1, value2):
4995                i = int(value1)
4996                j = int(value2)
4997                return x{}
4998            """)
4999
5000            self._check_code(template.format(indexing), "func", [tensor, value1, value2])
5001
5002        # Torchscript assumes type Tensor by default, so we need this explicit
5003        # declaration.
5004        def check_indexing_list_int(indexing, list):
5005            template = dedent("""
5006            def func(x):
5007                # type: (List[int]) -> Any
5008                return x{}
5009            """)
5010
5011            self._check_code(template.format(indexing), "func", [list])
5012
5013        def check_indexing_str(indexing, str):
5014            template = dedent("""
5015            def func(x):
5016                # type: (str) -> Any
5017                return x{}
5018            """)
5019
5020            self._check_code(template.format(indexing), "func", [str])
5021
5022        # basic slices
5023        check_indexing('[0]', consec((3, 3)))
5024        check_indexing('[1]', consec((3, 3), 10))
5025        check_indexing('[2]', consec((3, 3), 19))
5026        check_indexing('[2]', consec((3,)))
5027        check_indexing('[-1]', consec((3, 3), 19))
5028        check_indexing('[0:2]', consec((3, 3, 3)))
5029        check_indexing('[1:-1]', consec((3, 3, 3)))
5030        check_indexing('[-3:-1]', consec((6, 3)))
5031        check_indexing('[1:]', consec((3, 3)))
5032        check_indexing('[:1]', consec((3, 3)))
5033        check_indexing('[:]', consec((3, 2)))
5034
5035        # multi-dim: indexes
5036        check_indexing('[0, 1]', consec((3, 3)))
5037        check_indexing('[0, 1]', consec((3, 3, 2)))
5038        check_indexing('[1, 0, 2]', consec((3, 3, 3)))
5039        check_indexing('[2, -1]', consec((3, 3)))
5040
5041        # multi-dim: mixed slicing and indexing
5042        check_indexing('[0, 1:2]', consec((3, 3)))
5043        check_indexing('[0, :1]', consec((3, 3, 2)))
5044        check_indexing('[1, 2:]', consec((3, 3, 3)))
5045        check_indexing('[-1, 1:, 0]', consec((3, 3, 3, 3)))
5046        check_indexing('[1:, -1, 0]', consec((3, 3, 3, 3)))
5047        check_indexing('[-1, 2:, 1:2]', consec((3, 3, 3, 3)))
5048        check_indexing('[-1, 1:, 0]', consec((3, 3, 3, 3)))
5049        check_indexing('[-1, :, 0, 2]', consec((3, 3, 3, 3)))
5050
5051        # zero-sized slices
5052        check_indexing('[0:0]', consec((2, 2)))
5053        check_indexing('[0:0, 1]', consec((3, 3)))
5054
5055        # trivial expression usage
5056        check_indexing('[1+1]', consec((3, 3)))
5057        check_indexing('[1:(0 + 2)]', consec((3, 3, 3)))
5058
5059        # None for new dimensions
5060        check_indexing('[None, 0]', consec((3, 3)))
5061        check_indexing('[1, None]', consec((3, 3), 10))
5062        check_indexing('[None, None, 2]', consec((3, 3), 19))
5063        check_indexing('[None, 2, None]', consec((3,)))
5064        check_indexing('[0:2, None]', consec((3, 3, 3)))
5065        check_indexing('[None, 1:-1]', consec((3, 3, 3)))
5066        check_indexing('[None, -3:-1, None]', consec((6, 3)))
5067        check_indexing('[-1, None, 2:, None, 1:2]', consec((3, 3, 3, 3)))
5068        check_indexing('[None, -1, None, 2:, None, 1:2, None]', consec((3, 3, 3, 3)))
5069
5070        # dynamic expression usage
5071        check_dynamic_indexing("[i + j]", consec((3, 3)), 0, 1)
5072        check_dynamic_indexing("[i:j, i]", consec((3, 3, 2)), 0, 2)
5073
5074        # positive striding
5075        check_indexing_list_int('[0]', consec_list(6))
5076        check_indexing_list_int('[1]', consec_list(7))
5077        check_indexing_list_int('[2]', consec_list(8))
5078        check_indexing_list_int('[2]', consec_list(9))
5079        check_indexing_list_int('[-1]', consec_list(10))
5080        check_indexing_list_int('[0:2]', consec_list(11))
5081        check_indexing_list_int('[1:-1]', consec_list(12))
5082        check_indexing_list_int('[-3:-1]', consec_list(13))
5083        check_indexing_list_int('[1:]', consec_list(15))
5084        check_indexing_list_int('[:1]', consec_list(16))
5085        check_indexing_list_int('[:]', consec_list(17))
5086        check_indexing_list_int('[::]', consec_list(0))
5087        check_indexing_list_int('[1000::]', consec_list(0))
5088        check_indexing_list_int('[:1000:]', consec_list(0))
5089
5090        # negative striding
5091        check_indexing_list_int('[::-1]', consec_list(7))
5092        check_indexing_list_int('[:3:-1]', consec_list(7))
5093        check_indexing_list_int('[3::-1]', consec_list(7))
5094        check_indexing_list_int('[1000::-1]', consec_list(7))
5095        check_indexing_list_int('[3:0:-1]', consec_list(7))
5096        check_indexing_list_int('[3:-1000:-1]', consec_list(7))
5097        check_indexing_list_int('[0:0:-1]', consec_list(7))
5098        check_indexing_list_int('[0:-1000:-1]', consec_list(7))
5099
5100        # only step is specified
5101        check_indexing_list_int('[::-1]', consec_list(0))
5102        check_indexing_list_int('[::-1]', consec_list(7))
5103        check_indexing_list_int('[::-2]', consec_list(7))
5104        check_indexing_list_int('[::2]', consec_list(7))
5105        check_indexing_list_int('[::42]', consec_list(7))
5106        check_indexing_list_int('[::-42]', consec_list(7))
5107        check_indexing_list_int('[::42]', consec_list(0))
5108        check_indexing_list_int('[::-42]', consec_list(0))
5109        check_indexing_list_int('[::9223372036854775807]', consec_list(42))
5110        check_indexing_list_int('[::-9223372036854775807]', consec_list(42))
5111        with self.assertRaisesRegex(RuntimeError, "out of bounds"):
5112            check_indexing_list_int('[::-9223372036854775808]', consec_list(42))
5113        with self.assertRaisesRegex(RuntimeError, "should have non-zero step"):
5114            check_indexing_list_int('[::0]', consec_list(42))
5115
5116        # striding strings
5117        check_indexing_str('[0]', random_string(6))
5118        check_indexing_str('[1]', random_string(7))
5119        check_indexing_str('[2]', random_string(8))
5120        check_indexing_str('[2]', random_string(9))
5121        check_indexing_str('[-1]', random_string(10))
5122        check_indexing_str('[0:2]', random_string(11))
5123        check_indexing_str('[1:-1]', random_string(12))
5124        check_indexing_str('[-3:-1]', random_string(13))
5125        check_indexing_str('[1:]', random_string(15))
5126        check_indexing_str('[:1]', random_string(16))
5127        check_indexing_str('[:]', random_string(17))
5128        check_indexing_str('[::]', random_string(0))
5129        check_indexing_str('[1000::]', random_string(0))
5130        check_indexing_str('[:1000:]', random_string(0))
5131
5132        check_indexing_str('[::-1]', random_string(7))
5133        check_indexing_str('[:3:-1]', random_string(7))
5134        check_indexing_str('[3::-1]', random_string(7))
5135        check_indexing_str('[1000::-1]', random_string(7))
5136        check_indexing_str('[3:0:-1]', random_string(7))
5137        check_indexing_str('[3:-1000:-1]', random_string(7))
5138        check_indexing_str('[0:0:-1]', random_string(7))
5139        check_indexing_str('[0:-1000:-1]', random_string(7))
5140
5141        check_indexing_str('[::-1]', random_string(0))
5142        check_indexing_str('[::-1]', random_string(7))
5143        check_indexing_str('[::-2]', random_string(7))
5144        check_indexing_str('[::2]', random_string(7))
5145        check_indexing_str('[::42]', random_string(7))
5146        check_indexing_str('[::-42]', random_string(7))
5147        check_indexing_str('[::42]', random_string(0))
5148        check_indexing_str('[::-42]', random_string(0))
5149        check_indexing_str('[::9223372036854775807]', random_string(42))
5150        check_indexing_str('[::-9223372036854775807]', random_string(42))
5151        with self.assertRaisesRegex(RuntimeError, "out of bounds"):
5152            check_indexing_str('[::-9223372036854775808]', random_string(42))
5153        with self.assertRaisesRegex(RuntimeError, "should have non-zero step"):
5154            check_indexing_str('[::0]', random_string(42))
5155
5156    def test_module_copy_with_attributes(self):
5157        class Vocabulary(torch.jit.ScriptModule):
5158            def __init__(self, vocab_list):
5159                super().__init__()
5160                self._vocab = torch.jit.Attribute(vocab_list, List[str])
5161                self.some_idx = torch.jit.Attribute(2, int)
5162                self.idx = torch.jit.Attribute(
5163                    {word: i for i, word in enumerate(vocab_list)}, Dict[str, int]
5164                )
5165
5166            @torch.jit.script_method
5167            def lookup_indices_1d(self, values):
5168                # type: (List[str]) -> List[int]
5169                result = torch.jit.annotate(List[int], [])
5170                # Direct list iteration not supported
5171                for i in range(len(values)):
5172                    value = values[i]
5173                    result.append(self.idx.get(value, self.some_idx))
5174                return result
5175
5176            @torch.jit.script_method
5177            def forward(self, values):
5178                # type: (List[List[str]]) -> List[List[int]]
5179                result = torch.jit.annotate(List[List[int]], [])
5180                # Direct list iteration not supported
5181                for i in range(len(values)):
5182                    result.append(self.lookup_indices_1d(values[i]))
5183                return result
5184
5185        v = Vocabulary(list('uabcdefg'))
5186        v.__copy__()
5187
5188    def test_tuple_to_opt_list(self):
5189        @torch.jit.script
5190        def foo(x):
5191            # type: (Optional[List[int]]) -> int
5192            return 1
5193
5194        @torch.jit.script
5195        def tuple_call():
5196            return foo((1, 2))
5197
5198    def test_keyword(self):
5199        @torch.jit.script
5200        def func(x):
5201            return torch.sum(x, dim=0)
5202
5203        x = torch.rand(10, dtype=torch.float, requires_grad=True)
5204        y = func(x)
5205        y2 = torch.sum(x, dim=0)
5206        self.assertEqual(y, y2)
5207
5208    def test_constant_pooling_none(self):
5209        @torch.jit.script
5210        def typed_nones(a=None, b=None, c=None):
5211            # type: (Optional[int], Optional[bool], Optional[Tensor]) -> Tuple[Optional[int], Optional[bool], Optional[Tensor]]
5212            return a, b, c
5213
5214        @torch.jit.script
5215        def test(a):
5216            # type: (bool) -> None
5217            if a:
5218                print(typed_nones())
5219            else:
5220                print(typed_nones())
5221
5222        graph_str = str(test.graph)
5223        self.assertTrue(graph_str.count("NoneType = prim::Constant") == 1)
5224
5225    def test_constant_pooling_same_identity(self):
5226        def foo():
5227            a = torch.tensor([4])
5228            b = (a,)
5229            index = len(a) - 1
5230            c = b[index]
5231            d = b[index]
5232            return c, d
5233
5234        foo_script = torch.jit.script(foo)
5235        self.run_pass('constant_propagation', foo_script.graph)
5236        self.run_pass('constant_pooling', foo_script.graph)
5237        # even though the c & d escape scope, we are still able
5238        # pool them into one constant because they are the same object
5239        FileCheck().check_count("prim::Constant", 1, exactly=True).run(foo_script.graph)
5240        self.assertEqual(foo(), foo_script())
5241
5242    def test_constant_pooling_introduce_aliasing(self):
5243        @torch.jit.script
5244        def foo():
5245            a = torch.tensor(1)
5246            b = torch.tensor(1)
5247            return a, b
5248
5249        self.run_pass('constant_propagation', foo.graph)
5250        self.run_pass('constant_pooling', foo.graph)
5251        # dont pool constants bc it would introduce observable alias relationship changing
5252        a, b = foo()
5253        self.assertIsNot(a, b)
5254
5255    def test_literal(self):
5256        def func1(a, b):
5257            c = a, b
5258            d, e = c
5259            return d + e
5260
5261        def func2(a, b):
5262            c = a, (a, b)
5263            d, e = c
5264            f, g = e
5265            return d + f + g
5266
5267        def func3(a, b):
5268            # type: (float, float) -> float
5269            c = 0., (0., 0.)
5270            x = True
5271            while x:
5272                x = False
5273                c = a, (a, b)
5274            d, e = c
5275            f, g = e
5276            return d + f + g
5277
5278        a = torch.rand(1, requires_grad=True)
5279        b = torch.rand(1, requires_grad=True)
5280        self.checkScript(func1, (a, b), optimize=True)
5281        self.checkScript(func2, (a, b), optimize=True)
5282        self.checkScript(func3, (a.item(), b.item()), optimize=True)
5283
5284    def test_expand(self):
5285        @torch.jit.script
5286        def func(x, y):
5287            return x + y
5288
5289        x = torch.rand(2, 3, dtype=torch.float, requires_grad=True)
5290        y = torch.rand(3, dtype=torch.float, requires_grad=True)
5291        out = func(x, y)
5292        self.assertEqual(func(x, y), x + y)
5293
5294        grad = torch.randn(2, 3, dtype=torch.float)
5295        out.backward(grad)
5296        self.assertEqual(x.grad, grad)
5297        self.assertEqual(y.grad, grad.sum(dim=0))
5298
5299    def test_sum(self):
5300        @torch.jit.script
5301        def func(x):
5302            return x.sum(dim=[4])
5303
5304        @torch.jit.script
5305        def func2(x):
5306            return x.sum(dim=4)
5307
5308        # test that shape analysis is written correctly for sum with OptionalIntArrayRef[1] dim argument
5309        self.run_pass('constant_propagation', func.graph)
5310        self.run_pass('constant_propagation', func2.graph)
5311        g = _propagate_shapes(func.graph, (torch.zeros(1, 1, 1, 1, 4),), False)
5312        g2 = _propagate_shapes(func2.graph, (torch.zeros(1, 1, 1, 1, 4),), False)
5313
5314    def test_cat(self):
5315        with enable_profiling_mode_for_profiling_tests():
5316            @torch.jit.script
5317            def func(x):
5318                return torch.cat((x, x), dim=0)
5319
5320            x = torch.rand(10, dtype=torch.float, requires_grad=True)
5321            self.assertEqual(func(x, profile_and_replay=True), torch.cat((x, x), dim=0))
5322
5323            @torch.jit.script
5324            def func2(x, y):
5325                return torch.cat((x, x), y)
5326
5327            with disable_autodiff_subgraph_inlining():
5328                for sizes in ((2, 2), (0, 2)):
5329                    x = torch.rand(sizes).requires_grad_()
5330                    y = torch.tensor(1)
5331
5332                    output = func2(x, y, profile_and_replay=True)
5333                    output_ref = torch.cat((x, x), y)
5334                    self.assertEqual(output, output_ref)
5335
5336                    if GRAPH_EXECUTOR != ProfilingMode.SIMPLE:
5337                        self.assertAutodiffNode(func2.graph_for(x, y), True, ['aten::cat'], [])
5338
5339                        grad = torch.autograd.grad(output.sum(), x)
5340                        grad_ref = torch.autograd.grad(output_ref.sum(), x)
5341                        self.assertEqual(grad, grad_ref)
5342
5343    def test_cat_lifts(self):
5344        @torch.jit.script
5345        def foo(x):
5346            return torch.cat([x, x], dim=1)
5347
5348        @torch.jit.script
5349        def foo2(x):
5350            return torch.cat([], dim=1)
5351
5352        @torch.jit.script
5353        def foo3(x):
5354            return torch.cat([x], dim=1)
5355
5356        for g in [foo.graph, foo2.graph, foo3.graph]:
5357            FileCheck().check("int =").check("ListConstruct").check("aten::cat").run(str(g))
5358
5359    def test_stack(self):
5360        with enable_profiling_mode_for_profiling_tests():
5361            @torch.jit.script
5362            def func(x):
5363                return torch.stack((x, x), dim=1)
5364            x = torch.rand(10, 10)
5365            self.assertEqual(func(x, profile_and_replay=True), torch.stack((x, x), dim=1))
5366
5367            @torch.jit.script
5368            def func2(x, y):
5369                return torch.stack((x, y), dim=0)
5370
5371            with disable_autodiff_subgraph_inlining():
5372                x = torch.randn([2, 2]).requires_grad_()
5373                y = torch.randn([2, 2]).requires_grad_()
5374
5375                output = func2(x, y, profile_and_replay=True)
5376                output_ref = torch.stack((x, y), 0)
5377                self.assertEqual(output, output_ref)
5378                if GRAPH_EXECUTOR != ProfilingMode.SIMPLE:
5379                    self.assertAutodiffNode(func2.graph_for(x, y), True, ['aten::stack'], [])
5380
5381                    grads = torch.autograd.grad(output.sum(), (x, y))
5382                    grads_ref = torch.autograd.grad(output_ref.sum(), (x, y))
5383                    self.assertEqual(grads, grads_ref)
5384
5385    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY,
5386                     "Profiling executor will be using different heuristics for constructing differentiable graphs")
5387    def test_unbind(self):
5388        with enable_profiling_mode_for_profiling_tests():
5389            @torch.jit.script
5390            def func(x, y):
5391                # type: (Tensor, int) -> List[Tensor]
5392                return torch.unbind(x, y)
5393
5394            with disable_autodiff_subgraph_inlining():
5395                x = torch.rand([2, 2]).requires_grad_()
5396                y = 0
5397                outputs = func(x, y, profile_and_replay=True)
5398                outputs_ref = torch.unbind(x, dim=y)
5399                self.assertEqual(outputs, outputs_ref)
5400                self.assertAutodiffNode(func.graph_for(x, y), True, [], [])
5401
5402                grad = torch.autograd.grad(_sum_of_list(outputs), x)
5403                grad_ref = torch.autograd.grad(_sum_of_list(outputs_ref), x)
5404                self.assertEqual(grad, grad_ref)
5405
5406
5407    @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.PROFILING,
5408                     "Profiling executor fails to recognize that tensors in a list require gradients")
5409    def test_meshgrid(self):
5410        with enable_profiling_mode_for_profiling_tests():
5411            @torch.jit.script
5412            def func(a):
5413                # type: (List[Tensor]) -> List[Tensor]
5414                return torch.meshgrid(a)
5415            with disable_autodiff_subgraph_inlining():
5416                a = torch.tensor([1.0, 2, 3]).requires_grad_()
5417                b = torch.tensor([1.0, 2, 3, 4]).requires_grad_()
5418                inputs = [a, b]
5419
5420                outputs_ref = torch.meshgrid(inputs)
5421                outputs = func(inputs, profile_and_replay=True)
5422                self.assertEqual(outputs, outputs_ref)
5423
5424                if GRAPH_EXECUTOR != ProfilingMode.SIMPLE:
5425                    self.assertAutodiffNode(func.graph_for(inputs), True, [], [])
5426
5427                    grads = torch.autograd.grad(_sum_of_list(outputs), inputs)
5428                    grads_ref = torch.autograd.grad(_sum_of_list(outputs_ref), inputs)
5429                    self.assertEqual(grads, grads_ref)
5430
5431    def test_tensor_len(self):
5432        def func(x):
5433            return len(x)
5434
5435        self.checkScript(func, [torch.ones(4, 5, 6)])
5436
5437    def test_func_call(self):
5438        def add(a, b):
5439            return a + b
5440
5441        def mul(a, x):
5442            return a * x
5443
5444        def func(alpha, beta, x, y):
5445            return add(mul(alpha, x), mul(beta, y))
5446
5447        alpha = torch.rand(1, dtype=torch.float, requires_grad=True)
5448        beta = torch.rand(1, dtype=torch.float, requires_grad=True)
5449        x = torch.rand(3, dtype=torch.float, requires_grad=True)
5450        y = torch.rand(3, dtype=torch.float, requires_grad=True)
5451
5452        # NOTE: cannot optimize yet because broadcasts are not inserted before the fuser runs
5453        self.checkScript(func, [alpha, beta, x, y], optimize=False)
5454
5455    @unittest.skip("bailouts are being deprecated")
5456    def test_profiling_graph_executor(self):
5457        @torch.jit.script
5458        def def_in_one_branch(x, z):
5459            # type: (Tensor, bool) -> float
5460            y = x
5461            if z is False:
5462                y = x + 1
5463
5464            return y.sum()
5465
5466        a = torch.rand(2, 3)
5467
5468        with enable_profiling_mode_for_profiling_tests():
5469            # check prim::profile are inserted
5470            profiled_graph_str = str(def_in_one_branch.graph_for(a, True))
5471            FileCheck().check_count("prim::profile", 4).run(profiled_graph_str)
5472            # this call is optimized for
5473            # the given shape of (2, 3)
5474            def_in_one_branch(a, False)
5475            # change shape to (3)
5476            # so we go down a bailout path
5477            a = torch.ones(3)
5478            # check prim::BailOuts are inserted
5479            bailout_graph_str = str(def_in_one_branch.graph_for(a, True))
5480            FileCheck().check_count("prim::BailOut", 3).run(bailout_graph_str)
5481            # this triggers all 3 bailouts
5482            self.assertEqual(def_in_one_branch(a, False), 6.0)
5483            # this triggers 2 bailouts
5484            self.assertEqual(def_in_one_branch(a, True), 3.0)
5485
5486    @unittest.skip("bailouts are being deprecated")
5487    def test_maxpool_guard_elimination(self):
5488        @torch.jit.script
5489        def my_maxpool(x):
5490            return F.max_pool1d(x, kernel_size=[1]) + torch.ones([32, 32, 32])
5491
5492        a = torch.rand(32, 32, 32)
5493
5494        with enable_profiling_mode_for_profiling_tests():
5495            my_maxpool(a)
5496            bailout_graph_str = str(my_maxpool.graph_for(a))
5497            FileCheck().check_count("prim::BailOut", 1).run(bailout_graph_str)
5498
5499    @unittest.skip("bailouts are being deprecated")
5500    def test_slice_guard_elimination(self):
5501        @torch.jit.script
5502        def my_slice(x):
5503            return x[0:16:2] + x[0:16:2]
5504
5505        a = torch.rand(32, 4)
5506
5507        with enable_profiling_mode_for_profiling_tests():
5508            my_slice(a)
5509            bailout_graph_str = str(my_slice.graph_for(a))
5510            FileCheck().check_count("prim::BailOut", 1).run(bailout_graph_str)
5511
5512    @unittest.skip("bailouts are being deprecated")
5513    def test_unsqueeze_guard_elimination(self):
5514        @torch.jit.script
5515        def my_unsqueeze(x):
5516            return torch.unsqueeze(x, 0) + torch.unsqueeze(x, 0)
5517
5518        a = torch.rand(32, 4)
5519
5520        with enable_profiling_mode_for_profiling_tests():
5521            my_unsqueeze(a)
5522            bailout_graph_str = str(my_unsqueeze.graph_for(a))
5523            FileCheck().check_count("prim::BailOut", 2).run(bailout_graph_str)
5524
5525    def test_resize_input_ops(self):
5526        # resize_ and resize_as resize the input tensor. because our shape analysis
5527        # is flow invariant, we set any Tensor that can alias a resized Tensor
5528        # to the base Tensor Type, without size information.
5529
5530        # testing that value which is an input of a graph gets handled
5531        def out_op_graph_input():
5532            @torch.jit.script
5533            def test(x, y, z):
5534                torch.mul(x, y, out=z)
5535                return z
5536
5537            graph = _propagate_shapes(test.graph,
5538                                      (torch.zeros(2, 1), torch.zeros(1, 2), torch.zeros(1, 1, 1)), False)
5539            self.assertTrue(next(graph.outputs()).type() == TensorType.get())
5540        out_op_graph_input()
5541
5542        def test_resize():
5543            @torch.jit.script
5544            def test(x):
5545                after_resize_alias = torch.zeros([2])
5546                for _i in range(5):
5547                    b = x + 1
5548                    f = [1]
5549                    before_resize_alias = b.sub_(1)
5550                    # for i in range(10):
5551                    f.append(1)
5552                    b.resize_(f)
5553                    after_resize_alias = b.add_(1)
5554                return after_resize_alias
5555
5556            self.run_pass('constant_propagation', test.graph)
5557            g = _propagate_shapes(test.graph, (torch.zeros(1, 1),), False)
5558            resize_node = g.findNode("aten::resize_")
5559            # first input and output of b.resize_ is b
5560            self.assertTrue(next(resize_node.inputs()).type() == TensorType.get())
5561            self.assertTrue(next(resize_node.outputs()).type() == TensorType.get())
5562
5563            # correctly propagates to b alias set
5564            before_resize = g.findNode("aten::sub_")
5565            self.assertTrue(next(before_resize.outputs()).type() == TensorType.get())
5566
5567            after_resize = g.findNode("aten::add_")
5568            self.assertTrue(next(after_resize.outputs()).type() == TensorType.get())
5569
5570        test_resize()
5571
5572        def test_resize_as():
5573            @torch.jit.script
5574            def test(x):
5575                b = torch.zeros([2, 2])
5576                b.resize_as_(x)
5577                return b
5578
5579            g = test.graph
5580            self.run_pass('constant_propagation', g)
5581            g = _propagate_shapes(test.graph, (torch.zeros(1, 1),), False)
5582
5583            # x doesn't alias a resized op so it shouldn't be set to base Tensor type
5584            self.assertTrue(next(g.inputs()).type() != TensorType.get())
5585            # return is resized
5586            self.assertTrue(next(g.outputs()).type() == TensorType.get())
5587
5588        test_resize_as()
5589
5590    def test_uninitialized(self):
5591        graph_str = """graph():
5592          %1 : int = prim::Uninitialized()
5593          %2 : int = prim::Constant[value=1]()
5594          %3 : int = aten::add(%1, %2)
5595          return (%3)
5596        """
5597        g = parse_ir(graph_str)
5598        m = self.createFunctionFromGraph(g)
5599        self.getExportImportCopy(m)
5600        with self.assertRaisesRegex(RuntimeError, "expected int"):
5601            m()
5602
5603
5604    @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't use requires_grad information")
5605    @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.PROFILING, "Peeling is now disabled")
5606    def test_requires_grad_loop(self):
5607        @torch.jit.script
5608        def test(x, y, z):
5609            # type: (Tensor, Tensor, int) -> Tensor
5610            for _ in range(z):
5611                x = y
5612            return x
5613
5614        # x requires grad, y does not
5615        # testing that requires grad analysis correctly exits, with its input
5616        # to the loop (x) requiring grad and its output to the loop not requiring grad
5617        # and the output of the node conservatively setting grad to true
5618
5619        inps = (torch.tensor(1.0, requires_grad=True), torch.tensor(1), 10)
5620        test(*inps, profile_and_replay=True)
5621
5622        graph = test.graph_for(*inps)
5623        loop = graph.findNode("prim::Loop")
5624        loop_body = next(loop.blocks())
5625        loop_inputs = list(loop_body.inputs())
5626        loop_outputs = list(loop_body.outputs())
5627
5628        if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
5629            # TODO: simplify this test as it's very sensitive
5630            # the optimized graph will have 3 loops
5631            # the original loop is peeled
5632            # peeled loop also gets unrolled
5633            index_of_x_in_peeled_unrolled_loop = -2
5634            self.assertTrue(loop_inputs[index_of_x_in_peeled_unrolled_loop].requires_grad())
5635            bailouts_in_outer_block = graph.findAllNodes("prim::BailOut", False)
5636            last_bailout_index_on_loops_output = -1
5637            self.assertFalse(bailouts_in_outer_block[last_bailout_index_on_loops_output].output().requires_grad())
5638        else:
5639            self.assertTrue(loop_inputs[1].requires_grad())
5640            self.assertTrue(loop.output().requires_grad())
5641            self.assertFalse(loop_outputs[1].requires_grad())
5642
5643    def test_view_shape_prop(self):
5644        cu = torch.jit.CompilationUnit('''
5645        def test_view_shape_prop(a):
5646            return a.view(size=[-1])
5647        ''')
5648        inputs = [torch.zeros(10, 10)]
5649        outputs = torch.zeros(100)
5650
5651        real_outs = cu.test_view_shape_prop(*inputs)
5652        self.assertEqual(real_outs, outputs)
5653
5654    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
5655    def test_view_listconstruct_shape_prop(self):
5656        def fn(x):
5657            B = x.size(0)
5658            C = x.size(1)
5659            T = x.size(2)
5660            return x.view(T, B, C)
5661
5662        x = torch.randn(3, 1, 5, requires_grad=True)
5663        fn = torch.jit.script(fn)
5664        graph = _propagate_shapes(fn.graph, (x,), False)
5665        self.assertTrue(next(graph.outputs()).type().scalarType() == 'Float')
5666
5667    def test_shape_prop_promotion(self):
5668        @torch.jit.script
5669        def fn(x, y):
5670            return x + y
5671
5672        x, y = torch.rand(3, 4, dtype=torch.float), torch.rand(3, 4, dtype=torch.double)
5673        graph = _propagate_shapes(fn.graph, (x, y), False)
5674        FileCheck().check('Double(*, *, device=cpu) = aten::add').run(graph)
5675
5676    def test_shape_prop_promote_scalar_arg(self):
5677        @torch.jit.script
5678        def fn(x):
5679            return math.pi + x
5680
5681        x = torch.zeros(3, 4, dtype=torch.long)
5682        graph = _propagate_shapes(fn.graph, (x,), False)
5683        default = torch.get_default_dtype()
5684        if default == torch.float:
5685            FileCheck().check('Float(*, *, requires_grad=0, device=cpu) = aten::add').run(graph)
5686        else:
5687            FileCheck().check('Double(*, *, requires_grad=0, device=cpu) = aten::add').run(graph)
5688
5689    def test_integral_shape_inference(self):
5690        cu = torch.jit.CompilationUnit('''
5691        def test_integral_shape_inference(a):
5692            return a * a
5693        ''')
5694        inputs = [torch.ones(10, 10, dtype=torch.long)]
5695        outputs = torch.ones(10, 10, dtype=torch.long)
5696
5697        self.assertEqual(cu.test_integral_shape_inference(*inputs), outputs)
5698
5699    @unittest.skipIf(RUN_CUDA, 'This tests the CPU fuser')
5700    @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle")
5701    @enable_cpu_fuser
5702    def test_batchnorm_fuser_cpu(self):
5703        code = '''
5704            graph(%3 : Tensor,
5705                  %7 : Tensor,
5706                  %12 : Float(*, *),
5707                  %13 : Tensor,
5708                  %25 : Tensor):
5709                %23 : int = prim::Constant[value=1]()
5710                %22 : float = prim::Constant[value=1e-05]()
5711                %26 : Tensor = aten::sqrt(%25)
5712                %24 : Tensor = aten::add(%26, %22, %23)
5713                %20 : Tensor = aten::reciprocal(%24)
5714                %norm_invstd : Tensor = aten::mul(%20, %23)
5715                %15 : Tensor = aten::sub(%12, %13, %23)
5716                %11 : Tensor = aten::mul(%15, %norm_invstd)
5717                %8 : Tensor = aten::mul(%11, %7)
5718                %5 : Tensor = aten::add(%8, %3, %23)
5719                %1 : Float(*, *) = aten::relu(%5)
5720                return (%1)
5721        '''
5722
5723        graph = parse_ir(code)
5724        inputs = 5 * [torch.rand(26, 2048, dtype=torch.float)]
5725        code = torch._C._jit_fuser_get_fused_kernel_code(graph, inputs)
5726        FileCheck().check('sqrtf').run(code)
5727
5728    @slowTest
5729    @unittest.skipIf(RUN_CUDA, 'This tests the CPU fuser')
5730    @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle")
5731    @enable_cpu_fuser
5732    def test_fuser_double_float_codegen(self):
5733        fns = ['log', 'log10', 'log1p', 'log2', 'lgamma', 'exp', 'expm1', 'erf',
5734               'erfc', 'cos', 'acos', 'cosh', 'sin', 'asin', 'sinh', 'tan',
5735               'atan', 'tanh', 'sqrt', 'ceil', 'floor', 'round', 'trunc',
5736               'frac']
5737
5738        def lookup_c_equivalent_fn(aten_fn):
5739            return aten_fn
5740
5741        def test_dispatch(op, expects, dtype, binary=False):
5742            if dtype == torch.double:
5743                dtype_str = 'Double'
5744            elif dtype == torch.float:
5745                dtype_str = 'Float'
5746            else:
5747                raise RuntimeError('Unknown dtype')
5748
5749            if binary:
5750                code = f'''
5751                    graph(%3 : Tensor, %4 : Tensor):
5752                        %2 : {dtype_str}(*, *) = aten::{op}(%3, %4)
5753                        %1 : {dtype_str}(*, *) = aten::relu(%2)
5754                        return (%1)
5755                '''
5756            else:
5757                code = f'''
5758                    graph(%3 : Tensor):
5759                        %2 : {dtype_str}(*, *) = aten::{op}(%3)
5760                        %1 : {dtype_str}(*, *) = aten::relu(%2)
5761                        return (%1)
5762                '''
5763
5764            graph = parse_ir(code)
5765            inputs = (2 if binary else 1) * [torch.rand(26, 2048, dtype=dtype)]
5766            code = torch._C._jit_fuser_get_fused_kernel_code(graph, inputs)
5767            FileCheck().check(expects).run(code)
5768
5769        for fn in fns:
5770            test_dispatch(fn, lookup_c_equivalent_fn(fn) + '(', torch.double)
5771            test_dispatch(fn, lookup_c_equivalent_fn(fn) + 'f(', torch.float)
5772
5773        # 'min', 'max' were previously tested but are now replaced with ternary expressions
5774        # instead of fmin() and fmax()
5775        binary_fns = ['pow']
5776        for fn in binary_fns:
5777            test_dispatch(fn, lookup_c_equivalent_fn(fn) + '(', torch.double, binary=True)
5778            test_dispatch(fn, lookup_c_equivalent_fn(fn) + 'f(', torch.float, binary=True)
5779
5780    @unittest.skipIf(RUN_CUDA, 'This tests the CPU fuser')
5781    @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle")
5782    @enable_cpu_fuser
5783    def test_fuser_double_literal_precision(self):
5784        code = '''
5785        graph(%2 : Float(*, *)):
5786            %4 : int = prim::Constant[value=1]()
5787            %3 : float = prim::Constant[value=1.282549830161864]()
5788            %5 : Float(*, *) = aten::add(%2, %3, %4)
5789            %1 : Float(*, *) = aten::relu(%5)
5790            return (%1)
5791        '''
5792
5793        graph = parse_ir(code)
5794        code = torch._C._jit_fuser_get_fused_kernel_code(graph, [torch.rand(3, 4)])
5795        FileCheck().check('1.282549830161864').run(code)
5796
5797    def test_fuser_multiple_blocks(self):
5798        cu = torch.jit.CompilationUnit('''
5799        def test_fuser_multiple_blocks(this, that, theother, meme):
5800            i = 0
5801            while i < 20:
5802                this = torch.cat([this, meme], dim=0)
5803                that = torch.cat([that, meme], dim=0)
5804                theother = torch.cat([theother, meme], dim=0)
5805                i = i + 1
5806            return this, that, theother
5807        ''')
5808
5809        inputs = [torch.ones(0, 10, 10)] * 3
5810        inputs += [torch.ones(1, 10, 10)]
5811        outputs = [torch.ones(20, 10, 10)] * 3
5812
5813        self.assertEqual(cu.test_fuser_multiple_blocks(*inputs), outputs)
5814
5815    @unittest.skip("RuntimeError: VariableType::ID() not implemented")
5816    def test_cast(self):
5817        script = '''
5818        def to_int(x):
5819            return int(x)
5820        '''
5821        x = Variable(torch.FloatTensor([1.1, 2.3]), requires_grad=True)
5822        out = Variable(torch.IntTensor([1, 2]), requires_grad=True)
5823        self.checkScript(script, [x], optimize=True, outputs=[out], func='to_int')
5824
5825    def test_str_cast(self):
5826        @torch.jit.script
5827        def to_str(x):
5828            # type: (int) -> str
5829            return str((x, x))
5830
5831        self.assertEqual("(1, 1)", to_str(1))
5832
5833    def test_int_cast(self):
5834        @torch.jit.script
5835        def to_int(x):
5836            # type: (str) -> int
5837            return int(x)
5838
5839        self.assertEqual(5, to_int('5'))
5840        self.assertEqual(-5, to_int('-5'))
5841        self.assertEqual(2147483647, to_int('2147483647'))
5842        self.assertEqual(-2147483648, to_int('-2147483648'))
5843
5844        with self.assertRaisesRegex(RuntimeError, "invalid literal for int()"):
5845            to_int('0x20')
5846
5847        with self.assertRaisesRegex(RuntimeError, "invalid literal for int()"):
5848            to_int('0b0001')
5849
5850    def test_python_frontend(self):
5851        def fn(x, y, z):
5852            q = None
5853            q = x + y - z.sigmoid()
5854            print(q)
5855            w = -z
5856            if not x and not y and z:
5857                m = x if not z else y
5858            while x < y > z:
5859                q = x
5860            assert 1 == 1, "hello"
5861            return x
5862
5863        ast = torch.jit.frontend.get_jit_def(fn, fn.__name__)
5864        self.assertExpected(str(ast))
5865
5866    def test_python_frontend_source_range(self):
5867        def fn():
5868            raise Exception("hello")  # noqa: TRY002
5869        ast = torch.jit.frontend.get_jit_def(fn, fn.__name__)
5870        FileCheck().check("SourceRange at:") \
5871                   .check("def fn():") \
5872                   .check("~~~~~~~~~") \
5873                   .check('raise Exception("hello")') \
5874                   .check('~~~~~~~~~~~~~~~~~ <--- HERE') \
5875                   .run(str(ast.range()))
5876
5877    def test_python_frontend_py3(self):
5878        def fn():
5879            raise Exception("hello")  # noqa: TRY002
5880        ast = torch.jit.frontend.get_jit_def(fn, fn.__name__)
5881        self.assertExpected(str(ast))
5882
5883    def _make_scalar_vars(self, arr, dtype):
5884        return [torch.tensor(val, dtype=dtype) for val in arr]
5885
5886
5887    def test_string_print(self):
5888        def func(a):
5889            print(a, "a" 'b' '''c''' """d""", 2, 1.5)
5890            return a
5891
5892        inputs = self._make_scalar_vars([1], torch.int64)
5893        self.checkScript(func, inputs, capture_output=True)
5894
5895    def test_while(self):
5896        def func(a, b, max):
5897            while bool(a < max):
5898                a = a + 1
5899                b = b + 1
5900            c = a + b
5901            return c
5902
5903        inputs = self._make_scalar_vars([1, 1, 10], torch.int64)
5904        self.checkScript(func, inputs, optimize=True)
5905
5906    def test_fibb(self):
5907        def func(lim):
5908            first = 1
5909            second = 1
5910            i = 1
5911            somenum = 5
5912            dontmutateme = 3
5913            third = 0
5914            while bool(i < lim):
5915                third = first + second
5916                first = second
5917                second = third
5918                j = 0
5919                while j < 10:
5920                    somenum = somenum * 2
5921                    j = j + 1
5922                i = i + j
5923                i = i + dontmutateme
5924
5925            st = second + third
5926            fs = first + second
5927            return third, st, fs
5928
5929        inputs = self._make_scalar_vars([10], torch.int64)
5930        self.checkScript(func, inputs, optimize=True)
5931
5932    def test_fibb_totally_better(self):
5933        def fib(x):
5934            # type: (int) -> int
5935            prev = 1
5936            v = 1
5937            for i in range(0, x):
5938                save = v
5939                v = v + prev
5940                prev = save
5941            return v
5942
5943        self.checkScript(fib, (10,))
5944
5945    def test_if(self):
5946        def func(a, b):
5947            # type: (int, int) -> int
5948            d = 3
5949            if bool(a > 10):
5950                a = 3 + d
5951            else:
5952                b = 3 + d
5953                d = 4
5954            c = a + b
5955            return c
5956
5957        inputs = self._make_scalar_vars([1, -1], torch.int64)
5958        self.checkScript(func, inputs, optimize=True)
5959
5960    def test_if_for_in_range(self):
5961        def func(a, b):
5962            # type: (int, int) -> int
5963            d = 3
5964            for _ in range(20):
5965                if bool(a > 10):
5966                    a = 3 + d
5967                else:
5968                    b = 3 + d
5969                    d = 4
5970                c = a + b
5971            return d
5972        inputs = self._make_scalar_vars([1, -1], torch.int64)
5973        self.checkScript(func, inputs, optimize=True)
5974
5975    def test_if_noelse(self):
5976        def func(a, b):
5977            if bool(a > 10):
5978                a = 3 + b
5979            c = a + b
5980            return c
5981
5982        inputs = self._make_scalar_vars([-1, 1], torch.int64)
5983        self.checkScript(func, inputs, optimize=True)
5984
5985    def test_if_is_none_dispatch(self):
5986
5987        @torch.jit.script
5988        def test_lhs_none_rhs_none():
5989            # LHS, RHS both alwaysNone, dispatch always_none_branch
5990            # only emit one prim::Constant
5991            if None is None:
5992                return 1
5993            elif None is not None:
5994                return 2
5995            else:
5996                return 3
5997
5998        self.assertTrue(str(test_lhs_none_rhs_none.graph).count(': int = prim::Constant') == 1)
5999
6000        @torch.jit.script
6001        def test_lhs_opt_rhs_none(lhs=None):
6002            # type: (Optional[Tensor]) -> int
6003            # LHS maybeNone: emit normal if stmt that contains 3 constants
6004            if lhs is not None:
6005                return 2
6006            elif lhs is None:
6007                return 1
6008            else:
6009                return 3
6010
6011        self.assertTrue(str(test_lhs_opt_rhs_none.graph).count(': int = prim::Constant') == 3)
6012
6013        @torch.jit.script
6014        def test_lhs_none_rhs_opt(rhs=None):
6015            # type: (Optional[Tensor]) -> int
6016            # RHS maybeNone, emit normal if stmt that contains 3 constants
6017            if None is rhs:
6018                return 1
6019            elif None is not rhs:
6020                return 2
6021            else:
6022                return 3
6023
6024        self.assertTrue(str(test_lhs_opt_rhs_none.graph).count(': int = prim::Constant') == 3)
6025
6026        @torch.jit.script
6027        def test_lhs_never_rhs_none(lhs):
6028            # LHS neverNone, RHS alwaysNone dispatch never_none_branch
6029            # only emit one prim::Constant
6030            if lhs is None:
6031                return 1
6032            elif lhs is not None:
6033                return 2
6034            else:
6035                return 3
6036
6037        self.assertTrue(str(test_lhs_never_rhs_none.graph).count(': int = prim::Constant') == 1)
6038
6039        @torch.jit.script
6040        def test_lhs_none_rhs_never(rhs):
6041            # LHS alwaysNone, RHS neverNone dispatch never_none_branch
6042            # only emit one prim::Constant
6043            if None is rhs:
6044                return 1
6045            elif None is not rhs:
6046                return 2
6047            else:
6048                return 3
6049
6050        self.assertTrue(str(test_lhs_none_rhs_never.graph).count(': int = prim::Constant') == 1)
6051
6052        @torch.jit.script
6053        def test_bool_arith_and(lhs):
6054            if lhs is None and lhs is not None:
6055                return 1
6056            else:
6057                return 2
6058        self.assertEqual(test_bool_arith_and(torch.zeros(3)), 2)
6059        self.assertTrue(str(test_bool_arith_and.graph).count('if') == 0)
6060
6061        @torch.jit.script
6062        def test_bool_arith_or(lhs):
6063            if lhs is None or lhs is not None:
6064                return 1
6065            else:
6066                return 2
6067        self.assertEqual(test_bool_arith_or(torch.zeros(3)), 1)
6068        self.assertTrue(str(test_bool_arith_or.graph).count('if') == 0)
6069
6070
6071        @torch.jit.script
6072        def test_bool_arith_not(lhs):
6073            if lhs is not None:
6074                return 1
6075            else:
6076                return 2
6077        self.assertEqual(test_bool_arith_not(torch.zeros(3)), 1)
6078        self.assertTrue(str(test_bool_arith_not.graph).count('if') == 0)
6079
6080    def test_conditional_casting(self):
6081        def test_bool_cast_tensor(x):
6082            if x:
6083                return 1
6084            else:
6085                return 0
6086
6087        for make_one_dim in [True, False]:
6088            for inp_val in [0.1, 0.0, -0.0, -0.1, -1, 0, 1]:
6089                inp_val = [inp_val] if make_one_dim else inp_val
6090                self.checkScript(test_bool_cast_tensor, (torch.tensor(inp_val),))
6091
6092        self.checkScriptRaisesRegex(test_bool_cast_tensor, (torch.tensor([1, 1]),), Exception,
6093                                    "Boolean value of Tensor with more than one value")
6094
6095        def test_not_cast(x):
6096            if not x:
6097                return 1
6098            else:
6099                return 0
6100
6101        self.checkScript(test_not_cast, (torch.tensor(1),))
6102        self.checkScript(test_not_cast, (torch.tensor(0),))
6103
6104        with self.assertRaisesRegex(RuntimeError, r"Could not cast value of type Tuple\[Tensor, Tensor\]"):  # noqa: W605
6105            @torch.jit.script
6106            def test_mult(x, y):
6107                return not (x, y)
6108
6109        def test_cast_int(x):
6110            # type: (int) -> int
6111            if x:
6112                return 1
6113            else:
6114                return 0
6115        self.checkScript(test_cast_int, (1,))
6116        self.checkScript(test_cast_int, (0,))
6117        self.checkScript(test_cast_int, (-1,))
6118
6119        def test_cast_float(x):
6120            # type: (float) -> int
6121            if x:
6122                return 1
6123            else:
6124                return 0
6125        self.checkScript(test_cast_float, (1.,))
6126        self.checkScript(test_cast_float, (0.,))
6127        self.checkScript(test_cast_float, (-1.,))
6128
6129        with self.assertRaisesRegex(RuntimeError, r"Could not cast value of type Tuple\[int, int\] to bool"):  # noqa: W605
6130
6131            @torch.jit.script
6132            def test_bad_conditional(x):
6133                if (1, 2):  # noqa: F634
6134                    return
6135                else:
6136                    return 0
6137
6138    def test_while_nonexistent_value(self):
6139        with self.assertRaisesRegex(RuntimeError, "undefined value x"):
6140            torch.jit.CompilationUnit('''
6141            def test_while(a, b):
6142                while bool(a < 10):
6143                    a = a + x
6144                    b = b + 1
6145                return a + b
6146            ''')
6147
6148    def test_while_nonexistent_cond_value(self):
6149        with self.assertRaisesRegex(RuntimeError, "undefined value x"):
6150            torch.jit.CompilationUnit('''
6151            def test_while(a, b):
6152                while a < x:
6153                    a = a + 1
6154                    b = b + 1
6155                return a + b
6156            ''')
6157
6158        @torch.jit.script
6159        def test_ternary(x):
6160            # type: (Optional[int]) -> int
6161            x = x if x is not None else 2
6162            return x
6163
6164        @torch.jit.script
6165        def test_not_none(x):
6166            # type: (Optional[int]) -> None
6167            if x is not None:
6168                print(x + 1)
6169
6170        @torch.jit.script
6171        def test_and(x, y):
6172            # type: (Optional[int], Optional[int]) -> None
6173            if x is not None and y is not None:
6174                print(x + y)
6175
6176        @torch.jit.script
6177        def test_not(x, y):
6178            # type: (Optional[int], Optional[int]) -> None
6179            if not (x is not None and y is not None):
6180                pass
6181            else:
6182                print(x + y)
6183
6184        @torch.jit.script
6185        def test_bool_expression(x):
6186            # type: (Optional[int]) -> None
6187            if x is not None and x < 2:
6188                print(x + 1)
6189
6190        @torch.jit.script
6191        def test_nested_bool_expression(x, y):
6192            # type: (Optional[int], Optional[int]) -> int
6193            if x is not None and x < 2 and y is not None:
6194                x = x + y
6195            else:
6196                x = 5
6197            return x + 2
6198
6199        @torch.jit.script
6200        def test_or(x, y):
6201            # type: (Optional[int], Optional[int]) -> None
6202            if y is None or x is None:
6203                pass
6204            else:
6205                print(x + y)
6206
6207        # backwards compatibility
6208        @torch.jit.script
6209        def test_manual_unwrap_opt(x):
6210            # type: (Optional[int]) -> int
6211            if x is None:
6212                x = 1
6213            else:
6214                x = torch.jit._unwrap_optional(x)
6215            return x  # noqa: T484
6216
6217        with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):
6218            @torch.jit.script
6219            def or_error(x, y):
6220                # type: (Optional[int], Optional[int]) -> None
6221                if x is None or y is None:
6222                    print(x + y)  # noqa: T484
6223
6224        with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):
6225            @torch.jit.script
6226            def and_error(x, y):
6227                # type: (Optional[int], Optional[int]) -> None
6228                if x is None and y is None:
6229                    pass
6230                else:
6231                    print(x + y)  # noqa: T484
6232
6233        with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):
6234            @torch.jit.script
6235            def named_var(x):
6236                # type: (Optional[int]) -> None
6237                x_none = x is not None
6238                if x_none:
6239                    print(x + 1)  # noqa: T484
6240
6241        with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):
6242            @torch.jit.script
6243            def named_var_and(x, y):
6244                # type: (Optional[int], Optional[int]) -> None
6245                x_none = x is not None
6246                if y is not None and x_none:
6247                    print(x + y)  # noqa: T484
6248
6249    def test_assertion_optional_refinement(self):
6250        @torch.jit.script
6251        def test(x, y):
6252            # type: (Optional[int], Optional[int]) -> int
6253            assert x is not None and y is not None
6254            return x + y
6255
6256        self.assertEqual(test(2, 2), 4)
6257        with self.assertRaisesRegex(Exception, ""):
6258            test(1, None)
6259
6260    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "the current version of Profiler doesn't profile/specialize Optionals")
6261    def test_optional_tensor(self):
6262        @torch.jit.script
6263        def fn(x, y):
6264            # type: (Optional[Tensor], int) -> int
6265            if x is None:
6266                return y
6267            else:
6268                return 0
6269
6270        res = fn(None, 1)
6271        self.assertEqual(res, 1)
6272        g = torch.jit.last_executed_optimized_graph()
6273        first_input = next(g.inputs())
6274        # check if input is disconnected
6275        self.assertEqual(first_input.type().kind(), 'OptionalType')
6276        self.assertEqual(first_input.uses(), [])
6277        t = torch.ones(1)
6278        res = fn(t, 1)
6279        self.assertEqual(res, 0)
6280        g = torch.jit.last_executed_optimized_graph()
6281        self.assertEqual(next(g.inputs()).type().kind(), 'TensorType')
6282
6283        @torch.jit.script
6284        def fn(x, y, b):
6285            # type: (Optional[Tensor], Tensor, bool) -> Tensor
6286            if b:
6287                res = y
6288            else:
6289                res = torch.jit._unwrap_optional(x)
6290            return res
6291
6292        t2 = torch.zeros(1)
6293        res = fn(t, t2, True)
6294        self.assertEqual(res, t2)
6295        with self.assertRaisesRegex(RuntimeError, "Unwrapping null optional"):
6296            res = fn(None, t2, False)
6297        res = fn(None, t2, True)
6298        g = torch.jit.last_executed_optimized_graph()
6299        self.assertIn(next(g.outputs()).type().str(), ("Tensor", "Tensor(requires_grad=1)"))
6300
6301    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "the current version of Profiler doesn't profile/specialize Optionals")
6302    def test_optional_list(self):
6303        @torch.jit.script
6304        def fn(x, y):
6305            # type: (Optional[List[int]], int) -> int
6306            if x is None:
6307                return y
6308            else:
6309                res = 0
6310                for d in x:
6311                    res += d
6312                return res
6313
6314        res = fn(None, 1)
6315        self.assertEqual(res, 1)
6316        g = torch.jit.last_executed_optimized_graph()
6317        first_input = next(g.inputs())
6318        # check if input is disconnected
6319        self.assertEqual(first_input.type().kind(), 'OptionalType')
6320        self.assertEqual(first_input.uses(), [])
6321        l = [2, 3]
6322        res = fn(l, 1)
6323        self.assertEqual(res, 5)
6324        g = torch.jit.last_executed_optimized_graph()
6325        self.assertEqual(next(g.inputs()).type().kind(), 'ListType')
6326
6327        @torch.jit.script
6328        def fn(x, y, b):
6329            # type: (Optional[List[int]], List[int], bool) -> List[int]
6330            if b:
6331                l = torch.jit._unwrap_optional(x)
6332            else:
6333                l = y
6334            return l
6335
6336        l2 = [0, 1]
6337        res = fn(l, l2, True)
6338        self.assertEqual(res, l)
6339        with self.assertRaisesRegex(RuntimeError, "Unwrapping null optional"):
6340            res = fn(None, l2, True)
6341        res = fn(None, l2, False)
6342        g = torch.jit.last_executed_optimized_graph()
6343        self.assertEqual(next(g.outputs()).type().str(), "int[]")
6344
6345    def test_alias_covariant_type_containers(self):
6346        @torch.jit.script
6347        def foo(x):
6348            # type: (bool)
6349            if x:
6350                a = (None,)
6351            else:
6352                a = ([],)
6353            return a
6354
6355        @torch.jit.script
6356        def foo2(x, li):
6357            # type: (bool, Tuple[Optional[List[Tensor]]])
6358            if x:
6359                li = (None,)
6360            return li
6361
6362    def test_while_write_outer_then_read(self):
6363        def func(a, b):
6364            while bool(a < 10):
6365                a = a + 1
6366                b = a + 1
6367            return a + b
6368
6369        inputs = self._make_scalar_vars([42, 1337], torch.int64)
6370        self.checkScript(func, inputs, optimize=True)
6371
6372    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
6373    def test_while_nest_if(self):
6374        def func(a, b):
6375            # type: (int, int) -> int
6376            c = 0
6377            while a < 10:
6378                a = a + 1
6379                b = b + 1
6380                if a > b:
6381                    c = -a
6382                else:
6383                    c = -b
6384            return c + 1
6385
6386        inputs = self._make_scalar_vars([-1234, 4321], torch.int64)
6387        self.checkScript(func, inputs, optimize=True)
6388
6389    def test_divmod(self):
6390        def func_int(a, b):
6391            # type: (int, int) -> Tuple[int, int]
6392            return divmod(a, b)
6393
6394        def func_float(a, b):
6395            # type: (float, float) -> Tuple[float, float]
6396            return divmod(a, b)
6397
6398        def func_int_float(a, b):
6399            # type: (int, float) -> Tuple[float, float]
6400            return divmod(a, b)
6401
6402        def func_float_int(a, b):
6403            # type: (float, int) -> Tuple[float, float]
6404            return divmod(a, b)
6405
6406        def divmod_test_iterator(func, num, den):
6407            for i in num:
6408                for j in den:
6409                    self.checkScript(func, (i, j), frames_up=2)
6410
6411        num_int = [1024, -1024]
6412        den_int = [10, -10]
6413        num_float = [5.3, -5.3]
6414        den_float = [2.0, -2.0]
6415        divmod_test_iterator(func_int, num_int, den_int)
6416        divmod_test_iterator(func_float, num_float, den_float)
6417        divmod_test_iterator(func_int_float, num_int, den_float)
6418        divmod_test_iterator(func_float_int, num_float, den_int)
6419
6420        with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError: integer division or modulo by zero"):
6421            cu = torch.jit.CompilationUnit(dedent(inspect.getsource(func_int)))
6422            cu.func_int(1024, 0)
6423        with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError: float divmod()"):
6424            cu = torch.jit.CompilationUnit(dedent(inspect.getsource(func_float)))
6425            cu.func_float(5.3, 0.0)
6426        with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError: float divmod()"):
6427            cu = torch.jit.CompilationUnit(dedent(inspect.getsource(func_int_float)))
6428            cu.func_int_float(1024, 0.0)
6429        with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError: float divmod()"):
6430            cu = torch.jit.CompilationUnit(dedent(inspect.getsource(func_float_int)))
6431            cu.func_float_int(5.3, 0)
6432
6433    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
6434    def test_math_ops(self):
6435        def checkMathWrap(func_name, num_args=1, is_float=True, **args):
6436            if is_float:
6437                checkMath(func_name, num_args, True, **args)
6438                checkMath(func_name, num_args, False, **args)
6439            else:
6440                checkMath(func_name, num_args, is_float, **args)
6441
6442        inf = float("inf")
6443        NaN = float("nan")
6444        mx_int = 2**31 - 1
6445        mn_int = -2**31
6446        float_vals = ([inf, NaN, 0.0, 1.0, 2.2, -1.0, -0.0, -2.2, -inf, 1, 0, 2] +
6447                      [10.0 ** i for i in range(5)] + [-(10.0 ** i) for i in range(5)])
6448        int_vals = list(range(-5, 5, 1)) + [mx_int + 5, mx_int * 2, mn_int - 5, mn_int * 2]
6449
6450        def checkMath(func_name, num_args, is_float=True, ret_type="float", debug=False, vals=None, args_type=None):
6451            funcs_template = dedent('''
6452            def func(a, b):
6453                # type: {args_type} -> {ret_type}
6454                return math.{func}({args})
6455            ''')
6456            if num_args == 1:
6457                args = "a"
6458            elif num_args == 2:
6459                args = "a, b"
6460            else:
6461                raise RuntimeError("Test doesn't support more than 2 arguments")
6462            if args_type is None:
6463                args_type = "(float, float)" if is_float else "(int, int)"
6464            funcs_str = funcs_template.format(func=func_name, args=args, args_type=args_type, ret_type=ret_type)
6465            scope = {}
6466            execWrapper(funcs_str, globals(), scope)
6467            cu = torch.jit.CompilationUnit(funcs_str)
6468            f_script = cu.func
6469            f = scope['func']
6470
6471            if vals is None:
6472                vals = float_vals if is_float else int_vals
6473                vals = [(i, j) for i in vals for j in vals]
6474
6475            for a, b in vals:
6476                res_python = None
6477                res_script = None
6478                try:
6479                    res_python = f(a, b)
6480                except Exception as e:
6481                    res_python = e
6482                try:
6483                    res_script = f_script(a, b)
6484                except Exception as e:
6485                    res_script = e
6486                if debug:
6487                    print("in: ", a, b)
6488                    print("out: ", res_python, res_script)
6489                # We can't use assertEqual because of a couple of differences:
6490                # 1. nan == nan should return true
6491                # 2. When python functions throw an exception, we usually want to silently ignore them.
6492                # (ie: We want to return `nan` for math.sqrt(-5))
6493                if res_python != res_script:
6494                    if isinstance(res_python, Exception):
6495                        continue
6496
6497                    if type(res_python) == type(res_script):
6498                        if isinstance(res_python, tuple) and (math.isnan(res_python[0]) == math.isnan(res_script[0])):
6499                            continue
6500                        if isinstance(res_python, float) and math.isnan(res_python) and math.isnan(res_script):
6501                            continue
6502                    msg = (f"Failed on {func_name} with inputs {a} {b}. Python: {res_python}, Script: {res_script}")
6503                    # math.pow() behavior has changed in 3.11, see https://docs.python.org/3/library/math.html#math.pow
6504                    if sys.version_info >= (3, 11) and func_name == "pow" and a == 0.0 and b == -math.inf:
6505                        self.assertTrue(res_python == math.inf and type(res_script) is RuntimeError)
6506                    else:
6507                        self.assertEqual(res_python, res_script, msg=msg, atol=(1e-4) * max(abs(res_python), res_script), rtol=0)
6508
6509        unary_float_ops = ["log", "log1p", "log10", "exp", "sqrt", "gamma", "lgamma", "erf",
6510                           "erfc", "expm1", "fabs", "acos", "asin", "atan", "cos", "sin", "tan",
6511                           "asinh", "atanh", "acosh", "sinh", "cosh", "tanh", "degrees", "radians"]
6512        binary_float_ops = ["atan2", "fmod", "copysign"]
6513        for op in unary_float_ops:
6514            checkMathWrap(op, 1)
6515        for op in binary_float_ops:
6516            checkMathWrap(op, 2)
6517
6518        checkMath("modf", 1, ret_type="Tuple[float, float]")
6519        checkMath("frexp", 1, ret_type="Tuple[float, int]")
6520        checkMath("isnan", 1, ret_type="bool")
6521        checkMath("isinf", 1, ret_type="bool")
6522        checkMath("ldexp", 2, is_float=False, ret_type="float", args_type="(float, int)",
6523                  vals=[(i, j) for i in float_vals for j in range(-10, 10)])
6524        checkMath("pow", 2, is_float=False, ret_type="float")
6525        checkMath("pow", 2, is_float=True, ret_type="float")
6526        checkMathWrap("floor", ret_type="int")
6527        checkMathWrap("ceil", ret_type="int")
6528        checkMathWrap("gcd", 2, is_float=False, ret_type="int")
6529        checkMath("isfinite", 1, ret_type="bool")
6530        checkMathWrap("remainder", 2)
6531        checkMathWrap("factorial", 1, is_float=False, ret_type="int", vals=[(i, 0) for i in range(-2, 10)])
6532
6533    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
6534    def test_if_nest_while(self):
6535        def func(a, b):
6536            # type: (int, int) -> int
6537            c = 0
6538            if a > b:
6539                while a > b:
6540                    b = b + 1
6541                    c = -b
6542            return c
6543
6544        inputs = self._make_scalar_vars([4321, 1234], torch.int64)
6545        self.checkScript(func, inputs)
6546
6547    def test_script_optional_none(self):
6548        def none_stmt(x):
6549            output = None
6550            output = x
6551            return output
6552
6553        def none_args(x):
6554            # type: (Optional[Tensor]) -> Optional[Tensor]
6555            return None
6556
6557        self.checkScript(none_stmt, [torch.arange(0, 2)], optimize=True)
6558        self.checkScript(none_args, [None], optimize=True)
6559
6560        # test undefined tensor None as default param
6561        def test_script_optional_tensor_none(x=None):
6562            # type: (Optional[Tensor]) -> Tensor
6563            res = torch.zeros(1, dtype=torch.int8)
6564            if x is None:
6565                res = res + 1
6566            else:
6567                res = x
6568            return res
6569
6570        fn = test_script_optional_tensor_none
6571        scripted_fn = torch.jit.script(fn)
6572        self.assertEqual(fn(), scripted_fn())
6573        self.assertEqual(fn(torch.zeros(1)), scripted_fn(torch.zeros(1)))
6574
6575        # test typical None as default param
6576        def test_script_optional_other_none(x=None):
6577            # type: (Optional[float]) -> float
6578            res = 2.0
6579            if x is None:
6580                res = res + 1.0
6581            else:
6582                res = x
6583            return res
6584
6585        fn = test_script_optional_other_none
6586        scripted_fn = torch.jit.script(fn)
6587        self.assertEqual(fn(), scripted_fn())
6588        self.assertEqual(fn(1.0), scripted_fn(1.0))
6589
6590    def test_script_clamp_none(self):
6591        def test_script_clamp_max_none(x):
6592            return torch.clamp(x, min=2, max=None)
6593
6594        def test_script_clamp_max(x):
6595            return torch.clamp(x, max=2)
6596
6597        def test_script_clamp_min_none(x):
6598            return torch.clamp(x, min=None, max=2)
6599
6600        def test_script_clamp_min(x):
6601            return torch.clamp(x, min=2)
6602
6603        input = [torch.arange(0, 3)]
6604        self.checkScript(test_script_clamp_max_none, input, optimize=True)
6605        self.checkScript(test_script_clamp_max, input, optimize=True)
6606        self.checkScript(test_script_clamp_min_none, input, optimize=True)
6607        self.checkScript(test_script_clamp_min, input, optimize=True)
6608
6609    def test_script_bool_constant(self):
6610        def test_script_bool_constant():
6611            a = True
6612            return a
6613        self.checkScript(test_script_bool_constant, [])
6614
6615    def test_ternary(self):
6616        def func(a, b):
6617            c = 3
6618            c = a + b if bool(a > 3) else b
6619            return c
6620
6621        inputs_true = self._make_scalar_vars([5, 2], torch.int64)
6622        inputs_false = self._make_scalar_vars([1, 0], torch.int64)
6623        self.checkScript(func, inputs_true, optimize=True)
6624        self.checkScript(func, inputs_false, optimize=True)
6625
6626    def test_ternary_module_type_hint(self):
6627        class M1(torch.nn.Module):
6628            def forward(self) -> Any:
6629                return 'out' if self.training else {}
6630
6631        class M2(torch.nn.Module):
6632            def forward(self) -> Any:
6633                out: Any = 'out' if self.training else {}
6634                return out
6635
6636        class M3(torch.nn.Module):
6637            def forward(self) -> Optional[int]:
6638                return None if self.training else 1
6639
6640        for module in [M1, M2, M3]:
6641            self.checkModule(module().train(), ())
6642            self.checkModule(module().eval(), ())
6643
6644    def test_ternary_static_if(self):
6645        # Test for True branch when condition variable
6646        # is annotated as Final
6647        class M1(torch.nn.Module):
6648            flag: torch.jit.Final[bool]
6649
6650            def __init__(self) -> None:
6651                super().__init__()
6652                self.flag = True
6653
6654            def forward(self) -> torch.Tensor:
6655                return torch.ones(3) if self.flag else {}
6656
6657        # Test for True branch when condition variable
6658        # is annotated as Final
6659        class M2(torch.nn.Module):
6660            flag: torch.jit.Final[bool]
6661
6662            def __init__(self) -> None:
6663                super().__init__()
6664                self.flag = False
6665
6666            def forward(self) -> torch.Tensor:
6667                return {} if self.flag else torch.ones(3)
6668
6669        model1 = M1()
6670        model2 = M2()
6671        script_model_1 = torch.jit.script(model1)
6672        script_model_2 = torch.jit.script(model2)
6673        self.assertEqual(model1.forward(), script_model_1.forward())
6674        self.assertEqual(model2.forward(), script_model_2.forward())
6675
6676    def test_ternary_right_associative(self):
6677        def plus_123(x: int):
6678            return x + 1 if x == 1 else x + 2 if x == 2 else x + 3
6679        self.checkScript(plus_123, (1,))
6680        self.checkScript(plus_123, (2,))
6681        self.checkScript(plus_123, (3,))
6682
6683    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
6684    def test_print(self):
6685        def func(x, y):
6686            q = (x + y).sigmoid()
6687            print(q, 1, 2, [1, 2], [1.0, 2.0])
6688            w = -q
6689            return w * w
6690
6691        x = torch.arange(4., requires_grad=True)
6692        y = torch.arange(0., 8, 2, requires_grad=True)
6693        self.checkScript(func, [x, y], optimize=True, capture_output=True)
6694
6695    def test_format(self):
6696        def func(x):
6697            print("{}, I'm a {}".format("Hello", "test"))
6698            print("format blank".format())
6699            print("stuff before {}".format("hi"))
6700            print("{} stuff after".format("hi"))
6701            return x + 1
6702
6703        x = torch.arange(4., requires_grad=True)
6704        self.checkScript(func, [x], optimize=True, capture_output=True)
6705
6706    def test_logical_short_circuit(self):
6707        @torch.jit.script
6708        def testNoThrows(t):
6709            c1 = 1
6710            if (False and bool(t[1])) or (True or bool(t[1])):
6711                c1 = 0
6712            return c1
6713
6714        FileCheck().check_not("prim::If").run(testNoThrows.graph)
6715        self.assertEqual(0, testNoThrows(torch.randn(0)))
6716        self.assertEqual(0, testNoThrows(torch.randn([2, 3])))
6717
6718        @torch.jit.script
6719        def throwsOr(t):
6720            c0 = False or bool(t[1])
6721            print(c0)
6722
6723        @torch.jit.script
6724        def throwsAnd(t):
6725            c0 = True and bool(t[1])
6726            print(c0)
6727
6728        t = torch.randn(0)
6729        with self.assertRaisesRegex(RuntimeError, "index 1 out of range for tensor of size"):
6730            throwsOr(t)
6731        with self.assertRaisesRegex(RuntimeError, "index 1 out of range for tensor of size"):
6732            throwsAnd(t)
6733
6734    def test_type_cast(self):
6735        template = dedent('''
6736        def func(v):
6737            # type: ({from_type}) -> {to_type}
6738            return {to_type}(v)
6739        ''')
6740
6741        def check_cast(from_type, to_type, value, raises=False):
6742            code = template.format(from_type=from_type, to_type=to_type)
6743            self.checkScript(code, (value,))
6744
6745        check_cast('int', 'float', 1)
6746        check_cast('int', 'bool', 1)
6747        check_cast('int', 'bool', 0)
6748
6749        check_cast('float', 'int', 1.)
6750        check_cast('float', 'bool', 1.)
6751        check_cast('float', 'bool', 0.)
6752
6753        check_cast('bool', 'int', True)
6754        check_cast('bool', 'float', True)
6755
6756    def test_multiple_assignment(self):
6757        def outer_func(x):
6758            return x * 2, x + 2
6759
6760        @torch.jit.script
6761        def func(x):
6762            y, z = outer_func(x)
6763            return y + z
6764
6765        x = torch.arange(4)
6766        self.assertEqual(func(x), x * 2 + x + 2)
6767
6768    def test_literals(self):
6769        def func(a):
6770            return a.view(size=[1, 2, 3])
6771
6772        a = torch.randn(6)
6773        self.checkScript(func, [a], optimize=True)
6774
6775    def test_return(self):
6776        def no_return(a):
6777            a + 1
6778
6779        def void_return(a):
6780            return
6781
6782        def one_return(a):
6783            return a + 1.
6784
6785        def multiple_returns(a):
6786            return a * 1., a * 2., a * 3.
6787
6788        a = torch.randn(1, dtype=torch.float)
6789        self.checkScript(no_return, [a], optimize=True)
6790        self.checkScript(void_return, [a], optimize=True)
6791        self.checkScript(one_return, [a], optimize=True)
6792        self.checkScript(multiple_returns, [a], optimize=True)
6793
6794        with self.assertRaisesRegex(RuntimeError, "does not return along all paths"):
6795            torch.jit.CompilationUnit('''
6796            def no_return_bad_annotation(a):
6797                # type: (Tensor) -> Tensor
6798                a + 1
6799            ''')
6800
6801    def test_error(self):
6802        @torch.jit.script
6803        def foo(a):
6804            return a.t()
6805        s = Variable(torch.rand(5, 5, 5))
6806        # XXX: this should stay quiet in stay propagation and only fail in the interpreter
6807        with self.assertRaisesRegex(RuntimeError, "failed in the TorchScript interpreter"):
6808            foo(s)
6809
6810        @torch.jit.script
6811        def bar(c, b):
6812            return c + b
6813
6814        with self.assertRaisesRegex(RuntimeError, "failed in the TorchScript interpreter"):
6815            bar(Variable(torch.rand(10), requires_grad=True), Variable(torch.rand(9), requires_grad=True))
6816
6817    def test_error_stacktrace(self):
6818        @torch.jit.script
6819        def baz(c, b):
6820            return c + b
6821
6822        @torch.jit.script
6823        def foo(c, b):
6824            return baz(c, b)
6825
6826        @torch.jit.script
6827        def bar(c, b):
6828            return foo(c, b)
6829
6830        with self.assertRaises(RuntimeError) as cm:
6831            bar(torch.rand(10), torch.rand(9))
6832        FileCheck().check("The following operation failed in the TorchScript interpreter") \
6833                   .check("Traceback") \
6834                   .check("in foo").check("in baz").run(str(cm.exception))
6835
6836    def test_error_stacktrace_interface(self):
6837        @torch.jit.script
6838        def baz(c, b):
6839            return c + b
6840
6841        @torch.jit.script
6842        def foo(c, b):
6843            return baz(c, b)
6844
6845        @torch.jit.script
6846        def bar(c, b):
6847            return foo(c, b)
6848
6849        @torch.jit.script
6850        class Bar:
6851            def one(self, x, y):
6852                return bar(x, y)
6853
6854        @torch.jit.interface
6855        class IFace:
6856            def one(self, x, y):
6857                # type: (Tensor, Tensor) -> Tensor
6858                pass
6859
6860        make_global(IFace)
6861
6862        @torch.jit.script
6863        def as_interface(x):
6864            # type: (IFace) -> IFace
6865            return x
6866
6867        f = as_interface(Bar())
6868
6869        with self.assertRaises(RuntimeError) as cm:
6870            x = f.one(torch.rand(10), torch.rand(9))
6871            bar(torch.rand(10), torch.rand(9))
6872        FileCheck().check("The following operation failed in the TorchScript interpreter") \
6873                   .check("Traceback") \
6874                   .check("in foo").check("in baz").run(str(cm.exception))
6875
6876    def test_operator_precedence(self):
6877        def double(x):
6878            # type: (int) -> int
6879            return 2 * x
6880
6881        def complicated_arithmetic_operation():
6882            # TODO we need to test exponent operator '**' and bitwise not
6883            # operator '~' once they are properly supported.
6884            list = [0, 1, 2, 3]
6885            result = list[1:3][0] + double(4) + (-3 + 8) * 6 // 2 % 4 << 2 + 1 >> 1 | 23 & 16 + 3 ^ 4
6886            return result
6887
6888        self.checkScript(complicated_arithmetic_operation, ())
6889
6890    def test_in_operator_with_two_strings(self):
6891        def fn() -> bool:
6892            return "a" in "abcd"
6893        self.checkScript(fn, ())
6894
6895    def test_bitwise_ops(self):
6896
6897        def int_test():
6898            return 2 & 3, 2 ^ 3, 2 | 3, 2 << 3, 2 >> 3
6899
6900        self.checkScript(int_test, ())
6901
6902        def bool_test(x, y):
6903            # type: (bool, bool) -> Tuple[bool, bool, bool]
6904            return x & y, x ^ y, x | y
6905
6906        self.checkScript(bool_test, (True, False))
6907        self.checkScript(bool_test, (True, True))
6908
6909        def tensor_test(x, y):
6910            return x & y, x ^ y, x | y
6911
6912        def tensor_with_int_test(x, y):
6913            # type: (Tensor, int) -> Tuple[Tensor, Tensor]
6914            return x << y, x >> y
6915
6916        x = torch.tensor(2)
6917        y = torch.tensor(3)
6918
6919        self.checkScript(tensor_test, (x, y))
6920        self.checkScript(tensor_with_int_test, (x, 2))
6921
6922        def not_test(x):
6923            return ~x
6924
6925        self.checkScript(not_test, (torch.tensor([2, 4]), ))
6926
6927    def test_all(self):
6928        @torch.jit.script
6929        def test_all_tensor(x):
6930            return all(x)
6931        self.assertFalse(test_all_tensor(torch.tensor([1, 0, 3], dtype=torch.uint8)))
6932        self.assertTrue(test_all_tensor(torch.tensor([3.14, 3, 99], dtype=torch.uint8)))
6933        self.assertTrue(test_all_tensor(torch.tensor([True, True], dtype=torch.uint8)))
6934        self.assertFalse(test_all_tensor(torch.tensor([True, False], dtype=torch.uint8)))
6935
6936        @torch.jit.script
6937        def test_all_bool_list(x):
6938            # type: (List[bool]) -> bool
6939            return all(x)
6940        self.assertTrue(test_all_bool_list([True, True]))
6941        self.assertTrue(test_all_bool_list([True, 1]))
6942        self.assertFalse(test_all_bool_list([True, False]))
6943        self.assertFalse(test_all_bool_list([True, 0]))
6944        self.assertFalse(test_all_bool_list([False, 0]))
6945        self.assertTrue(test_all_bool_list([]))
6946
6947        @torch.jit.script
6948        def test_all_int_list(x):
6949            # type: (List[int]) -> bool
6950            return all(x)
6951        self.assertTrue(test_all_int_list([3, 6]))
6952        self.assertFalse(test_all_int_list([2, 0]))
6953
6954        @torch.jit.script
6955        def test_all_float_list(x):
6956            # type: (List[float]) -> bool
6957            return all(x)
6958        self.assertTrue(test_all_float_list([3.14, 8.1]))
6959        self.assertFalse(test_all_float_list([3.14, 0, 8.9]))
6960
6961
6962    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
6963    def test_number_math(self):
6964        ops_template = dedent('''
6965        def func():
6966            return {scalar1} {op} {scalar2}
6967        ''')
6968        ops = ['+', '-', '*', '%', '<', '<=', '>', '>=', '==', '!=', '//']
6969        funcs_template = dedent('''
6970        def func():
6971            return {func}({scalar1}, {scalar2})
6972        ''')
6973        funcs = ['min', 'max']
6974        scalars = ['7', '2', '3', '-3', '3.14', '0.125', '-0.5', '2.0', '-2.0']
6975        scalar_pairs = [(scalar1, scalar2) for scalar1 in scalars for scalar2 in scalars]
6976
6977        def run_test(code):
6978            scope = {}
6979            execWrapper(code, globals(), scope)
6980            cu = torch.jit.CompilationUnit(code)
6981
6982            self.assertEqual(cu.func(), scope['func']())
6983
6984        for scalar1, scalar2 in scalar_pairs:
6985            for op in ops:
6986                code = ops_template.format(op=op, scalar1=scalar1, scalar2=scalar2)
6987                run_test(code)
6988            for func in funcs:
6989                code = funcs_template.format(func=func, scalar1=scalar1, scalar2=scalar2)
6990                run_test(code)
6991
6992        # test Scalar overloads
6993        for scalar1, scalar2 in scalar_pairs:
6994            item1 = 'torch.tensor(' + scalar1 + ').item()'
6995            item2 = 'torch.tensor(' + scalar2 + ').item()'
6996            for op in ops:
6997                code = ops_template.format(op=op, scalar1=item1, scalar2=scalar2)
6998                run_test(code)
6999                code = ops_template.format(op=op, scalar1=scalar1, scalar2=item2)
7000                run_test(code)
7001                code = ops_template.format(op=op, scalar1=item1, scalar2=item2)
7002                run_test(code)
7003            for func in funcs:
7004                code = funcs_template.format(func=func, scalar1=item1, scalar2=scalar2)
7005                run_test(code)
7006                code = funcs_template.format(func=func, scalar1=scalar1, scalar2=item2)
7007                run_test(code)
7008                code = funcs_template.format(func=func, scalar1=item1, scalar2=item2)
7009                run_test(code)
7010
7011    def test_number_abs(self):
7012        def func1(x):
7013            # type: (float) -> float
7014            return abs(x)
7015
7016        def func2(x):
7017            # type: (int) -> int
7018            return abs(x)
7019
7020        def func3(x):
7021            return abs(x)
7022
7023        self.checkScript(func1, (-3.14,))
7024        self.checkScript(func1, (3.14,))
7025        self.checkScript(func2, (-10,))
7026        self.checkScript(func2, (10,))
7027        self.checkScript(func3, (torch.tensor([-5, -10, -20]),))
7028        self.checkScript(func3, (torch.tensor([5, 10, 20]),))
7029        self.checkScript(func3, (torch.tensor([-5, 10, -20]),))
7030
7031    def test_number_div(self):
7032        self.assertEqual(div_int_future(), torch.jit.script(div_int_future)())
7033        self.checkScript(div_float_future, ())
7034
7035        self.checkScript(div_int_nofuture, ())
7036        self.checkScript(div_float_nofuture, ())
7037
7038    # Testing bitwise shorthand aug assignment
7039    def test_bool_augassign_bitwise_or(self):
7040        def func(a: bool, b: bool) -> bool:
7041            a |= b
7042            return a
7043
7044        self.checkScript(func, (True, False), optimize=True)
7045        self.checkScript(func, (True, True), optimize=True)
7046        self.checkScript(func, (False, False), optimize=True)
7047        self.checkScript(func, (False, True), optimize=True)
7048
7049    def test_bool_augassign_bitwise_and(self):
7050        def func(a: bool, b: bool) -> bool:
7051            a &= b
7052            return a
7053
7054        self.checkScript(func, (True, False), optimize=True)
7055        self.checkScript(func, (True, True), optimize=True)
7056        self.checkScript(func, (False, False), optimize=True)
7057        self.checkScript(func, (False, True), optimize=True)
7058
7059    def test_bool_augassign_bitwise_xor(self):
7060        def func(a: bool, b: bool) -> bool:
7061            a ^= b
7062            return a
7063
7064        self.checkScript(func, (True, False), optimize=True)
7065        self.checkScript(func, (True, True), optimize=True)
7066        self.checkScript(func, (False, False), optimize=True)
7067        self.checkScript(func, (False, True), optimize=True)
7068
7069    def test_number_augassign_bitwise_lshift(self):
7070        def func() -> int:
7071            z = 8
7072            z <<= 2
7073            return z
7074
7075        self.checkScript(func, (), optimize=True)
7076
7077    def test_number_augassign_bitwise_rshift(self):
7078        def func() -> int:
7079            z = 8
7080            z >>= 2
7081            return z
7082
7083        self.checkScript(func, (), optimize=True)
7084
7085    def test_number_augassign_bitwise_pow(self):
7086        def func() -> float:
7087            z = 8
7088            z **= 2
7089            return z
7090
7091        self.checkScript(func, (), optimize=True)
7092
7093    def test_number_augassign(self):
7094        def func():
7095            z = 1
7096            z += 2
7097            return z
7098
7099        self.checkScript(func, (), optimize=True)
7100
7101    def test_nested_select_assign(self):
7102        class SubSubModule(torch.nn.Module):
7103            def __init__(self) -> None:
7104                super().__init__()
7105                self.abc = 11
7106
7107            def forward(self, x):
7108                return self.abc
7109
7110        class SubModule(torch.nn.Module):
7111            def __init__(self) -> None:
7112                super().__init__()
7113                self.a = 11
7114                self.nested = SubSubModule()
7115
7116            def forward(self, x):
7117                return self.a
7118
7119        class TestModule(torch.nn.Module):
7120            def __init__(self) -> None:
7121                super().__init__()
7122                self.sub = SubModule()
7123                self.hi = 1
7124
7125            def forward(self):
7126                self.hi = 5
7127                self.sub.a = 1
7128                self.sub.nested.abc = 5
7129                return self.sub.a * 20 + self.sub.nested.abc * 3 + self.hi
7130
7131        self.checkModule(TestModule(), ())
7132
7133    def test_number_neg(self):
7134        # int -> int
7135        def func1():
7136            return -8
7137
7138        # float -> float
7139        def func2():
7140            return -3.14
7141
7142        self.checkScript(func1, (), optimize=True)
7143        self.checkScript(func2, (), optimize=True)
7144
7145    def test_compare_two_bool_inputs(self):
7146        def compare_eq(a: bool, b: bool):
7147            return a == b
7148
7149        def compare_ne(a: bool, b: bool):
7150            return a != b
7151
7152        scripted_fn_eq = torch.jit.script(compare_eq)
7153        scripted_fn_ne = torch.jit.script(compare_ne)
7154        self.assertEqual(scripted_fn_eq(True, False), compare_eq(True, False))
7155        self.assertEqual(scripted_fn_eq(False, True), compare_eq(False, True))
7156        self.assertEqual(scripted_fn_eq(True, True), compare_eq(True, True))
7157        self.assertEqual(scripted_fn_eq(False, False), compare_eq(False, False))
7158
7159        self.assertEqual(scripted_fn_ne(True, False), compare_ne(True, False))
7160        self.assertEqual(scripted_fn_ne(False, True), compare_ne(False, True))
7161        self.assertEqual(scripted_fn_ne(True, True), compare_ne(True, True))
7162        self.assertEqual(scripted_fn_ne(False, False), compare_ne(False, False))
7163
7164
7165    def _test_tensor_number_math(self, device='cpu'):
7166        template = dedent('''
7167        def func(t):
7168            return {lhs} {op} {rhs}
7169        ''')
7170
7171        def test(op, tensor, const, swap_args, template=template):
7172            args = ('t', const)
7173            if swap_args:
7174                args = (const, 't')
7175
7176            code = template.format(lhs=args[0], rhs=args[1], op=op)
7177            scope = {}
7178            execWrapper(code, globals(), scope)
7179            cu = torch.jit.CompilationUnit(code)
7180            message = f'with code `{args[0]} {op} {args[1]}` and t={tensor}'
7181            res1 = cu.func(tensor)
7182            res2 = scope['func'](tensor)
7183            self.assertEqual(res1, res2, msg=message + "\nres1=" + str(res1) + "\nres2=" + str(res2))
7184            self.assertEqual(res1.dtype, res2.dtype, msg=message + "\nres1=" + str(res1) + "\nres2=" + str(res2))
7185
7186        var_int = [2, -2]
7187        var_float = [1.4321, -1.2]
7188
7189        ops = ['+', '-', '*', '%', '<', '<=', '>', '>=', '==', '!=', '/']
7190
7191        float_tensor = torch.randn(5, 5, device=device)
7192        double_tensor = torch.randn(5, 5, dtype=torch.double, device=device)
7193        long_tensor = torch.randint(-5, 5, (5, 5), dtype=torch.long, device=device)
7194        long_tensor[long_tensor == 0] = 2
7195
7196        tensors = [float_tensor, double_tensor, long_tensor]
7197        consts = var_int + var_float
7198
7199        for op, tensor, const, swap_args in product(ops, tensors, consts, [True, False]):
7200            # FIXME: things like 2 / long_tensor are not implemented correctly
7201            # Look in torch/_tensor.py to see how pytorch implements it.
7202            if op == '/' and tensor.data_ptr() == long_tensor.data_ptr():
7203                continue
7204
7205            # % operator does not take: const % tensor
7206            if op == '%' and swap_args is True:
7207                continue
7208
7209            test(op, tensor, const, swap_args)
7210
7211    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
7212    def test_tensor_number_math(self):
7213        self._test_tensor_number_math()
7214
7215    def test_torch_tensor_bad_input(self):
7216        with self.assertRaisesRegex(RuntimeError, "must be of ints, floats, "
7217                                    "or bools, got None"):
7218            @torch.jit.script
7219            def test():
7220                return torch.tensor([None])
7221            test()
7222
7223        with self.assertRaisesRegex(RuntimeError, r"Empty lists default to List\[Tensor\]"):
7224            @torch.jit.script
7225            def tmp():
7226                return torch.tensor([])
7227            tmp()
7228
7229        @torch.jit.script
7230        def foo():
7231            return torch.tensor([[2, 2], [1]])
7232        with self.assertRaisesRegex(RuntimeError, "Expected sequence of length"):
7233            foo()
7234
7235    @suppress_warnings
7236    def test_torch_tensor_as_tensor_empty_list(self):
7237        tensor_template = dedent('''
7238        def func():
7239            empty_list = torch.jit.annotate(List[int], [])
7240            ten1 = torch.{tensor_op}({input})
7241            return ten1
7242        ''')
7243        ops = ['tensor', 'as_tensor']
7244        inputs = ['empty_list', '[empty_list, empty_list]', '[[[empty_list]]]']
7245
7246        for op in ops:
7247            for inp in inputs:
7248                code = tensor_template.format(tensor_op=op, input=inp)
7249                scope = {}
7250                exec(code, globals(), scope)
7251                cu = torch.jit.CompilationUnit(code)
7252                t1 = cu.func()
7253                t2 = scope['func']()
7254                if inp == 'empty_list':
7255                    # torchscript returns int tensor, python returns float tensor
7256                    self.assertNotEqual(t1.dtype, t2.dtype)
7257                self.assertEqual(t1, t2, exact_dtype=False)
7258                self.assertEqual(t1.device, t2.device)
7259
7260    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "Simple Executor doesn't have any shapes to propagate")
7261    def test_tensor_as_tensor_shape_prop(self):
7262        tensor_template = dedent('''
7263        def func():
7264            return torch.{tensor_op}({input})
7265        ''')
7266        ops = ['tensor', 'as_tensor']
7267        inputs = ['[1]', '[False]', '[2.5]', '0.5', '1', 'False', '[[1]]', 'torch.jit.annotate(List[List[int]], [])']
7268        expected_shape = ["Long(*, device=cpu)", "Bool(*, device=cpu)",
7269                          "Float(*, device=cpu)", "Float(device=cpu)",
7270                          "Long(device=cpu)", "Bool(device=cpu)", "Long(*, *, device=cpu)"]
7271
7272        for op in ops:
7273            for inp, expect in zip(inputs, expected_shape):
7274                code = tensor_template.format(tensor_op=op, input=inp)
7275                scope = {}
7276                exec(code, globals(), scope)
7277                cu = torch.jit.CompilationUnit(code)
7278                torch._C._jit_pass_complete_shape_analysis(cu.func.graph, (), False)
7279                FileCheck().check(expect).check(f"aten::{op}").run(cu.func.graph)
7280
7281        @torch.jit.script
7282        def test_dtype(inp_dtype: torch.dtype):
7283            a = torch.tensor(1.0, dtype=torch.float, requires_grad=True)
7284            return a, torch.tensor(1.0, dtype=inp_dtype)
7285
7286        if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
7287            g = test_dtype.graph_for(5, profile_and_replay=True)
7288            # both should have completed shapes
7289            FileCheck().check("Tensor = aten::tensor").check("Float(device=cpu) = prim::BailOut") \
7290                       .check("Tensor = aten::tensor").check("Half(device=cpu) = prim::BailOut").run(g)
7291        else:
7292            g = test_dtype.graph_for(5)
7293            # first should have type set second should not
7294            FileCheck().check("Float(requires_grad=1, device=cpu) = aten::tensor") \
7295                       .check("Tensor(requires_grad=0) = aten::tensor").run(g)
7296
7297        @torch.jit.script
7298        def test_as_tensor_tensor_input(input):
7299            a = torch.as_tensor(input, dtype=input.dtype)
7300            return a, torch.as_tensor(input, dtype=torch.float)
7301
7302        if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
7303            g = test_as_tensor_tensor_input.graph_for(torch.ones(3, 4), profile_and_replay=True)
7304            FileCheck().check("Tensor = aten::as_tensor").check("Float(3, 4) = prim::BailOut") \
7305                       .check("Tensor = aten::as_tensor").check("Float(3, 4) = prim::BailOut").run(g)
7306        else:
7307            g = test_as_tensor_tensor_input.graph_for(torch.ones(3, 4))
7308            FileCheck().check("Tensor = aten::as_tensor").check("Float(*, *, requires_grad=0, device=cpu) = aten::as_tensor").run(g)
7309
7310    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "testing legacy behavior")
7311    def test_tensor_requires_grad(self):
7312        @torch.jit.script
7313        def test(b):
7314            # type: (bool) -> Tuple[Tensor, Tensor, Tensor]
7315            a = torch.tensor(1., requires_grad=b)
7316            b = torch.tensor(1., requires_grad=True)
7317            c = torch.tensor(1., requires_grad=False)
7318            return a, b, c
7319
7320        g = test.graph_for(True)
7321        out = next(g.outputs())
7322        out_inp = list(out.node().inputs())
7323
7324        self.assertTrue(out_inp[0].requires_grad())
7325        self.assertTrue(out_inp[1].requires_grad())
7326        self.assertFalse(out_inp[2].requires_grad())
7327
7328    def test_grad_from_script(self):
7329        def test():
7330            a = torch.tensor(2.5, requires_grad=True)
7331            b = a * 2
7332            return a, b
7333
7334        a, b = test()
7335        b.backward()
7336
7337        a_script, b_script = torch.jit.script(test)()
7338        b_script.backward()
7339        self.assertEqual(a.grad, a_script.grad)
7340
7341    def test_torch_tensor_as_tensor(self):
7342        tensor_template = dedent('''
7343        def func():
7344            li = {list_create}
7345            ten1 = torch.{tensor_op}(li {options})
7346            return ten1
7347        ''')
7348
7349        lists = ["2.5", "4", "True", "False", "[2]", "[-.5]", "[False, True, False]", "[2, 2]", "(1, 1)",
7350                 "torch.jit.annotate(List[List[int]], [])",
7351                 "torch.jit.annotate(List[int], [])", "[2.5, 2.5]", "[[2], [2]]", "[[-.5], [2.2]]", "[[False], [True]]"]
7352
7353        dtypes = ["", ", dtype=torch.float", ", dtype=torch.double", ", dtype=torch.half",
7354                  ", dtype=torch.uint8", ", dtype=torch.int8", ", dtype=torch.short",
7355                  ", dtype=torch.int", ", dtype=torch.long", ", dtype=torch.cfloat",
7356                  ", dtype=torch.cdouble"]
7357
7358        ops = ['tensor', 'as_tensor']
7359        devices = ['', ", device='cpu'"]
7360        if RUN_CUDA:
7361            devices.append(", device='cuda'")
7362
7363        option_pairs = [dtype + device for dtype in dtypes for device in devices]
7364        for op in ops:
7365            for li in lists:
7366                for option in option_pairs:
7367                    # tensor from empty list is type float in python and annotated type in torchscript
7368                    if "annotate" in li and "dtype" not in option:
7369                        continue
7370                    # Skip unsigned tensor initializaton for signed values on 3.10
7371                    if sys.version_info[:2] >= (3, 10) and "torch.uint8" in option and "-" in li:
7372                        continue
7373                    code = tensor_template.format(list_create=li, tensor_op=op, options=option)
7374                    scope = {}
7375                    exec(code, globals(), scope)
7376                    cu = torch.jit.CompilationUnit(code)
7377                    t1 = cu.func()
7378                    t2 = scope['func']()
7379                    if t1.dtype == torch.float16:  # equality NYI for half tensor
7380                        self.assertTrue(str(t1) == str(t2))
7381                    else:
7382                        self.assertEqual(t1, t2)
7383                    self.assertEqual(t1.dtype, t2.dtype)
7384                    self.assertEqual(t1.device, t2.device)
7385
7386        def test_as_tensor_tensor_input(input):
7387            # type: (Tensor) -> Tuple[Tensor, Tensor, Tensor]
7388            return torch.as_tensor(input, dtype=torch.cfloat), torch.as_tensor(input, dtype=torch.float), \
7389                torch.as_tensor(input, dtype=torch.int32)
7390
7391        inp = torch.randn(3, 4, dtype=torch.cfloat)
7392        self.checkScript(test_as_tensor_tensor_input, (inp,))
7393
7394    def test_torch_tensor_dtype(self):
7395        def foo(s: float):
7396            return torch.tensor(s), torch.tensor([s, s])
7397
7398        # need to clear function cache so we re run shape analysis
7399        with set_default_dtype(torch.double):
7400            self.assertEqual(torch.jit.script(foo)(1.), foo(1.), exact_dtype=True)
7401            if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
7402                FileCheck().check("Double").check_same("aten::tensor").run(torch.jit.last_executed_optimized_graph())
7403        with set_default_dtype(torch.float):
7404            del torch.jit._state._jit_caching_layer[foo]
7405            self.assertEqual(torch.jit.script(foo)(1.), foo(1.), exact_dtype=True)
7406            if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
7407                FileCheck().check("Float").check_same("aten::tensor").run(torch.jit.last_executed_optimized_graph())
7408        with set_default_dtype(torch.half):
7409            del torch.jit._state._jit_caching_layer[foo]
7410            self.assertEqual(torch.jit.script(foo)(1.), foo(1.), exact_dtype=True)
7411            if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
7412                FileCheck().check("Half").check_same("aten::tensor").run(torch.jit.last_executed_optimized_graph())
7413
7414    def test_shape_analysis_grad_property(self):
7415        @torch.jit.script
7416        def foo(x):
7417            return torch.sub(x, torch.tanh(x))
7418
7419        torch._C._jit_pass_complete_shape_analysis(foo.graph, (torch.tensor([0.39]),), False)
7420
7421        # requires_grad property shouldn't be accidentally set by shape analysis
7422        self.assertTrue(foo.graph.findNode("aten::sub").output().requiresGrad() is None)
7423
7424    def test_empty_like_memory_format_bc(self):
7425        def f(x):
7426            # type: (Tensor) -> Tensor
7427            return torch.zeros_like(x, memory_format=None)
7428
7429        scripted_f = torch.jit.script(f)
7430        x = torch.rand(3, 4)
7431        self.assertEqual(scripted_f(x), f(x))
7432
7433    def test_multiline_string_dedents(self):
7434        def foo() -> None:
7435            multiline_string_dedent_1 = """
7436This is a string dedent """
7437            multiline_string_dedent_2 = """ This is a
7438  string dedent """
7439            multiline_string_dedent_3 = """
7440            This is a string
7441dedent """
7442            multiline_string_dedent_4 = """ This is a string dedent """
7443
7444        scripted_foo = torch.jit.script(foo)
7445        self.assertEqual(scripted_foo(), foo())
7446
7447    def test_class_with_comment_at_lower_indentation(self):
7448        class Foo(torch.nn.Module):
7449            def forward(self, x):
7450                x = torch.neg(x)
7451        # This comment is at the wrong indent
7452                return x
7453
7454        torch.jit.script(Foo())
7455
7456    # adapted from test in test_torch
7457    def test_tensor_to(self):
7458        template = dedent('''
7459        def func(t):
7460            cuda = "{cuda}"
7461            device = "{device}"
7462            non_blocking = {non_blocking}
7463            return {to_str}
7464        ''')
7465
7466        def s(t, to_str, non_blocking=None, device=None, cuda=None):
7467            device = device if device is not None else str(t.device)
7468            non_blocking = non_blocking if non_blocking is not None else False
7469            cuda = "cuda" if cuda is None else cuda
7470            code = template.format(to_str=to_str, device=device, non_blocking=non_blocking, cuda=cuda)
7471            scope = {}
7472            cu = torch.jit.CompilationUnit(code)
7473            return cu.func(t, profile_and_replay=True)
7474
7475        def test_copy_behavior(t, non_blocking=False):
7476            self.assertIs(t, s(t, 't.to(t, non_blocking=non_blocking)', non_blocking))
7477            self.assertIs(t, s(t, 't.to(t.dtype, non_blocking=non_blocking)', non_blocking))
7478            self.assertIs(t, s(t, 't.to(torch.empty_like(t), non_blocking=non_blocking)', non_blocking))
7479            self.assertIsNot(t, s(t, 't.to(t, non_blocking=non_blocking, copy=True)', non_blocking))
7480            self.assertIsNot(t, s(t, 't.to(t.dtype, non_blocking=non_blocking, copy=True)', non_blocking))
7481            self.assertIsNot(t, s(t, 't.to(torch.empty_like(t), non_blocking=non_blocking, copy=True)', non_blocking))
7482
7483            devices = [t.device]
7484            if t.device.type == 'cuda':
7485                if t.device.index == -1:
7486                    devices.append(f'cuda:{torch.cuda.current_device()}')
7487                elif t.device.index == torch.cuda.current_device():
7488                    devices.append('cuda')
7489            for device in devices:
7490                self.assertIs(t, s(t, 't.to(device, non_blocking=non_blocking)', non_blocking, device))
7491                self.assertIs(t, s(t, 't.to(device, t.dtype, non_blocking=non_blocking)', non_blocking, device))
7492                self.assertIsNot(t, s(t, 't.to(device, non_blocking=non_blocking, copy=True)', non_blocking, device))
7493                self.assertIsNot(t, s(t, 't.to(device, t.dtype, non_blocking=non_blocking, copy=True)',
7494                                      non_blocking, device))
7495
7496        t = torch.tensor(5)
7497        test_copy_behavior(t)
7498
7499        self.assertEqual(t.device, s(t, "t.to('cpu')").device)
7500        self.assertEqual(t.device, s(t, "t.to('cpu', dtype=torch.float32)").device)
7501        self.assertIs(torch.float32, s(t, "t.to('cpu', dtype=torch.float32)").dtype)
7502        self.assertEqual(t.device, s(t, "t.to(torch.float32)").device)
7503        self.assertIs(torch.float32, s(t, "t.to(dtype=torch.float32)").dtype)
7504        self.assertEqual(t.data_ptr(), s(t, "t.to('cpu')").data_ptr())
7505        self.assertEqual(t.data_ptr(), s(t, "t.to(dtype=t.dtype, device=t.device, copy=False)").data_ptr())
7506        self.assertEqual(t.data_ptr(), s(t, "t.to('cpu', copy=False)").data_ptr())
7507        self.assertNotEqual(t.data_ptr(), s(t, "t.to('cpu', copy=True)").data_ptr())
7508
7509        a = torch.tensor(5)
7510        if torch.cuda.is_available():
7511            for non_blocking in [True, False]:
7512                for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']:
7513                    b = torch.tensor(5., device=cuda)
7514                    test_copy_behavior(b, non_blocking)
7515                    self.assertEqual(b.device, s(b, "t.to(cuda, non_blocking=non_blocking).device", cuda=cuda))
7516                    self.assertEqual(a.device, s(b, "t.to('cpu', non_blocking=non_blocking).device"))
7517                    self.assertEqual(b.device, s(b, "t.to(cuda, non_blocking=non_blocking).device", cuda=cuda))
7518                    self.assertIs(torch.int32, s(b, "t.to('cpu', dtype=torch.int32, non_blocking=non_blocking)").dtype)
7519                    self.assertEqual(a.device, s(b, "t.to('cpu', dtype=torch.int32, non_blocking=non_blocking)").device)
7520                    self.assertIs(torch.int32, s(b, "t.to(dtype=torch.int32)").dtype)
7521                    self.assertEqual(b.device, s(b, "t.to(dtype=torch.int32)").device)
7522
7523        # Test AD: aten::to(Tensor self, int dtype, bool non_blocking, bool copy) -> Tensor
7524        t = torch.tensor(5).float().requires_grad_()
7525        out_ref = t.to(torch.float32)
7526        out = s(t, "t.to(torch.float32)")
7527        self.assertEqual(out_ref, out)
7528
7529        grad_ref = torch.autograd.grad(out_ref.sum(), t)
7530        grad = torch.autograd.grad(out.sum(), t)
7531        self.assertEqual(grad_ref, grad)
7532
7533        # Test AD: aten::to(Tensor self, Device? device, int? dtype, bool non_blocking, bool copy) -> Tensor
7534        out_ref = t.to('cpu')
7535        out = s(t, "t.to('cpu')")
7536        self.assertEqual(out_ref, out)
7537
7538        grad_ref = torch.autograd.grad(out_ref.sum(), t)
7539        grad = torch.autograd.grad(out.sum(), t)
7540        self.assertEqual(grad_ref, grad)
7541
7542        # Test AD: aten::to(Tensor self, Tensor other, bool non_blocking, bool copy) -> Tensor
7543        @torch.jit.script
7544        def func2(t, t_ref):
7545            return t.to(t_ref)
7546
7547        with disable_autodiff_subgraph_inlining():
7548            t_ref = torch.tensor(4).double()
7549            out_ref = t.to(t_ref)
7550            out = func2(t, t_ref)
7551            grad_ref = torch.autograd.grad(out_ref.sum(), t)
7552            grad = torch.autograd.grad(out.sum(), t)
7553            self.assertEqual(grad_ref, grad)
7554
7555    @unittest.skipIf(not RUN_CUDA, "No CUDA")
7556    def test_tensor_number_math_cuda(self):
7557        self._test_tensor_number_math(device='cuda')
7558
7559    def test_not(self):
7560        # test not operator in python
7561        # TODO: add more tests when bool conversions ready
7562        def test_not_op(a):
7563            return not bool(a > 1)
7564
7565        self.checkScript(test_not_op, (torch.tensor(2), ), optimize=True)
7566
7567    def test_is_isnot(self):
7568        # test is and is not operator in python
7569        template = dedent('''
7570        def func():
7571            # type: () -> bool
7572            return {lhs} {op} {rhs}
7573        ''')
7574
7575        def test(op, args):
7576            code = template.format(lhs=args[0], rhs=args[1], op=op)
7577            scope = {}
7578            execWrapper(code, globals(), scope)
7579            cu = torch.jit.CompilationUnit(code)
7580            self.assertEqual(
7581                cu.func(),
7582                scope['func'](),
7583                msg=f"Failed with op: {op}, lhs: {args[0]}, rhs: {args[1]}"
7584            )
7585
7586        ops = ['is', 'is not']
7587        type_literals = [True, False, None, [1, 1], 1, 2, .5, 1.5]
7588
7589        # do literals product to try any types combinations
7590        for op, lhs, rhs in product(ops, type_literals, type_literals):
7591            test(op, [lhs, rhs])
7592
7593    def test_isinstance_refinement(self):
7594        @torch.jit.script
7595        def foo(a):
7596            # type: (Optional[int]) -> int
7597            if isinstance(a, int):
7598                return a + 3
7599            else:
7600                return 4
7601        self.assertEqual(foo(4), 7)
7602        self.assertEqual(foo(None), 4)
7603
7604        @torch.jit.script
7605        def foo2(a, b):
7606            # type: (Optional[int], Optional[int]) -> int
7607            if not isinstance(a, int) or not isinstance(b, int):
7608                return 0
7609            else:
7610                return a + b
7611        self.assertEqual(foo2(3, 4), 7)
7612        self.assertEqual(foo2(None, 4), 0)
7613        self.assertEqual(foo2(4, None), 0)
7614
7615        @torch.jit.script
7616        def any_refinement(a, b):
7617            # type: (Any, Any) -> int
7618            if isinstance(a, int) and isinstance(b, int):
7619                return a + b
7620            return 0
7621
7622        self.assertEqual(any_refinement(3, 4), 7)
7623        self.assertEqual(any_refinement(3, "hi"), 0)
7624
7625        @torch.jit.script
7626        def any_refinement2(a):
7627            # type: (Any) -> Tensor
7628            if isinstance(a, Tensor):
7629                return a
7630            return torch.tensor(3)
7631
7632        self.assertEqual(any_refinement2(3), torch.tensor(3))
7633        self.assertEqual(any_refinement2(torch.tensor(5)), torch.tensor(5))
7634
7635    @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.LEGACY, "bug persists in deprecated executor")
7636    def test_unspecialized_any_binding(self):
7637        # any binding will infer the type, if it infers
7638        # a specialized tensor type `x` Dict type will fail isinstance check
7639
7640        @torch.jit.script
7641        def foo(x: Any):
7642            assert isinstance(x, Dict[str, torch.Tensor])
7643
7644        foo({"1": torch.tensor(3)})
7645        with self.assertRaises(Exception):
7646            foo(2)
7647
7648    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
7649    def test_isinstance(self):
7650        # test isinstance operator for static type checking
7651        template = dedent('''
7652        def func(x):
7653            # type: ({type_hint}) -> bool
7654            return isinstance(x, {typ})
7655        ''')
7656
7657        def test(inp, typ, type_hint):
7658            code = template.format(typ=typ, type_hint=type_hint)
7659            scope = {}
7660            execWrapper(code, globals(), scope)
7661            cu = torch.jit.CompilationUnit(code)
7662            self.assertEqual(
7663                cu.func(inp),
7664                scope['func'](inp),
7665                msg=f"Failed with typ: {typ}"
7666            )
7667
7668        inputs = [True, 1, 1.0, torch.tensor(1), [1, 2], (1.0,), [1, 2], 1]
7669        type_literals = ['bool', 'int', 'float', 'torch.Tensor', 'list', 'tuple',
7670                         '(list, tuple)', '(int, float, bool)']
7671        type_annotations = ['bool', 'int', 'float', 'Tensor', 'List[int]', 'Tuple[float]',
7672                            'List[int]', 'int']
7673
7674        # do zipping to try different types
7675        for inp, typ, type_hint in zip(inputs, type_literals, type_annotations):
7676            test(inp, typ, type_hint)
7677
7678        # test optional isinstance check
7679        @torch.jit.script
7680        def opt_func(x):
7681            # type: (Optional[int]) -> bool
7682            return isinstance(x, int)
7683        self.assertTrue(opt_func(3))
7684        self.assertFalse(opt_func(None))
7685
7686    def test_dropout_eval(self):
7687        class ScriptedConv2d(torch.jit.ScriptModule):
7688            def __init__(self, in_channels, out_channels, **kwargs):
7689                super().__init__()
7690                self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
7691                self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
7692
7693            @torch.jit.script_method
7694            def forward(self, x):
7695                x = self.conv(x)
7696                x = self.bn(x)
7697                return F.relu(x, inplace=True)
7698
7699        class ScriptMod(torch.jit.ScriptModule):
7700            def __init__(self) -> None:
7701                super().__init__()
7702                self.Conv2d_1a_3x3 = ScriptedConv2d(3, 32, kernel_size=3, stride=2)
7703
7704            @torch.jit.script_method
7705            def forward(self, x):
7706                x = self.Conv2d_1a_3x3(x)
7707                return F.dropout(x, training=self.training)
7708
7709        class EagerConv2d(torch.nn.Module):
7710            def __init__(self, in_channels, out_channels, **kwargs):
7711                super().__init__()
7712                self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
7713                self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
7714
7715            def forward(self, x):
7716                x = self.conv(x)
7717                x = self.bn(x)
7718                return F.relu(x, inplace=True)
7719
7720        class EagerMod(torch.nn.Module):
7721            def __init__(self) -> None:
7722                super().__init__()
7723                self.Conv2d_1a_3x3 = EagerConv2d(3, 32, kernel_size=3, stride=2)
7724
7725            def forward(self, x):
7726                x = self.Conv2d_1a_3x3(x)
7727                return F.dropout(x, training=self.training)
7728
7729        script_input = torch.rand(4, 3, 299, 299)
7730        eager_input = script_input.clone()
7731
7732        with freeze_rng_state():
7733            script_mod = ScriptMod()
7734            script_mod.eval()
7735            script_output = script_mod(script_input)
7736
7737        with freeze_rng_state():
7738            eager_mod = EagerMod()
7739            eager_mod.eval()
7740            eager_output = eager_mod(eager_input)
7741
7742        self.assertEqual(script_output, eager_output)
7743
7744        with freeze_rng_state():
7745            script_mod = ScriptMod()
7746            script_mod.train()
7747            script_output = script_mod(script_input)
7748
7749        with freeze_rng_state():
7750            eager_mod = EagerMod()
7751            eager_mod.train()
7752            eager_output = eager_mod(eager_input)
7753
7754        self.assertEqual(script_output, eager_output)
7755
7756    def test_nested_breaks(self):
7757        def no_bool_loop_outputs(g):
7758            # testing that the "did exit" transform values are not loop block
7759            # outputs (and thus not affecting one loop from another)
7760            loops = g.findAllNodes("prim::Loop")
7761            for loop in loops:
7762                for out in loop.outputs():
7763                    self.assertTrue(out.type() != BoolType.get())
7764
7765        def test(y):
7766            # type: (int)
7767            ret = 0
7768            tensor = torch.tensor(0)
7769            while int(tensor.add_(1)) < 4:
7770                if y == 1:
7771                    continue
7772                for i in range(y):
7773                    continue
7774                    ret += 1
7775                ret += 1
7776            return ret, int(tensor)
7777
7778        self.assertEqual(torch.jit.script(test)(1), test(1))
7779        self.assertEqual(torch.jit.script(test)(2), test(2))
7780        no_bool_loop_outputs(torch.jit.script(test).graph)
7781
7782        def foo():
7783            y = torch.tensor(0)
7784            z = 0
7785            while int(y.add_(1)) < 20:
7786                if int(y) < 10:
7787                    for i in range(6):
7788                        if i == 3:
7789                            continue
7790                        else:
7791                            if i > 3:
7792                                break
7793                        z += 2
7794                if int(y) == 18:
7795                    break
7796                if int(y) == 15:
7797                    continue
7798                z += 1
7799            return int(y), z
7800
7801        no_bool_loop_outputs(torch.jit.script(foo).graph)
7802        self.checkScript(foo, ())
7803
7804        def test_nested_two():
7805            i = 0
7806            k = 0
7807            while i < 5:
7808                for j in range(5):
7809                    k += 1
7810                    if j == 3:
7811                        continue
7812                i += 1
7813                k += 1
7814                if i == 4:
7815                    break
7816            return i, k
7817
7818        self.checkScript(test_nested_two, ())
7819        no_bool_loop_outputs(torch.jit.script(test_nested_two).graph)
7820
7821    def test_breaks_continues(self):
7822        def foo_continue(cond):
7823            # type: (int)
7824            j = 1
7825            for i in range(5):
7826                if i == cond:
7827                    continue
7828                j += 1
7829            return j
7830
7831        def foo_break(cond):
7832            # type: (int)
7833            j = 1
7834            for i in range(5):
7835                if i == cond:
7836                    break
7837                j += 1
7838            return j
7839
7840        for i in range(1, 4):
7841            self.checkScript(foo_continue, (i,))
7842            self.checkScript(foo_break, (i,))
7843
7844        def test_refine_outside_loop():
7845            if 1 == 1:
7846                x = None
7847            else:
7848                x = 1
7849            i = 0
7850            j = 0
7851            while (x is None or torch.jit._unwrap_optional(x) > 3):
7852                if i < 3:
7853                    if i < 3:
7854                        x = torch.jit.annotate(Optional[int], None)
7855                        i += 1
7856                        continue
7857                    x = 1
7858                else:
7859                    x = 1 if x is None else x
7860                x = x + 1
7861                j = x + x
7862
7863            return x, j
7864
7865        self.checkScript(test_refine_outside_loop, ())
7866
7867        def assign_after_break(y):
7868            # type: (int)
7869            x = 0
7870            for i in range(y):
7871                x = y * 2 + i
7872                break
7873                x = 4
7874            return x
7875
7876        self.checkScript(assign_after_break, (1,))
7877        self.checkScript(assign_after_break, (2,))
7878        self.checkScript(assign_after_break, (3,))
7879
7880        def assign_after_break_nested(y):
7881            # type: (int)
7882            x = 0
7883            for i in range(y):
7884                if y == 1:
7885                    x = 5
7886                    break
7887                    assert 1 == 2
7888                else:
7889                    x = x + 1
7890                    break
7891                    assert 1 == 2
7892                x = -30
7893                assert 1 == 2
7894            return x
7895
7896        self.checkScript(assign_after_break_nested, (1,))
7897        self.checkScript(assign_after_break_nested, (2,))
7898        self.checkScript(assign_after_break_nested, (3,))
7899
7900        def may_break(y):
7901            # type: (int)
7902            x = 0
7903            for i in range(y):
7904                if y == 1:
7905                    x = 5
7906                else:
7907                    x = x + 1
7908                    break
7909                x = -30
7910            return x
7911
7912        self.checkScript(may_break, (1,))
7913        self.checkScript(may_break, (2,))
7914        self.checkScript(may_break, (3,))
7915
7916        def test(x, y):
7917            # type: (int, int)
7918            a = 1
7919            while (x > 0):
7920                if y == 3:
7921                    for i in range(y):
7922                        a += (1 % (i + 1))
7923                        x -= 1
7924                if x == 3:
7925                    a = x * 3
7926                    break
7927                if x < 3:
7928                    if x == 1:
7929                        a -= 2
7930                        x -= 1
7931                        break
7932                a -= 1
7933                x -= 3
7934            return a, x
7935
7936        self.checkScript(test, (10, 3))
7937        self.checkScript(test, (10, 2))
7938        self.checkScript(test, (3, 2))
7939        self.checkScript(test, (5, 3))
7940        self.checkScript(test, (2, 3))
7941
7942        def test_delete_after_break(x):
7943            # type: (int)
7944            a = 1
7945            b = 1
7946            for i in range(x):
7947                a = i * 3
7948                break
7949                b = i * 5
7950            return a, b
7951
7952        self.checkScript(test_delete_after_break, (0,))
7953        self.checkScript(test_delete_after_break, (1,))
7954
7955        def test_will_break_after_guard(x):
7956            # type: (int)
7957            a = 1
7958            for i in range(x):
7959                if i == 4:
7960                    a = 3
7961                    break
7962                a -= 1
7963                break
7964                assert 1 == 2
7965                a -= -100
7966            return a
7967
7968        self.checkScript(test_will_break_after_guard, (0,))
7969        self.checkScript(test_will_break_after_guard, (2,))
7970        self.checkScript(test_will_break_after_guard, (4,))
7971
7972        def test_varexit(cond):
7973            # type: (int)
7974            m = 0
7975            for i in range(3):
7976                if cond == 2:
7977                    if cond == 2:
7978                        m = 2
7979                        break
7980                    k = 1
7981                else:
7982                    k = 2
7983                m += k
7984            return m
7985
7986        # use of k tests the pathway where we have to insert unitialized
7987        self.checkScript(test_varexit, (3,))
7988        self.checkScript(test_varexit, (2,))
7989
7990        def test_break_true():
7991            i = 0
7992            while True:
7993                i += 1
7994                if i == 3:
7995                    break
7996            while False:
7997                i += 1
7998            return i
7999
8000        self.checkScript(test_break_true, ())
8001
8002    def test_break_continue_error(self):
8003        with self.assertRaisesRegex(RuntimeError, "Syntax"):
8004            cu = torch.jit.CompilationUnit('''
8005            def other_func(a):
8006                break
8007                ''')
8008
8009        with self.assertRaisesRegex(RuntimeError, "Syntax"):
8010            cu = torch.jit.CompilationUnit('''
8011            def other_func(a):
8012                for i in range(5):
8013                    def foo():
8014                        break
8015                ''')
8016
8017        with self.assertRaisesRegex(RuntimeError, "do not support break or continue inside"):
8018            @torch.jit.script
8019            def foo(x):
8020                i = 0
8021                for a in (1, "2", 1.5):
8022                    b = a
8023                    if x:
8024                        break
8025                return b
8026
8027    def test_python_call(self):
8028        def pyfunc(a):
8029            return a * 3.0
8030
8031        cu = torch.jit.CompilationUnit('''
8032        def other_func(a):
8033            return a + a
8034
8035        def test_call_python(a):
8036            b = pyfunc(a)
8037            b = other_func(b)
8038            i = 0
8039            step = 1
8040            while i < 10:
8041                b = pyfunc(b)
8042                if bool(b > 3.0):
8043                    b = pyfunc(b)
8044                i = 11
8045            return b
8046        ''')
8047        inputs = self._make_scalar_vars([1], torch.float)
8048        outputs = self._make_scalar_vars([54], torch.float)
8049
8050        self.assertEqual(cu.test_call_python(*inputs), outputs[0])
8051
8052    def test_python_call_failure(self):
8053        with self.assertRaisesRegex(RuntimeError, "undefined value pyfunc2"):
8054            def pyfunc(a):
8055                return a * 3.0
8056
8057            cu = torch.jit.CompilationUnit('''
8058            def other_func(a):
8059                return a + a
8060
8061            def test_call_python(a):
8062                b = pyfunc(a)
8063                b = other_func(b)
8064                i = 0
8065                step = 1
8066                while i < 10:
8067                    b = pyfunc2(b)
8068                    if b > 3.0:
8069                        b = pyfunc(b)
8070                    i = 11
8071                return b
8072            ''')
8073            inputs = self._make_scalar_vars([1], torch.float)
8074            outputs = self._make_scalar_vars([54], torch.float)
8075
8076            self.assertEqual(cu.test_call_python(*inputs), outputs)
8077
8078    def test_type_call_in_script(self):
8079        @torch.jit.script
8080        def fn(x):
8081            return type(x)
8082
8083        with self.assertRaisesRegex(RuntimeError, "value of type _TensorMeta"):
8084            fn(torch.tensor(.5))
8085
8086    def test_python_call_annotation(self):
8087        def pyfunc(a):
8088            return a * 3.0
8089
8090        @torch.jit.script
8091        def foo(a):
8092            return pyfunc(a) + pyfunc(a)
8093
8094        inputs = self._make_scalar_vars([1], torch.float)
8095        outputs = self._make_scalar_vars([6], torch.float)
8096        self.assertEqual(foo(*inputs), outputs[0])
8097
8098    def test_python_call_annoytation_failure(self):
8099        with self.assertRaisesRegex(RuntimeError, "undefined value pyfunc2"):
8100            def pyfunc(a):
8101                return a * 3.0
8102
8103            @torch.jit.script
8104            def foo(a):
8105                return pyfunc2(a) + pyfunc(a)  # noqa: F821
8106
8107            inputs = self._make_scalar_vars([1], torch.float)
8108            outputs = self._make_scalar_vars([6], torch.float)
8109
8110            self.assertEqual(foo(*inputs), outputs[0])
8111
8112    def test_desugar_module(self):
8113        import torch.nn.functional as F
8114
8115        def fn(x, slope):
8116            a = torch.abs(x)
8117            b = torch.nn.functional.prelu(x, slope)
8118            c = F.prelu(x, slope)
8119            return a, b, c
8120
8121        x = torch.arange(-3., 4)
8122        slope = torch.tensor([0.5])
8123        self.checkScript(fn, [x, slope], optimize=True)
8124
8125    def test_script_docstring(self):
8126        @torch.jit.script
8127        def with_docstring(x):
8128            """test str"""
8129            y = x
8130            """y is the same as x"""
8131            return y
8132        self.assertEqual(with_docstring.__doc__, 'test str')
8133
8134    def test_script_method_docstring(self):
8135        class A(torch.jit.ScriptModule):
8136            @torch.jit.script_method
8137            def with_docstring(self, x):
8138                """test str"""
8139                y = x
8140                """y is the same as x"""
8141                return y
8142        a = A()
8143        self.assertEqual(a.with_docstring.__doc__, 'test str')
8144
8145    def test_script_module(self):
8146        class M1(torch.jit.ScriptModule):
8147            def __init__(self) -> None:
8148                super().__init__()
8149                self.weight = nn.Parameter(torch.randn(2))
8150
8151            @torch.jit.script_method
8152            def forward(self, thing):
8153                return self.weight + thing
8154
8155        class PModule(nn.Module):
8156            def __init__(self) -> None:
8157                super().__init__()
8158                self.a = nn.Parameter(torch.randn(2, 3))
8159
8160            def forward(self, a):
8161                return self.a.mm(a)
8162
8163        class M2(torch.jit.ScriptModule):
8164            def __init__(self) -> None:
8165                super().__init__()
8166                # test submodule
8167                self.sub = M1()
8168                self.sub2 = PModule()
8169                # test parameters
8170                self.weight = nn.Parameter(torch.randn(2, 3))
8171                self.bias = nn.Parameter(torch.randn(2))
8172                # test defining a method from a string
8173                self.define("""
8174                    def hi(self, a):
8175                        return self.weight.mm(a)
8176                """)
8177            # test script methods
8178
8179            @torch.jit.script_method
8180            def doit(self, input):
8181                # test use of parameter
8182                return self.weight.mm(input)
8183
8184            @torch.jit.script_method
8185            def doit2(self, input):
8186                return self.weight.mm(input)
8187
8188            @torch.jit.script_method
8189            def forward(self, input):
8190                a = self.doit(input)
8191                b = self.doit2(input)
8192                c = self.hi(input)
8193                d = self.sub2(input)
8194                return a + b + self.bias + self.sub(a) + c + d
8195        with torch.jit.optimized_execution(False):
8196            m2 = M2()
8197            input = torch.randn(3, 2)
8198            a = m2.weight.mm(input)
8199            b = m2.weight.mm(input)
8200            c = m2.weight.mm(input)
8201            d = m2.sub2.a.mm(input)
8202            ref = a + b + m2.bias + m2.sub.weight + a + c + d
8203            self.assertEqual(ref, m2.forward(input))
8204            m2.weight = nn.Parameter(torch.zeros_like(m2.weight))
8205            m2.bias = nn.Parameter(torch.zeros_like(m2.bias))
8206            m2.sub.weight = nn.Parameter(torch.zeros_like(m2.sub.weight))
8207            m2.sub2.a.data.zero_()
8208            self.assertEqual(torch.zeros(2, 2), m2.forward(torch.randn(3, 2)))
8209
8210    def test_irparser(self):
8211        graph_str = """graph(%0 : Double(5, 5)):
8212          # CHECK: aten::relu
8213          %1 : Double(5, 5) = aten::relu(%0)
8214          return (%1)
8215        """
8216        FileCheck().run(graph_str, parse_ir(graph_str))
8217
8218    def test_parse_tensor_constants(self):
8219        def foo():
8220            return torch.zeros([4, 4])
8221
8222        foo_s = torch.jit.script(foo)
8223        torch._C._jit_pass_constant_propagation(foo_s.graph)
8224
8225        g = str(foo_s.graph)
8226        g_parsed = parse_ir(g, parse_tensor_constants=True)
8227        self.assertEqual(str(canonical(g_parsed)), str(canonical(foo_s.graph)))
8228        func = torch._C._create_function_from_graph("forward", g_parsed)
8229
8230        out_parsed = func()
8231        out_func = foo()
8232        # not checking data, just dtype, size etc
8233        out_parsed[:] = 0
8234        out_func[:] = 0
8235        self.assertEqual(out_func, out_parsed)
8236
8237        with self.assertRaises(RuntimeError):
8238            parse_ir(g, parse_tensor_constants=False)
8239
8240    def test_parse_nested_names(self):
8241        g_str = """
8242    graph(%x.1 : Tensor):
8243        %3 : int = prim::Constant[value=1]()
8244        %2 : int = prim::Constant[value=2]()
8245        %hi.submod.value.5 : Tensor = aten::add(%x.1, %2, %3)
8246        return (%hi.submod.value.5)
8247        """
8248        g = parse_ir(g_str)
8249        round_trip_g = parse_ir(str(g))
8250        self.assertEqual(canonical(g), canonical(round_trip_g))
8251
8252        func1 = torch._C._create_function_from_graph("forward", g)
8253        func2 = torch._C._create_function_from_graph("forward", round_trip_g)
8254        self.assertEqual(func1(torch.ones([2])), func2(torch.ones([2])))
8255
8256    def test_is_after_use(self):
8257        def sorted_input_use(g):
8258            uses = list(next(g.inputs()).uses())
8259            return sorted(uses, key=functools.cmp_to_key(type(uses[0]).isAfter))
8260
8261        @torch.jit.script
8262        def foo(x):
8263            a = x + 1
8264            return (x, x, a)
8265
8266        uses_sorted = sorted_input_use(foo.graph)
8267        # sorts last use to the end
8268        self.assertFalse(uses_sorted[0].isAfter(uses_sorted[1]))
8269        self.assertTrue(uses_sorted[0].user.kind() == "aten::add")
8270        self.assertEqual(uses_sorted[1].offset, 0)
8271
8272        @torch.jit.script
8273        def foo(x, cond: bool):
8274            if cond:
8275                return x + 3
8276            else:
8277                return x - 3
8278
8279        uses_sorted = sorted_input_use(foo.graph)
8280        self.assertTrue(uses_sorted[0].user.kind() == "aten::add")
8281        self.assertTrue(uses_sorted[1].user.kind() == "aten::sub")
8282
8283        @torch.jit.script
8284        def foo(x, cond: bool, cond2: bool):
8285            if cond:
8286                return x + 3
8287            elif cond2 :
8288                return x - 3
8289
8290            return x / 3
8291
8292        graph1 = foo.graph
8293
8294        @torch.jit.script
8295        def foo(x, cond: bool, cond2: bool):
8296            if cond:
8297                return x + 3
8298            else:
8299                if cond2 :
8300                    return x - 3
8301                return x / 3
8302
8303        graph2 = foo.graph
8304
8305        for graph in [graph1, graph2]:
8306            uses_sorted = sorted_input_use(graph)
8307            self.assertTrue(uses_sorted[0].user.kind() == "aten::add")
8308            self.assertTrue(uses_sorted[1].user.kind() == "aten::sub")
8309            self.assertTrue(uses_sorted[2].user.kind() == "aten::div")
8310
8311    def test_canonicalize_control_outputs(self):
8312        def test_all_outputs(g):
8313            ifs = g.findAllNodes("prim::If")
8314            loops = g.findAllNodes("prim::Loop")
8315
8316            def contained_blocks(node):
8317                return len(node.findAllNodes("prim::If")) * 2 + len(node.findAllNodes("prim::Loop"))
8318            for node in ifs + loops:
8319                outs = list(node.outputs())
8320                out_name = [x.debugName() for x in outs]
8321                if len(out_name) == 0:
8322                    continue
8323                fc = FileCheck()
8324                # find the last output, then all subsequent uses
8325                fc.check(out_name[-1] + " : ")
8326                # skip past node body
8327                for i in range(contained_blocks(node)):
8328                    fc.check("->")
8329                if (node.kind() == "prim::If"):
8330                    fc.check("->").check("->").check("\n")
8331                else:
8332                    fc.check("->").check("\n")
8333                # the canonical order is the same order as the first use
8334                # appears in text
8335                for name in out_name:
8336                    fc.check(name)
8337                fc.run(g)
8338
8339        @torch.jit.script
8340        def test(x):
8341            # type: (bool) -> Tuple[int, int]
8342            b = 2
8343            a = 1
8344            if x:
8345                a = 1
8346                b = 2
8347                x = False
8348            if x:
8349                b = a
8350            else:
8351                a = b
8352
8353            return a, b
8354        test_all_outputs(test.graph)
8355
8356        @torch.jit.script
8357        def test2(x):
8358            # type: (bool) -> Tuple[int, int]
8359            b = 2
8360            a = 1
8361            if x:
8362                a = 1
8363                b = 2
8364                x = False
8365            if x:
8366                print(a)
8367            else:
8368                if x:
8369                    print(b)
8370
8371            return a, b
8372        test_all_outputs(test2.graph)
8373
8374        @torch.jit.script
8375        def test_loop(x, iter):
8376            # type: (bool, int) -> (None)
8377            a = 1
8378            b = 2
8379            c = 3
8380            for i in range(iter):
8381                a = 4
8382                b = 5
8383                c = 6
8384                x = True
8385            print(c)
8386            if x:
8387                print(a, b)
8388        test_all_outputs(test_loop.graph)
8389
8390        @torch.jit.script
8391        def loop_unused(iter):
8392            # type: (int) -> (None)
8393            a = 1
8394            b = 2
8395            c = 3
8396            for i in range(iter):
8397                c = c + 1
8398                b = b + 1
8399                a = a + 1
8400                print(a, b)
8401            print(c)
8402
8403        # c is used, then unused should be ordered by alphabetical
8404        FileCheck().check(r"%c : int, %a : int, %b : int").run(loop_unused.graph)
8405
8406    def test_filecheck(self):
8407        def test_check():
8408            file = "232"
8409            FileCheck().check("2").check("3").check("2").run(file)
8410            FileCheck().check("232").run(file)
8411
8412            with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'):
8413                FileCheck().check("22").run(file)
8414            with self.assertRaisesRegex(RuntimeError, "CHECK: 3"):
8415                FileCheck().check("3").check("3").run(file)
8416
8417        test_check()
8418
8419        def test_check_count():
8420            file = "22222"
8421            FileCheck().check_count("2", 5).run(file)
8422            FileCheck().check_count("22", 2).run(file)
8423            FileCheck().check_count("222", 1).run(file)
8424
8425            with self.assertRaisesRegex(RuntimeError, 'Expected to not find'):
8426                FileCheck().check_count("2", 4, exactly=True).run(file)
8427
8428            with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'):
8429                FileCheck().check_count("22", 3).run(file)
8430
8431            with self.assertRaisesRegex(RuntimeError, "CHECK-COUNT-6: 2"):
8432                FileCheck().check_count("2", 6).run(file)
8433
8434        test_check_count()
8435
8436        def test_check_same():
8437            file = "22\n33"
8438            FileCheck().check_same("22").run(file)
8439
8440            with self.assertRaisesRegex(RuntimeError, "Expected to not find"):
8441                FileCheck().check_same("33").run(file)
8442
8443            file = "22  1  3"
8444
8445            FileCheck().check("2").check_same("3").run(file)
8446            FileCheck().check_count("2", 2).check_same("3").run(file)
8447
8448        test_check_same()
8449
8450        def test_check_next():
8451            file = "\n1\n2\n3"
8452            FileCheck().check("1").check_next("2").check_next("3").run(file)
8453            FileCheck().check_next("1").check_next("2").check_next("3").run(file)
8454
8455            with self.assertRaisesRegex(RuntimeError, "Expected to find"):
8456                FileCheck().check("1").check_next("2").run("12")
8457
8458            with self.assertRaisesRegex(RuntimeError, "Expected to not find"):
8459                FileCheck().check("1").check_next("2").run("1\n\n2")
8460
8461        test_check_next()
8462
8463        def test_check_dag():
8464            fc = FileCheck().check_dag("1").check_dag("2").check_not("2")
8465            fc.run("12")
8466            fc.run("21")
8467
8468            fc = FileCheck()
8469            fc.check_not("3").check_dag("1").check_dag("2").check_not("3")
8470            fc.run("1 3 2")
8471            fc.run("2 3 1")
8472
8473            fc = FileCheck().check_dag("1").check_dag("2").check("3")
8474            with self.assertRaisesRegex(RuntimeError, 'Expected to find "3" but did not find it'):
8475                fc.run("1 3 2")
8476
8477        test_check_dag()
8478
8479        def test_check_not():
8480            FileCheck().check_not("2").check("1").run("12")
8481            FileCheck().check("2").check_not("2").run("12")
8482
8483            with self.assertRaisesRegex(RuntimeError, 'Expected to not find "2"'):
8484                FileCheck().check_not("2").check("1").run("21")
8485
8486            with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'):
8487                FileCheck().check("2").check_not("1").run("21")
8488
8489            # checks with distinct range matchings
8490            fb = FileCheck().check_count("2", 2).check_count("2", 2).check_not("2")
8491            with self.assertRaisesRegex(RuntimeError, 'Expected to not find "2"'):
8492                fb.run("22 2 22")
8493
8494            fb = FileCheck().check_count("2", 2).check_not("1").check_count("2", 2)
8495            with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'):
8496                fb.run("22 1 22")
8497
8498    def _dtype_to_jit_name(self, dtype):
8499        if dtype == torch.float32:
8500            return "Float"
8501        if dtype == torch.float64:
8502            return "Double"
8503        if dtype == torch.int64:
8504            return "Long"
8505        if dtype == torch.int32:
8506            return "Int"
8507        if dtype == torch.bool:
8508            return "Bool"
8509        raise RuntimeError('dtype not handled')
8510
8511    def _dtype_to_expect(self, dtype, dim=0):
8512        param = ', '.join(['*'] * dim + ['device=cpu'])
8513        param = '(' + param + ')'
8514        jit_type = self._dtype_to_jit_name(dtype)
8515        if dim >= 0:
8516            return jit_type + param
8517        # special case representing wrapped number
8518        else:
8519            return jit_type.lower()
8520
8521
8522    def _test_dtype_op_shape(self, ops, args, input_dims=1):
8523        if input_dims < 1:
8524            raise RuntimeError("input dims must be at least 1")
8525        dtypes = [torch.float32, torch.float64, torch.int64, torch.int32]
8526        str_args = ', '.join([str(arg) for arg in args]) + (', ' if len(args) else '')
8527        tensor_data = ('[' * input_dims) + '1, 2, 3' + (input_dims * ']')
8528        template = dedent('''
8529        def func():
8530            return {return_line}
8531        ''')
8532
8533        for op in ops:
8534            for dtype in (dtypes + [None]):
8535                for tensor_type in dtypes:
8536                    # a couple of ops aren't implemented for non-floating types
8537                    if not tensor_type.is_floating_point or (dtype is not None and not dtype.is_floating_point):
8538                        if op in ['mean', 'softmax', 'log_softmax']:
8539                            continue
8540                    return_line = f"torch.tensor({tensor_data}, dtype={tensor_type}).{op}({str_args}dtype={dtype})"
8541                    # uncomment for debugging a failed test:
8542                    # print("testing {}".format(return_line))
8543                    code = template.format(return_line=return_line)
8544                    scope = {}
8545                    exec(code, globals(), scope)
8546                    cu = torch.jit.CompilationUnit(code)
8547                    graph = cu.func.graph
8548                    torch._C._jit_pass_complete_shape_analysis(graph, (), False)
8549                    input_array = [1, 2, 3]
8550                    for _ in range(1, input_dims):
8551                        input_array = [input_array]
8552                    t = torch.tensor(input_array, dtype=tensor_type)
8553                    attr = getattr(t, op)
8554                    kwargs = {'dtype': dtype}
8555                    result = attr(*args, **kwargs)
8556                    expect = self._dtype_to_expect(result.dtype, result.dim())
8557                    FileCheck().check("aten::tensor").check(expect).run(graph)
8558
8559    def test_dtype_op_shape(self):
8560        ops = ['prod']
8561        self._test_dtype_op_shape(ops, args=[])
8562        self._test_dtype_op_shape(ops, args=[0, False])
8563        self._test_dtype_op_shape(ops, args=[0, False])
8564        self._test_dtype_op_shape(ops, args=[0, True])
8565
8566    def test_dtype_op_shape2(self):
8567        ops = ['cumprod', 'cumsum', 'softmax', 'log_softmax']
8568        self._test_dtype_op_shape(ops, args=[0])
8569
8570        self._test_dtype_op_shape(ops, args=[1], input_dims=4)
8571
8572
8573    def _test_binary_op_shape(self, ops, input_dims=1):
8574
8575        dtypes = [torch.float32, torch.float64, torch.int64, torch.int32, torch.bool]
8576
8577        if input_dims == 0:
8578            shape = '1'
8579        else:
8580            shape = '[' + ('1,' * 4) + ']'
8581            for _ in range(1, input_dims):
8582                shape = '[' + ",".join([shape] * 4) + ']'
8583
8584        template = dedent('''
8585        def func():
8586            arg1 = {}
8587            arg2 = {}
8588            return torch.{}(arg1, arg2)
8589        ''')
8590
8591        args = []
8592        for dtype in dtypes:
8593            args = args + [f"torch.tensor({shape}, dtype={dtype})"]
8594        args = args + [1, 1.5]
8595
8596        def isBool(arg):
8597            return type(arg) == bool or (type(arg) == str and "torch.bool" in arg)
8598
8599        for op in ops:
8600            for first_arg in args:
8601                for second_arg in args:
8602                    # subtract not supported for bool
8603                    if (op == 'sub' or op == 'div') and (isBool(first_arg) or isBool(second_arg)):
8604                        continue
8605                    # div is not implemented correctly for mixed-type or int params
8606                    if (op == 'div' and (type(first_arg) != type(second_arg) or
8607                       isinstance(first_arg, int) or
8608                       (isinstance(first_arg, str) and 'int' in first_arg))):
8609                        continue
8610                    return_line = f"torch.{op}({first_arg}, {second_arg})"
8611                    # uncomment for debugging a failed test:
8612                    # print("testing {}".format(return_line))
8613                    code = template.format(first_arg, second_arg, op)
8614                    scope = {}
8615                    exec(code, globals(), scope)
8616                    non_jit_result = scope['func']()
8617
8618                    cu = torch.jit.CompilationUnit(code)
8619                    graph = cu.func.graph
8620                    torch._C._jit_pass_complete_shape_analysis(graph, (), False)
8621                    # use dim=-1 to represent a python/jit scalar.
8622                    dim = -1 if type(first_arg) != str and type(second_arg) != str else non_jit_result.dim()
8623                    dtype = non_jit_result.dtype
8624                    # jit only supports int/float scalars.
8625                    if dim < 0:
8626                        if dtype == torch.int64:
8627                            dtype = torch.int32
8628                        if dtype == torch.float64:
8629                            dtype = torch.float32
8630                    expect = self._dtype_to_expect(dtype, dim)
8631                    jit_output = next(graph.outputs())
8632
8633                    check = FileCheck()
8634                    check.check(expect).run(str(jit_output))
8635
8636    def test_binary_op_shape(self):
8637        self._test_binary_op_shape(['mul', 'div', 'add', 'sub'], 0)
8638        self._test_binary_op_shape(['mul', 'div', 'add', 'sub'], 3)
8639
8640    def test_no_dtype_shape(self):
8641
8642        @torch.jit.script
8643        def foo(x):
8644            scalar_number = x.item()
8645            return x.add(scalar_number)
8646
8647        @torch.jit.script
8648        def foo2(x):
8649            scalar_number = x.item()
8650            return torch.tensor(1).add(scalar_number)
8651
8652        t = torch.tensor(5)
8653        g = foo.graph_for(t)
8654        type = next(g.outputs())
8655        self.assertTrue(type.type() == torch._C.TensorType.get())
8656        g2 = foo2.graph_for(t)
8657        type = next(g.outputs())
8658        self.assertTrue(type.type() == torch._C.TensorType.get())
8659
8660
8661    def test_filecheck_parse(self):
8662        def test_check():
8663            file = """
8664                # CHECK: 2
8665                # CHECK: 3
8666                # CHECK: 2
8667                232
8668                """
8669            FileCheck().run(checks_file=file, test_file=file)
8670            file = """
8671                # CHECK: 232
8672                232
8673                """
8674            FileCheck().run(file, "232")
8675            with self.assertRaisesRegex(RuntimeError, 'Expected to find "232"'):
8676                FileCheck().run(file, "22")
8677            with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'):
8678                FileCheck().run("# CHECK: 22", "23")
8679        test_check()
8680
8681        def test_check_count():
8682            file = "22222"
8683            FileCheck().run("# CHECK-COUNT-5: 2", file)
8684            FileCheck().run("# CHECK-COUNT-EXACTLY-5: 2", file)
8685            FileCheck().run("# CHECK-COUNT-2: 22", file)
8686            FileCheck().run("# CHECK-COUNT-1: 222", file)
8687
8688            with self.assertRaisesRegex(RuntimeError, 'Expected to not find'):
8689                FileCheck().run("# CHECK-COUNT-EXACTLY-2: 2", file)
8690        test_check_count()
8691
8692        def test_check_same():
8693            file = "22\n33"
8694            FileCheck().run("# CHECK-SAME: 22", file)
8695
8696            with self.assertRaisesRegex(RuntimeError, "Expected to not find"):
8697                FileCheck().run("# CHECK-SAME: 33", file)
8698
8699            file = "22  1  3"
8700
8701            FileCheck().run("# CHECK: 2\n # CHECK-SAME: 3", file)
8702            FileCheck().run("# CHECK-COUNT-2: 2\n # CHECK-SAME: 3", file)
8703        test_check_same()
8704
8705        def test_bad_input():
8706            with self.assertRaisesRegex(RuntimeError, "Check for bad input"):
8707                FileCheck().run("", "1")
8708
8709            with self.assertRaisesRegex(RuntimeError, "Could not parse check"):
8710                FileCheck().run("# CHECK1", "")
8711
8712        test_bad_input()
8713
8714    def test_script_module_call_noscript(self):
8715        class M(torch.jit.ScriptModule):
8716            def __init__(self) -> None:
8717                super().__init__()
8718                self.value = 1
8719
8720            @torch.jit.ignore
8721            def foo(self):
8722                return torch.ones(2, 2) + self.value
8723
8724            @torch.jit.script_method
8725            def forward(self, input):
8726                return input + self.foo()
8727
8728        with torch.jit.optimized_execution(False):
8729            m = M()
8730            input = torch.randn(2, 2)
8731            o = m(input)
8732            self.assertEqual(o, input + torch.ones(2, 2) + 1)
8733            # check that we can change python attributes
8734            # and that those changes are picked up in script methods
8735            m.value = 2
8736            o = m(input)
8737            self.assertEqual(o, input + torch.ones(2, 2) + 2)
8738
8739    def test_script_module_nochange_submodule(self):
8740        class M(torch.jit.ScriptModule):
8741            def __init__(self) -> None:
8742                super().__init__()
8743                self.sub = nn.Linear(5, 5)
8744
8745            @torch.jit.script_method
8746            def forward(self, input):
8747                return self.sub(input)
8748        with torch.jit.optimized_execution(False):
8749            m = M()
8750            input = torch.randn(1, 5, 5)
8751            o = m(input)
8752            self.assertEqual(o, m.sub(input))
8753            with self.assertRaisesRegex(RuntimeError, "Cannot re-assign"):
8754                m.sub = nn.Linear(5, 5)
8755
8756    def test_module_apis(self):
8757        class Sub(torch.nn.Module):
8758            def forward(self, thing):
8759                return thing - 2
8760
8761        class Double(torch.nn.Module):
8762            def forward(self, thing):
8763                return thing * 2
8764
8765        class MyMod(torch.nn.Module):
8766            def __init__(self) -> None:
8767                super().__init__()
8768                self.mod = (Sub())
8769                self.mod2 = (Sub())
8770                self.mod3 = nn.Sequential(nn.Sequential(Sub()))
8771                self.mod4 = nn.Sequential(Sub(), Double())
8772
8773            @torch.jit.export
8774            def method(self, x, x1, y, y1):
8775                mod_names = ""
8776                for name, mod in self.named_modules():
8777                    mod_names = mod_names + " " + name
8778                    x = mod(x)
8779
8780                children_names = ""
8781                for name, mod in self.named_children():
8782                    children_names = children_names + " " + name
8783                    x1 = mod(x1)
8784
8785                for mod in self.modules():
8786                    y = mod(y)
8787
8788                for mod in self.children():
8789                    y1 = mod(y1)
8790
8791                return mod_names, children_names, x, x1, y, y1
8792
8793            def forward(self, x):
8794                return x + 2
8795
8796        mod = torch.jit.script(MyMod())
8797        inps = tuple([torch.tensor(i) for i in range(1, 5)])
8798        self.assertEqual(mod.method(*inps), MyMod().method(*inps))
8799
8800    def test_script_module_const(self):
8801        class M(torch.jit.ScriptModule):
8802
8803            __constants__ = ['b', 'i', 'c', 's']
8804
8805            def __init__(self) -> None:
8806                super().__init__()
8807                self.b = False
8808                self.i = 1
8809                self.c = 3.5
8810                self.s = ["hello"]
8811
8812            @torch.jit.script_method
8813            def forward(self):
8814                return self.b, self.i, self.c
8815
8816        with torch.jit.optimized_execution(False):
8817            m = M()
8818            o0, o1, o2 = m()
8819        self.assertEqual(o0, 0)
8820        self.assertEqual(o1, 1)
8821        self.assertEqual(o2, 3.5)
8822
8823    def test_script_module_fail_exist(self):
8824        class M(torch.jit.ScriptModule):
8825            @torch.jit.script_method
8826            def forward(self, x):
8827                return x + self.whatisgoingon
8828        with self.assertRaisesRegex(RuntimeError, "Module 'M' has no attribute"):
8829            M()
8830
8831    @unittest.skip("[module dedupe] currently NoneType refinement on optional attributes doesn't work.")
8832    def test_script_module_none_exist_fail(self):
8833        class M(torch.jit.ScriptModule):
8834            def __init__(self, my_optional):
8835                super().__init__()
8836                self.my_optional = my_optional
8837
8838            @torch.jit.script_method
8839            def forward(self, x):
8840                if self.my_optional is not None:
8841                    return torch.neg(x) + self.my_optional
8842                return torch.neg(x)
8843        with self.assertRaisesRegex(RuntimeError, "has no attribute 'my_optional'"):
8844            x = torch.rand(3, 4)
8845            fb = M(None)
8846            fb(x)
8847
8848    def test_script_module_invalid_consts(self):
8849        class Foo(torch.jit.ScriptModule):
8850            __constants__ = ['invalid']
8851
8852            def __init__(self) -> None:
8853                super().__init__()
8854                self.invalid = [nn.Linear(3, 4)]
8855
8856        with self.assertRaisesRegex(
8857                TypeError,
8858                "Linear' object in attribute 'Foo.invalid' is not a valid constant"):
8859            Foo()
8860
8861        class Foo2(torch.jit.ScriptModule):
8862            __constants__ = ['invalid']
8863
8864            def __init__(self) -> None:
8865                super().__init__()
8866                self.invalid = int
8867
8868        with self.assertRaisesRegex(TypeError, "not a valid constant"):
8869            Foo2()
8870
8871        class Foo3(torch.jit.ScriptModule):
8872            __constants__ = ['invalid']
8873
8874            def __init__(self) -> None:
8875                super().__init__()
8876                self.invalid = (3, 4, {})
8877
8878        with self.assertRaisesRegex(TypeError, "not a valid constant"):
8879            Foo3()
8880
8881        class Foo4(torch.jit.ScriptModule):
8882            __constants__ = ['invalid']
8883
8884            def __init__(self) -> None:
8885                super().__init__()
8886                self.invalid = np.int64(5)
8887
8888        # verify that we capture human understandable class name
8889        with self.assertRaisesRegex(TypeError, "numpy.int64"):
8890            Foo4()
8891
8892    def test_script_module_param_buffer_mutation(self):
8893        # TODO: add param mutation test case after JIT support it
8894        class ModuleBufferMutate(torch.jit.ScriptModule):
8895            def __init__(self) -> None:
8896                super().__init__()
8897                self.running_var = nn.Buffer(torch.tensor(0, dtype=torch.long))
8898
8899            @torch.jit.script_method
8900            def forward(self):
8901                if self.training:
8902                    self.running_var += 1
8903                return self.running_var
8904
8905        with torch.jit.optimized_execution(False):
8906            m = ModuleBufferMutate()
8907            self.assertEqual(m(), 1)
8908            m.eval()
8909            self.assertEqual(m(), 1)
8910
8911    def test_script_module_for(self):
8912        class M(torch.jit.ScriptModule):
8913            __constants__ = ['b']
8914
8915            def __init__(self) -> None:
8916                super().__init__()
8917                self.b = [1, 2, 3, 4]
8918
8919            @torch.jit.script_method
8920            def forward(self):
8921                sum = 0
8922                for i in self.b:
8923                    sum += i
8924                return sum
8925
8926        with torch.jit.optimized_execution(False):
8927            m = M()
8928            self.assertEqual(m(), 10)
8929
8930    def test_override_magic(self):
8931        class OverrideMagic(nn.Module):
8932            @torch.jit.export
8933            def __len__(self):
8934                return 10
8935
8936        mod = OverrideMagic()
8937        self.assertEqual(len(mod), len(torch.jit.script(mod)))
8938
8939        class OverrideMagicSeq(nn.Sequential):
8940            @torch.jit.export
8941            def __len__(self):
8942                return 10
8943
8944        mod = OverrideMagicSeq()
8945        self.assertEqual(len(mod), len(torch.jit.script(mod)))
8946        self.assertTrue(torch.jit.script(mod))
8947
8948    def test_script_module_for2(self):
8949        class Sub(torch.jit.ScriptModule):
8950            def __init__(self) -> None:
8951                super().__init__()
8952                self.weight = nn.Parameter(torch.randn(2))
8953
8954            @torch.jit.script_method
8955            def forward(self, thing):
8956                return self.weight + thing
8957
8958        class M(torch.jit.ScriptModule):
8959            def __init__(self) -> None:
8960                super().__init__()
8961                self.mods = nn.ModuleList([Sub() for i in range(10)])
8962
8963            @torch.jit.script_method
8964            def forward(self, v):
8965                for m in self.mods:
8966                    v = m(v)
8967                return v
8968
8969        with torch.jit.optimized_execution(False):
8970            i = torch.empty(2)
8971            m = M()
8972            o = m(i)
8973            v = i
8974            for sub in m.mods:
8975                v = sub(v)
8976            self.assertEqual(o, v)
8977            with self.assertRaisesRegex(Exception, "object is not iterable"):
8978                print(list(m))
8979
8980    def test_attr_qscheme_script(self):
8981        class Foo(torch.nn.Module):
8982            def __init__(self) -> None:
8983                super().__init__()
8984                self.qscheme = torch.per_tensor_affine
8985
8986            def forward(self):
8987                if self.qscheme == torch.per_tensor_symmetric:
8988                    return 3
8989                else:
8990                    return 4
8991
8992        f = Foo()
8993        scripted = torch.jit.script(f)
8994        self.assertEqual(f(), scripted())
8995
8996    def test_script_module_const_submodule_fail(self):
8997        class Sub(torch.jit.ScriptModule):
8998            def __init__(self) -> None:
8999                super().__init__()
9000                self.weight = nn.Parameter(torch.randn(2))
9001
9002            @torch.jit.script_method
9003            def forward(self, thing):
9004                return self.weight + thing
9005
9006        class M(torch.jit.ScriptModule):
9007            def __init__(self) -> None:
9008                super().__init__()
9009                self.mods = [Sub() for _ in range(10)]
9010
9011            @torch.jit.script_method
9012            def forward(self):
9013                for _ in self.mods:
9014                    print(1)
9015                return 4
9016
9017        with self.assertRaisesRegex(RuntimeError, "has no attribute 'mods'"):
9018            M()
9019
9020    class DerivedStateModule(torch.jit.ScriptModule):
9021        def __init__(self) -> None:
9022            super(TestScript.DerivedStateModule, self).__init__()
9023            self.param = torch.nn.Parameter(torch.ones(3, 4, dtype=torch.float))
9024            self.derived = nn.Buffer(torch.neg(self.param).detach().clone())
9025
9026            # This is a flag so we can test that the pack method was called
9027            self.pack_called = nn.Buffer(torch.zeros(1, dtype=torch.long))
9028            # This is a flag so we can test that the unpack method was called
9029            self.unpack_called = nn.Buffer(torch.zeros(1, dtype=torch.long))
9030
9031        @torch.jit.script_method
9032        def _pack(self):
9033            self.pack_called.set_(torch.ones(1, dtype=torch.long))
9034            self.derived.set_(torch.rand(1).detach())
9035
9036        @torch.jit.script_method
9037        def _unpack(self):
9038            self.unpack_called.set_(torch.ones(1, dtype=torch.long))
9039            self.derived.set_(torch.neg(self.param).detach())
9040
9041        @torch.jit.script_method
9042        def forward(self, x):
9043            return x + self.derived
9044
9045    def test_pack_unpack_state(self):
9046        sm = TestScript.DerivedStateModule()
9047        x = torch.rand(3, 4)
9048        torch.testing.assert_close(sm(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float)))
9049
9050        # Test save path
9051        self.assertFalse(sm.pack_called.item())
9052        self.assertFalse(sm.unpack_called.item())
9053        imported = self.getExportImportCopyWithPacking(sm)
9054        # ensure pack was called before serialization
9055        self.assertTrue(sm.pack_called.item())
9056        # ensure unpack was called after serialization so as to leave the module in an initialized state
9057        self.assertTrue(sm.unpack_called.item())
9058
9059        torch.testing.assert_close(sm.derived, torch.neg(sm.param))
9060
9061        # Test load paths
9062        self.assertTrue(imported.unpack_called.item())
9063        torch.testing.assert_close(imported(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float)))
9064
9065    @unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support")
9066    @unittest.skipIf(True, "Skipping while landing PR stack")
9067    def test_torch_functional(self):
9068        def stft(input, n_fft):
9069            # type: (Tensor, int) -> Tensor
9070            return torch.stft(input, n_fft, return_complex=True)
9071
9072        inps = (torch.randn(10), 7)
9073        self.assertEqual(stft(*inps), torch.jit.script(stft)(*inps))
9074
9075        def istft(input, n_fft):
9076            # type: (Tensor, int) -> Tensor
9077            return torch.istft(input, n_fft)
9078
9079        inps2 = (stft(*inps), inps[1])
9080        self.assertEqual(istft(*inps2), torch.jit.script(istft)(*inps2))
9081
9082        def lu_unpack(x):
9083            A_LU, pivots = torch.linalg.lu_factor(x)
9084            return torch.lu_unpack(A_LU, pivots)
9085
9086        for shape in ((3, 3), (5, 3, 3), (7, 3, 5, 5), (7, 5, 3, 3, 3)):
9087            a = torch.randn(*shape)
9088            self.checkScript(lu_unpack, (a,))
9089
9090        def cdist_fn():
9091            a = torch.tensor([[0.9041, 0.0196], [-0.3108, -2.4423], [-0.4821, 1.059]])
9092            b = torch.tensor([[-2.1763, -0.4713], [-0.6986, 1.3702]])
9093            return torch.cdist(a, b, compute_mode="use_mm_for_euclid_dist")
9094
9095        self.checkScript(cdist_fn, ())
9096
9097        def norm():
9098            c = torch.tensor([[1, 2, 3], [-1, 1, 4]], dtype=torch.float)
9099            return torch.norm(c, p="fro"), torch.norm(c, p="nuc"), torch.norm(c), torch.norm(c, p=.5)
9100
9101        self.checkScript(norm, ())
9102
9103        def torch_unique(dim: Optional[int]):
9104            ten = torch.unique(torch.tensor([[1, 3], [2, 3]], dtype=torch.long))
9105            a = torch.unique(ten, dim=dim)
9106            b = torch.unique(ten, return_counts=True, dim=dim)
9107            c = torch.unique(ten, return_inverse=True, dim=dim)
9108            d = torch.unique(ten, return_counts=True, return_inverse=True, dim=dim)
9109            return a, b, c, d
9110
9111        self.checkScript(torch_unique, (None,))
9112        self.checkScript(torch_unique, (0,))
9113
9114        def torch_unique_consecutive(dim: Optional[int]):
9115            ten = torch.unique(torch.tensor([[1, 3], [3, 2], [3, 2], [2, 3]], dtype=torch.long))
9116            a = torch.unique_consecutive(ten, dim=dim)
9117            b = torch.unique_consecutive(ten, return_counts=True, dim=dim)
9118            c = torch.unique_consecutive(ten, return_inverse=True, dim=dim)
9119            d = torch.unique_consecutive(ten, return_counts=True, return_inverse=True, dim=dim)
9120            return a, b, c, d
9121
9122        self.checkScript(torch_unique_consecutive, (None,))
9123        self.checkScript(torch_unique_consecutive, (0,))
9124
9125    def test_torch_functional_tensordot_int(self):
9126        def tensordot_dims_int(a: torch.Tensor, b: torch.Tensor, dims: int):
9127            return torch.tensordot(a, b, dims=dims)
9128
9129        a = torch.arange(120.).reshape(2, 3, 4, 5)
9130        b = torch.arange(840.).reshape(4, 5, 6, 7)
9131        dims = 2
9132        self.checkScript(tensordot_dims_int, (a, b, dims))
9133
9134        for dims in [-1, 5]:
9135            try:
9136                tensordot_dims_int(a, b, dims)
9137            except RuntimeError as error:
9138                if dims < 0:
9139                    self.assertEqual(str(error), "tensordot expects dims >= 0, but got dims=" + str(dims))
9140                if dims > min(a.dim(), b.dim()):
9141                    self.assertEqual(str(error), "tensordot expects dims < ndim_a or ndim_b, but got dims=" + str(dims))
9142
9143    def test_torch_functional_tensordot_tensor(self):
9144        def tensordot_dims_tensor(a: torch.Tensor, b: torch.Tensor, dims: torch.Tensor):
9145            return torch.tensordot(a, b, dims=dims)
9146
9147        a = torch.arange(120.).reshape(2, 3, 4, 5)
9148        b = torch.arange(840.).reshape(4, 5, 6, 7)
9149        dims = torch.tensor([2])
9150        self.checkScript(tensordot_dims_tensor, (a, b, dims))
9151
9152        a = torch.arange(60.).reshape(3, 4, 5)
9153        b = torch.arange(24.).reshape(4, 3, 2)
9154        dims = torch.tensor([[1, 0], [0, 1]], dtype=torch.long)
9155        self.checkScript(tensordot_dims_tensor, (a, b, dims))
9156
9157    def test_torch_functional_tensordot_list(self):
9158        def tensordot_dims_list(a: torch.Tensor, b: torch.Tensor, dims: List[List[int]]):
9159            return torch.tensordot(a, b, dims=dims)
9160
9161        a = torch.arange(60.).reshape(3, 4, 5)
9162        b = torch.arange(24.).reshape(4, 3, 2)
9163        dims = [[1, 0], [0, 1]]
9164        self.checkScript(tensordot_dims_list, (a, b, dims))
9165
9166    def test_torch_functional_tensordot_tuple(self):
9167        def tensordot_dims_tuple(a: torch.Tensor, b: torch.Tensor, dims: Tuple[List[int], List[int]]):
9168            return torch.tensordot(a, b, dims=dims)
9169
9170        a = torch.arange(60.).reshape(3, 4, 5)
9171        b = torch.arange(24.).reshape(4, 3, 2)
9172        dims = ([1, 0], [0, 1])
9173        self.checkScript(tensordot_dims_tuple, (a, b, dims))
9174
9175    def test_missing_getstate(self):
9176        class Foo(torch.nn.Module):
9177            def __init__(self) -> None:
9178                super().__init__()
9179                self.x = 1
9180
9181            def forward(self, x):
9182                return x * self.x
9183
9184            @torch.jit.export
9185            def __setstate__(self, state):
9186                self.x = state[0]
9187                self.training = state[1]
9188
9189        with self.assertRaisesRegex(RuntimeError, "getstate"):
9190            scripted = torch.jit.script(Foo())
9191
9192    def test_inlining_cleanup(self):
9193        def foo(x):
9194            return F.linear(x, x)
9195
9196        @torch.jit.script
9197        def fee(x):
9198            return foo(x)
9199
9200        # inlining optimizations should have cleaned up linear if statement
9201        self.run_pass("inline", fee.graph)
9202        FileCheck().check_not("prim::If").run(fee.graph)
9203
9204    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
9205    def test_pack_unpack_nested(self):
9206        class SubSubMod(torch.jit.ScriptModule):
9207            def __init__(self) -> None:
9208                super().__init__()
9209                self.buf = nn.Buffer(torch.ones(3, 4) * 3)
9210
9211            @torch.jit.script_method
9212            def _pack(self):
9213                self.buf.set_(torch.zeros(1))
9214
9215            @torch.jit.script_method
9216            def _unpack(self):
9217                self.buf.set_(torch.ones(3, 4) * 3)
9218
9219            @torch.jit.script_method
9220            def forward(self, x):
9221                return x + self.buf
9222
9223        class SubMod(torch.jit.ScriptModule):
9224            def __init__(self) -> None:
9225                super().__init__()
9226                self.buf = nn.Buffer(torch.ones(3, 4) * 2)
9227                self.ssm = SubSubMod()
9228
9229            @torch.jit.script_method
9230            def _pack(self):
9231                self.buf.set_(torch.zeros(1))
9232
9233            @torch.jit.script_method
9234            def _unpack(self):
9235                self.buf.set_(torch.ones(3, 4) * 2)
9236
9237            @torch.jit.script_method
9238            def forward(self, x):
9239                return self.ssm(x + self.buf)
9240
9241        class Mod(torch.jit.ScriptModule):
9242            def __init__(self) -> None:
9243                super().__init__()
9244                self.submod = SubMod()
9245                self.buf = nn.Buffer(torch.ones(3, 4) * 1)
9246
9247            @torch.jit.script_method
9248            def _pack(self):
9249                self.buf.set_(torch.zeros(1))
9250
9251            @torch.jit.script_method
9252            def _unpack(self):
9253                self.buf.set_(torch.ones(3, 4))
9254
9255            @torch.jit.script_method
9256            def forward(self, x):
9257                return self.submod(x + self.buf)
9258
9259        m = Mod()
9260        torch.testing.assert_close(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6)
9261        m.apply(lambda s: s._pack())
9262        torch.testing.assert_close(m(torch.zeros(3, 4)), torch.zeros(3, 4))
9263        m.apply(lambda s: s._unpack())
9264        torch.testing.assert_close(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6)
9265
9266    def test_torch_any(self):
9267        def fn(x):
9268            return torch.any(x)
9269
9270        def fn1(x, dim: int):
9271            return torch.any(x, dim)
9272
9273        self.checkScript(fn, (torch.randn(3, 4), ))
9274        self.checkScript(fn, (torch.empty(3), ))
9275        self.checkScript(fn, (torch.empty(1), ))
9276        self.checkScript(fn, (torch.ones(3, 4),))
9277        self.checkScript(fn, (torch.zeros(5, 7, 1),))
9278        self.checkScript(fn1, (torch.empty(3, 4), -2))
9279        self.checkScript(fn1, (torch.randn(3, 8), 1))
9280        self.checkScript(fn1, (torch.zeros(3, 6, 9), -3))
9281        self.checkScript(fn1, (torch.empty(5), 0))
9282
9283    def test_any(self):
9284        def fn(x: List[int]):
9285            return any(x)
9286
9287        def fn1(x: List[float]):
9288            return any(x)
9289
9290        def fn2(x: List[bool]):
9291            return any(x)
9292
9293        def fn3(x: List[str]):
9294            return any(x)
9295
9296        self.checkScript(fn, ([0, 0, 0, 0], ))
9297        self.checkScript(fn, ([0, 3, 0], ))
9298        self.checkScript(fn, ([], ))
9299        self.checkScript(fn1, ([1.0, 2.0, 3.0], ))
9300        self.checkScript(fn1, ([0.0, 0.0, 0.0], ))
9301        self.checkScript(fn1, ([0, 0, 0], ))
9302        self.checkScript(fn1, ([], ))
9303        self.checkScript(fn2, ([True, False, False], ))
9304        self.checkScript(fn2, ([False, False, False], ))
9305        self.checkScript(fn2, ([True, True, True, True], ))
9306        self.checkScript(fn2, ([], ))
9307        self.checkScript(fn3, (["", "", ""], ))
9308        self.checkScript(fn3, (["", "", "", "-1"], ))
9309        self.checkScript(fn3, ([], ))
9310
9311    def test_script_module_not_tuple(self):
9312        class M(torch.jit.ScriptModule):
9313            __constants__ = ['mods']
9314
9315            def __init__(self) -> None:
9316                super().__init__()
9317                self.mods = 1
9318
9319            @torch.jit.script_method
9320            def forward(self, v):
9321                for m in self.mods:
9322                    print(m)
9323                return v
9324        with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"):
9325            M()
9326
9327    def test_attr_module_constants(self):
9328        class M2(torch.jit.ScriptModule):
9329            def __init__(self, mod_list):
9330                super().__init__()
9331                self.mods = mod_list
9332
9333            @torch.jit.script_method
9334            def forward(self, x):
9335                return self.mods.forward(x)
9336
9337        with torch.jit.optimized_execution(False):
9338            m = M2(nn.Sequential(nn.ReLU()))
9339            self.assertExportImportModule(m, (torch.randn(2, 2),))
9340
9341    def test_script_sequential_for(self):
9342        class Sub(torch.jit.ScriptModule):
9343            def __init__(self) -> None:
9344                super().__init__()
9345                self.weight = nn.Parameter(torch.randn(2))
9346
9347            @torch.jit.script_method
9348            def forward(self, thing):
9349                return self.weight + thing
9350
9351        class M(torch.jit.ScriptModule):
9352            def __init__(self) -> None:
9353                super().__init__()
9354                self.mods = nn.Sequential(Sub(), Sub(), Sub())
9355
9356            @torch.jit.script_method
9357            def forward(self, v):
9358                for m in self.mods:
9359                    v = m(v)
9360                return v
9361
9362            @torch.jit.script_method
9363            def forward2(self, v):
9364                return self.mods(v)
9365
9366        with torch.jit.optimized_execution(False):
9367            i = torch.empty(2)
9368            m = M()
9369            o = m(i)
9370            v = i
9371            for sub in m.mods._modules.values():
9372                v = sub(v)
9373            self.assertEqual(o, v)
9374
9375            o2 = m.forward2(i)
9376            self.assertEqual(o2, v)
9377
9378    def test_script_sequential_sliced_iteration(self):
9379        class seq_mod(nn.Module):
9380            def __init__(self) -> None:
9381                super().__init__()
9382                self.layers = [nn.ReLU(), nn.ReLU(), nn.ReLU()]
9383                self.layers = nn.Sequential(*self.layers)
9384
9385            def forward(self, input):
9386                x = self.layers[0].forward(input)
9387                for layer in self.layers[1:3]:
9388                    x = layer.forward(x)
9389                for layer in self.layers[2:]:
9390                    x = layer.forward(x)
9391                return x
9392
9393        seq = seq_mod()
9394        self.checkModule(seq, [torch.tensor([-2, 1, -1, 2])])
9395
9396    def test_script_sequential_orderdict(self):
9397        class M(torch.jit.ScriptModule):
9398            def __init__(self) -> None:
9399                super().__init__()
9400                self.mods = nn.Sequential(OrderedDict([
9401                    ("conv", nn.Conv2d(1, 20, 5)),
9402                    ("relu", nn.ReLU())
9403                ]))
9404
9405            @torch.jit.script_method
9406            def forward(self, input):
9407                return self.mods(input)
9408
9409        m = M()
9410        self.assertTrue('mods.conv.weight' in m.state_dict().keys())
9411
9412    def test_script_sequential_multi_output_fail(self):
9413        class Sub(torch.jit.ScriptModule):
9414            def __init__(self) -> None:
9415                super().__init__()
9416                self.weight = nn.Parameter(torch.randn(2))
9417
9418            @torch.jit.script_method
9419            def forward(self, thing):
9420                return self.weight + thing
9421
9422        class ReturnMulti(torch.jit.ScriptModule):
9423            @torch.jit.script_method
9424            def forward(self, x):
9425                return x, x, x
9426
9427        class HaveSequential(torch.jit.ScriptModule):
9428            def __init__(self) -> None:
9429                super().__init__()
9430                self.someseq = nn.Sequential(
9431                    Sub(),
9432                    ReturnMulti(),
9433                    Sub()
9434                )
9435
9436            @torch.jit.script_method
9437            def forward(self, x):
9438                return self.someseq(x)
9439
9440        with self.assertRaisesRegex(RuntimeError, "(Tensor, Tensor, Tensor)"):
9441            with torch.jit.optimized_execution(False):
9442                hs = HaveSequential()
9443                i = torch.empty(2)
9444                hs(i)
9445
9446    @_tmp_donotuse_dont_inline_everything
9447    def test_script_sequential_in_mod_list(self):
9448        class Sub(torch.jit.ScriptModule):
9449            def __init__(self) -> None:
9450                super().__init__()
9451                self.weight = nn.Parameter(torch.randn(2))
9452
9453            @torch.jit.script_method
9454            def forward(self, thing):
9455                return self.weight + thing
9456
9457        class M(torch.jit.ScriptModule):
9458            def __init__(self) -> None:
9459                super().__init__()
9460                self.mods = nn.ModuleList([Sub(), nn.Sequential(Sub(), nn.Sequential(Sub(), Sub()), Sub())])
9461
9462            @torch.jit.script_method
9463            def forward(self, v):
9464                for mod in self.mods:
9465                    v = mod(v)
9466                return v
9467
9468        m = M()
9469        graph = str(m.graph)
9470        self.assertTrue(graph.count("prim::CallMethod") == 2)
9471        self.assertTrue("python" not in graph)
9472
9473    @_tmp_donotuse_dont_inline_everything
9474    def test_script_nested_mod_list(self):
9475        class Sub(torch.jit.ScriptModule):
9476            def __init__(self) -> None:
9477                super().__init__()
9478                self.weight = nn.Parameter(torch.randn(2))
9479
9480            @torch.jit.script_method
9481            def forward(self, thing):
9482                return self.weight + thing
9483
9484        class M(torch.jit.ScriptModule):
9485            def __init__(self) -> None:
9486                super().__init__()
9487                self.mods = nn.ModuleList([nn.ModuleList([Sub()]), nn.Sequential(Sub()), nn.ModuleList([Sub(), Sub()])])
9488
9489            @torch.jit.script_method
9490            def forward(self, v):
9491                for mod in self.mods:
9492                    for m in mod:
9493                        v = m(v)
9494                return v
9495
9496        m = M()
9497        graph = str(m.graph)
9498        self.assertTrue(graph.count("prim::CallMethod") == 4)
9499        self.assertTrue("python" not in graph)
9500
9501    def test_constant_as_attr(self):
9502        class M(torch.jit.ScriptModule):
9503            __constants__ = ['dim']
9504
9505            def __init__(self) -> None:
9506                super().__init__()
9507                self.dim = 1
9508
9509            @torch.jit.script_method
9510            def forward(self, v):
9511                return torch.cat([v, v, v], dim=self.dim)
9512        v = torch.zeros(1, 1)
9513        with torch.jit.optimized_execution(False):
9514            self.assertEqual(torch.cat([v, v, v], dim=1), M()(v))
9515
9516    class StarTestSumStarred(torch.nn.Module):
9517        def __init__(self) -> None:
9518            super(TestScript.StarTestSumStarred, self).__init__()
9519
9520        def forward(self, *inputs):
9521            output = inputs[0]
9522            for i in range(1, len(inputs)):
9523                output += inputs[i]
9524            return output
9525
9526    class StarTestReturnThree(torch.nn.Module):
9527        def __init__(self) -> None:
9528            super(TestScript.StarTestReturnThree, self).__init__()
9529
9530        def forward(self, rep):
9531            return rep, rep, rep
9532
9533    def test_script_star_expr(self):
9534
9535        class M2(torch.jit.ScriptModule):
9536            def __init__(self) -> None:
9537                super().__init__()
9538                self.m = torch.jit.trace(TestScript.StarTestSumStarred(),
9539                                         (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)))
9540                self.g = torch.jit.trace(TestScript.StarTestReturnThree(), torch.ones(4, 3))
9541
9542            @torch.jit.script_method
9543            def forward(self, rep):
9544                tup = self.g(rep)
9545                return self.m(*tup)
9546
9547        m = M2()
9548        self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
9549
9550    def test_script_star_expr_string(self):
9551        class M2(torch.jit.ScriptModule):
9552            def __init__(self) -> None:
9553                super().__init__()
9554                self.m = torch.jit.trace(TestScript.StarTestSumStarred(),
9555                                         (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)))
9556                self.g = torch.jit.trace(TestScript.StarTestReturnThree(), torch.ones(4, 3))
9557
9558                self.define('''
9559            def forward(self, rep):
9560                tup = self.g(rep)
9561                return self.m(*tup)
9562                ''')
9563
9564        m = M2()
9565        self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
9566
9567    class StarTestSumAndReturnThree(torch.nn.Module):
9568        def __init__(self) -> None:
9569            super(TestScript.StarTestSumAndReturnThree, self).__init__()
9570
9571        def forward(self, *inputs):
9572            output = inputs[0]
9573            for i in range(1, len(inputs)):
9574                output += inputs[i]
9575            return output, output, output
9576
9577    def test_script_star_assign(self):
9578        class M2(torch.jit.ScriptModule):
9579            def __init__(self) -> None:
9580                super().__init__()
9581                self.g = torch.jit.trace(TestScript.StarTestSumAndReturnThree(), torch.ones(4, 3))
9582                self.define('''
9583            def forward(self, rep):
9584                head, *tail = self.g(rep)
9585                return head
9586                ''')
9587
9588        m = M2()
9589        self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
9590
9591    def test_script_module_star_assign2(self):
9592        class M2(torch.jit.ScriptModule):
9593            def __init__(self) -> None:
9594                super().__init__()
9595                self.g = torch.jit.trace(
9596                    TestScript.StarTestSumAndReturnThree(),
9597                    (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)),
9598                    _force_outplace=True)
9599                self.define('''
9600            def forward(self, rep):
9601                *head, tail = self.g(rep, rep, rep)
9602                return tail
9603                ''')
9604
9605        m = M2()
9606        self.assertEqual(m(torch.ones(4, 3)), 3 * torch.ones(4, 3))
9607
9608    def test_script_module_star_assign2_inplace(self):
9609        class M2(torch.jit.ScriptModule):
9610            def __init__(self) -> None:
9611                super().__init__()
9612                self.g = torch.jit.trace(
9613                    TestScript.StarTestSumAndReturnThree(),
9614                    (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)),
9615                    _force_outplace=False)
9616                self.define('''
9617            def forward(self, rep):
9618                *head, tail = self.g(rep, rep, rep)
9619                return tail
9620                ''')
9621
9622        m = M2()
9623        # since forward() makes three aliases to the input `rep` before passing
9624        # it to StarTestSumAndReturnThree(), in-place behavior will be different
9625        # than the above out of place.
9626        self.assertEqual(m(torch.ones(4, 3)), 4 * torch.ones(4, 3))
9627
9628    def test_script_module_star_assign_fail_pythonop(self):
9629
9630        with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
9631            class M2(torch.jit.ScriptModule):
9632                def __init__(self) -> None:
9633                    super().__init__()
9634
9635                    @torch.jit.ignore
9636                    def myfunc():
9637                        return torch.zeros(1, 2, 3), torch.zeros(1, 2, 3)
9638
9639                    self.define('''
9640                def forward(self, rep):
9641                    a, *b = myfunc()
9642                    return a
9643                    ''')
9644
9645            m = M2()
9646            m(torch.zeros(4, 3))
9647
9648    def test_script_module_star_assign_fail_builtin(self):
9649        with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
9650            class M2(torch.jit.ScriptModule):
9651                def __init__(self) -> None:
9652                    super().__init__()
9653
9654                    self.define('''
9655                def forward(self, rep):
9656                    a, *b = torch.neg(rep)
9657                    return a
9658                    ''')
9659
9660            m = M2()
9661            m(torch.zeros(4, 3))
9662
9663    def test_script_pack_padded_sequence(self):
9664        from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
9665
9666        def pack_padded_pad_packed_script(x, seq_lens):
9667            x = pack_padded_sequence(x, seq_lens)
9668            x, lengths = pad_packed_sequence(x)
9669            return x, lengths
9670
9671        T, B, C = 3, 5, 7
9672        x = torch.ones((T, B, C))
9673        seq_lens = torch.tensor([3, 3, 2, 2, 1])
9674        # set padding value so we can test equivalence
9675        for b in range(B):
9676            if seq_lens[b] < T:
9677                x[seq_lens[b]:, b, :] = 0
9678
9679        eager_seq, eager_lengths = pack_padded_pad_packed_script(x, seq_lens)
9680        with torch._jit_internal._disable_emit_hooks():
9681            scripted_pack_padded_seq = torch.jit.script(pack_padded_pad_packed_script)
9682        script_seq, script_lengths = scripted_pack_padded_seq(x, seq_lens)
9683        self.assertEqual(eager_seq, script_seq)
9684        self.assertEqual(eager_lengths, script_lengths)
9685
9686        class ExperimentalLSTM(torch.nn.Module):
9687            def __init__(self, input_dim, hidden_dim):
9688                super().__init__()
9689
9690            def forward(self, input):
9691                # type: (Tensor)
9692                packed = pack_padded_sequence(
9693                    input=input, lengths=torch.tensor([1, 2]), enforce_sorted=False
9694                )
9695                output, lengths = pad_packed_sequence(
9696                    sequence=packed, total_length=2
9697                )
9698                # lengths is flipped, so is output
9699                return output[0]
9700
9701        lstm = ExperimentalLSTM(input_dim=2, hidden_dim=2)
9702
9703        with torch._jit_internal._disable_emit_hooks():
9704            self.checkModule(lstm, [torch.ones(2, 2)])
9705
9706    def test_script_pad_sequence_pack_sequence(self):
9707        from torch.nn.utils.rnn import pad_sequence, pack_sequence, pad_packed_sequence
9708
9709        def pad_sequence_func(tensor_list, batch_first=False, padding_value=0.0, padding_side="right"):
9710            # type: (List[Tensor], bool, float, str) -> Tensor
9711            return pad_sequence(tensor_list, batch_first, padding_value, padding_side)
9712
9713        def pack_sequence_func(tensor_list, enforce_sorted=True):
9714            # type: (List[Tensor], bool) -> Tensor
9715            return pad_packed_sequence(pack_sequence(tensor_list, enforce_sorted))[0]
9716
9717        ones3 = torch.ones(3, 5)
9718        ones4 = torch.ones(4, 5)
9719        ones5 = torch.ones(5, 5)
9720        tensor1 = torch.tensor([1, 2, 3])
9721        tensor2 = torch.tensor([4, 5])
9722        tensor3 = torch.tensor([6])
9723        with torch._jit_internal._disable_emit_hooks():
9724            self.checkScript(pad_sequence_func,
9725                             ([ones3, ones4, ones5],))
9726            self.checkScript(pad_sequence_func,
9727                             ([ones3, ones4, ones5], True))
9728            self.checkScript(pad_sequence_func,
9729                             ([ones3, ones4, ones5], True, 2.5))
9730            self.checkScript(pad_sequence_func,
9731                             ([ones3, ones4, ones5], True, 2.5, "left"))
9732            self.checkScript(pad_sequence_func,
9733                             ([ones3, ones4, ones5], False, 2.5, "left"))
9734            self.checkScript(pack_sequence_func,
9735                             ([tensor1, tensor2, tensor3],))
9736            self.checkScript(pack_sequence_func,
9737                             ([tensor1, tensor2, tensor3], False))
9738
9739    def test_script_get_tracing_state(self):
9740        def test_if_tracing(x):
9741            if torch._C._get_tracing_state():
9742                return x + 1
9743            else:
9744                return x - 1
9745
9746        inp = torch.randn(3, 3)
9747        self.checkScript(test_if_tracing, (inp,))
9748
9749    def test_script_is_tracing(self):
9750        def test_is_tracing(x):
9751            if torch.jit.is_tracing():
9752                return x + 1
9753            else:
9754                return x - 1
9755
9756        inp = torch.randn(3, 3)
9757        self.checkScript(test_is_tracing, (inp,))
9758
9759    def test_is_scripting(self):
9760        def foo():
9761            return torch.jit.is_scripting()
9762
9763        self.assertFalse(foo())
9764        scripted = torch.jit.script(foo)
9765        self.assertTrue(scripted())
9766
9767    def test_comment_ignore_indent(self):
9768        class Model(torch.nn.Module):
9769            def __init__(self) -> None:
9770    # useless comment that is not indented correctly  # noqa: E115
9771                super().__init__()
9772
9773            def forward(self):
9774                return 5
9775
9776        # should compile without an error
9777        self.checkModule(Model(), ())
9778
9779    def test_script_outputs(self):
9780        with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
9781            @torch.jit.script
9782            def foo(a):
9783                c, d = a + a
9784                return c + d
9785
9786        @torch.jit.script
9787        def return3():
9788            return 1, 2, 3
9789
9790        with self.assertRaisesRegex(RuntimeError, "too many values to unpack"):
9791            @torch.jit.script
9792            def bind2():
9793                a, b = return3()
9794                print(a)
9795                print(b)
9796
9797    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
9798    def test_script_get_device_cuda(self):
9799        @torch.jit.script
9800        def foo(a):
9801            return a.get_device()
9802
9803        v = torch.randn(1, device='cuda')
9804        self.assertEqual(foo(v), 0)
9805
9806    def test_script_chunk(self):
9807        @torch.jit.script
9808        def foo(a):
9809            b, c = torch.chunk(a, dim=0, chunks=2)
9810            return b
9811        v = torch.rand(10, 3)
9812        self.assertEqual(torch.chunk(v, dim=0, chunks=2)[0], foo(v))
9813
9814    def test_script_copy(self):
9815        class M(torch.nn.Module):
9816            __annotations__ = {
9817                "val": Optional[torch.Tensor]
9818            }
9819
9820            def __init__(self) -> None:
9821                super().__init__()
9822                self.val = None
9823
9824            def some_method(self):
9825                return 3
9826
9827            def forward(self, x):
9828                # type: (Tensor) -> Tensor
9829                self.val = x + self.some_method()
9830                return x
9831
9832        m = torch.jit.script(M())
9833        # test copy
9834        copy.copy(m)
9835        copy.deepcopy(m)
9836
9837    def test_script_forward_method_replacement(self):
9838        # We want to support the use case of attaching a different `forward` method
9839        class LowLevelModule(torch.nn.Module):
9840            def forward(self, input: torch.Tensor):
9841                # Generic forward dispatch
9842                return self.forward_pytorch(input) * 2
9843
9844        class TestModule(LowLevelModule):
9845            def __init__(self) -> None:
9846                super().__init__()
9847                # Replace the forward method
9848                self.forward = types.MethodType(LowLevelModule.forward, self)
9849
9850            def forward_pytorch(self, input: torch.Tensor):
9851                return torch.tensor(123)
9852
9853            def forward(self, input: torch.Tensor):
9854                # Should not use this forward method
9855                raise AssertionError("This method should not be used")
9856                return self.forward_pytorch(input)
9857
9858        m = TestModule()
9859        self.assertEqual(m(torch.tensor(1)), torch.tensor(246))
9860
9861        m_scripted = torch.jit.script(m)
9862        self.assertEqual(m_scripted(torch.tensor(1)), torch.tensor(246))
9863
9864    def test_python_call_non_tensor(self):
9865        def foo(a, b, c):
9866            # type: (Tensor, int, Tuple[Tensor, int]) -> Tuple[int, Tensor]
9867            d, e = c
9868            return b + e, a + d
9869
9870        @torch.jit.script
9871        def bar():
9872            x = torch.ones(3, 4)
9873            a, b = foo(x, 3, (x, 3))
9874            return a, b
9875
9876        self.assertEqual((6, torch.ones(3, 4) + 1), bar())
9877
9878    def test_python_call_non_tensor_wrong(self):
9879        with self.assertRaisesRegex(RuntimeError, r"but instead got value of type tuple"):
9880            @torch.jit.ignore
9881            def foo():
9882                # type: () -> Tensor
9883                return ((3, 4),)  # noqa: T484
9884
9885            @torch.jit.script
9886            def bar():
9887                return foo()
9888
9889            bar()
9890
9891    def test_if_different_type(self):
9892        with self.assertRaisesRegex(RuntimeError, "c0 is set to type "
9893                                    "int in the true branch and type "
9894                                    "float in the false branch"):
9895            @torch.jit.script
9896            def diff_type_used():
9897                if 1 == 2:
9898                    c0 = 1
9899                else:
9900                    c0 = 1.0
9901                return c0
9902
9903        with self.assertRaisesRegex(RuntimeError, "Variable 'c0' previously had type float"):
9904            @torch.jit.script
9905            def diff_existing_type(x):
9906                c0 = 1.0
9907                if 1 == 2:
9908                    c0 = 1
9909                    print(x)
9910                return x
9911
9912        @torch.jit.script
9913        def diff_type_unused():
9914            if 1 == 1:
9915                c0 = 1
9916                print(c0)
9917            else:
9918                c0 = 1.0
9919                print(c0)
9920            return 1
9921
9922    def test_if_not_defined_error(self):
9923        with self.assertRaisesRegex(RuntimeError, "c0 is not defined in the false branch"):
9924            @torch.jit.script
9925            def test():
9926                if 1 == 1:
9927                    c0 = 1
9928                return c0
9929        with self.assertRaisesRegex(RuntimeError, "c0 is not defined in the true branch"):
9930            @torch.jit.script
9931            def test2():
9932                if 1 == 1:
9933                    pass
9934                else:
9935                    c0 = 1
9936                return c0
9937
9938    def test_if_list_cat(self):
9939        # testing that different length lists don't throw error on cat in shape prop
9940        @torch.jit.script
9941        def test_list(x):
9942            if bool(x.sum() < 1):
9943                c = [x, x]
9944            else:
9945                c = [x, x, x]
9946            return torch.cat(c)
9947
9948        b = torch.zeros(2, 4)
9949        _propagate_shapes(test_list.graph, (b,), False)
9950
9951    def test_if_supertype(self):
9952        @torch.jit.script
9953        def tensor_unifying(x, y, z):
9954            # testing dynamic is appropriately set for y and z
9955            if bool(x):
9956                x, y, z = x + 1, y, z
9957            else:
9958                x, y, z = x + 1, x, y
9959
9960            return x, y, z
9961
9962        a = torch.zeros(2, 2, dtype=torch.float)
9963        b = torch.zeros(2, 4, dtype=torch.long)
9964        c = torch.zeros(2, 4, dtype=torch.float)
9965
9966        graph = _propagate_shapes(tensor_unifying.graph, (a, b, c), False)
9967        if_outputs = list(graph.findNode("prim::If").outputs())
9968        self.assertTrue(if_outputs[0].type().str() == "Float(*, *, requires_grad=0, device=cpu)")
9969        self.assertTrue(if_outputs[1].type().str() == "Tensor(*, *, requires_grad=0, device=cpu)")
9970        self.assertTrue(if_outputs[2].type().str() == "Tensor(*, *, requires_grad=0, device=cpu)")
9971
9972    def test_list_unify(self):
9973        # allowing a unififed int?[] would cause a runtime error b/c
9974        # the index operation expects int?[] to be a generic list,
9975        # but in the true branch the IValue will be a int list
9976        with self.assertRaisesRegex(RuntimeError, "int[] in the true branch and type None[]"):
9977            @torch.jit.script
9978            def list_optional_fails(x):
9979                # type: (bool) -> Optional[int]
9980                if x:
9981                    y = [1]
9982                else:
9983                    y = [None]  # noqa: T484
9984                return y[0]
9985
9986        @torch.jit.script
9987        def list_tensors(x):
9988            # type: (bool) -> Tuple[Tensor, List[Tensor]]
9989            if x:
9990                a = torch.zeros([1, 1])
9991                y = [a]
9992            else:
9993                a = torch.zeros([1, 2])
9994                y = [a]
9995            return a, y
9996
9997        self.run_pass('constant_propagation', list_tensors.graph)
9998        m = self.createFunctionFromGraph(list_tensors.graph)
9999        # testing that tensor type of lists is unified
10000        self.getExportImportCopy(m)
10001
10002    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
10003    @_inline_everything
10004    def test_import_constants_not_specialized(self):
10005        class Mod(torch.nn.Module):
10006            def forward(self, x):
10007                return torch.cat(2 * [x], dim=0)
10008
10009        class ScriptMod(torch.jit.ScriptModule):
10010            def __init__(self, mod):
10011                super().__init__()
10012                x = torch.zeros(1, 3)
10013                mod_fn = lambda : mod(x)  # noqa: E731
10014                self.mod = torch.jit.trace(mod_fn, ())
10015
10016            @torch.jit.script_method
10017            def forward(self):
10018                return self.mod()
10019
10020        cm = ScriptMod(Mod())
10021        # specialized tensor in graph
10022        FileCheck().check("Float(1, 3, strides=[3, 1], requires_grad=0, device=cpu)").run(cm.forward.graph)
10023        buffer = io.BytesIO()
10024        torch.jit.save(cm, buffer)
10025        buffer.seek(0)
10026        # when tensor is loaded as constant it isnt specialized
10027        cm_load = torch.jit.load(buffer)
10028        FileCheck().check_not("Float(1, 3)").run(cm_load.forward.graph)
10029
10030    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
10031    def test_type_annotations_repeated_list(self):
10032        @torch.jit.script
10033        def float_fn(x, y):
10034            # type: (float, BroadcastingList3[float]) -> List[float]
10035            return y
10036        self.assertEqual(float_fn(2.0, 1.0), float_fn(2.0, [1.0, 1.0, 1.0]))
10037        self.assertEqual(float_fn(2.0, 1.0), float_fn(2.0, (1.0, 1.0, 1.0)))
10038
10039        @torch.jit.script
10040        def float_fn_call():
10041            print(float_fn(1.0, 1.0))
10042            print(float_fn(1.0, (1.0, 1.0, 1.0)))
10043
10044        @torch.jit.script
10045        def int_fn(x):
10046            # type: (BroadcastingList3[int]) -> List[int]
10047            return x
10048        self.assertEqual(int_fn(1), int_fn([1, 1, 1]))
10049        self.assertEqual(int_fn(1), int_fn((1, 1, 1)))
10050
10051        @torch.jit.script
10052        def int_fn_call():
10053            print(int_fn(1))
10054            print(int_fn((1, 1, 1)))
10055
10056        with self.assertRaisesRegex(RuntimeError, "must be a positive integer:"):
10057            @torch.jit.script  # noqa: T484
10058            def fn(x):
10059                # type: (BroadcastingListx[int]) -> List[int]  # noqa: T484
10060                return x
10061
10062        # using CU so that flake8 error on int[2] is not raised (noqa not working)
10063        with self.assertRaisesRegex(RuntimeError, "Unknown type constructor"):
10064            cu = torch.jit.CompilationUnit('''
10065                def nested(x, y):
10066                    # type: (int, Tuple[int, int[2]]) -> List[int]
10067                    return x  # noqa: T484
10068            ''')
10069
10070        @torch.jit.script
10071        def f(x: BroadcastingList2[int]):
10072            return x
10073
10074        out = f(1)
10075        self.assertTrue(isinstance(out[0], int))
10076        self.assertEqual(out, [1, 1])
10077
10078    def test_ntuple_builtins(self):
10079        from torch.nn.modules.utils import _single, _pair, _triple, _quadruple
10080
10081        def test_ints():
10082            return _single(1), _pair(2), _triple(3), _quadruple(4)
10083
10084        def test_floats():
10085            return _single(1), _pair(2.1), _triple(3.1), _quadruple(4.1)
10086
10087        self.checkScript(test_ints, ())
10088        self.checkScript(test_floats, ())
10089
10090    def test_embedding_renorm_grad_error(self):
10091        # Testing that the builtin call to embedding_renorm_ correctly throws
10092        # Error when .backward() is called on its input
10093
10094        def embedding_norm(input, embedding_matrix, max_norm):
10095            F.embedding(input, embedding_matrix, max_norm=0.01)
10096
10097        @torch.jit.script
10098        def embedding_norm_script(input, embedding_matrix, max_norm):
10099            # type: (Tensor, Tensor, float) -> None
10100            F.embedding(input, embedding_matrix, max_norm=0.01)
10101
10102        for _ in [embedding_norm, embedding_norm_script]:
10103            input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]])
10104            embedding_matrix = torch.randn(10, 3)
10105
10106            var1 = torch.randn(10, 3, requires_grad=True)
10107            var2 = var1.detach().requires_grad_()
10108            output1 = var1 * embedding_matrix
10109            output2 = var2 * embedding_matrix
10110
10111            output1.sum().backward()
10112
10113            ignore = F.embedding(input, embedding_matrix, max_norm=0.01)
10114            with self.assertRaisesRegex(RuntimeError, "modified"):
10115                output2.sum().backward()
10116
10117    def test_type_annotations(self):
10118        def fn(x, y):
10119            # type: (Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
10120            return x, x * 2, x * 3
10121
10122        with self.assertRaisesRegex(RuntimeError, r"need 4 values .* found only 3"):
10123            @torch.jit.script
10124            def script_fn(x):
10125                x, y, z, w = fn(x, x)
10126
10127        with self.assertRaisesRegex(RuntimeError, r"too many values .* need 2 but found 3"):
10128            @torch.jit.script
10129            def script_fn2(x):
10130                x, y = fn(x, x)
10131
10132        def fn_unpack(x):
10133            y, z, w = fn(x, x)
10134            return y
10135
10136        def fn_index(x):
10137            q = fn(x, x)
10138            return x
10139
10140        def fn_string(str, strpair):
10141            # type: (str, Tuple[str, str]) -> Tuple[str, int, str, str]
10142            str1, str2 = strpair
10143            return str, 2, str1, str2
10144
10145        x = torch.ones(2, 2)
10146        self.checkScript(fn_unpack, (x,), optimize=True)
10147        self.checkScript(fn_index, (x,), optimize=True)
10148        self.checkScript(fn_string, ("1", ("3", "4")), optimize=True)
10149
10150    def test_type_annotations_varargs(self):
10151        @torch.jit.ignore
10152        def fn_varargs(x, *args):
10153            return args[0] if args else x
10154
10155        def fn1(x, y, z):
10156            return fn_varargs(x)
10157
10158        def fn2(x, y, z):
10159            return fn_varargs(x, y)
10160
10161        def fn3(x, y, z):
10162            return fn_varargs(x, y, z)
10163
10164        x, y, z = (torch.randn(2, 2) for _ in range(3))
10165        self.checkScript(fn1, (x, y, z), optimize=True)
10166        self.checkScript(fn2, (x, y, z), optimize=True)
10167        self.checkScript(fn3, (x, y, z), optimize=True)
10168
10169    def test_type_annotation_py3(self):
10170        code = dedent("""
10171        import torch
10172        from torch import Tensor
10173        from typing import Tuple
10174
10175        def fn(x : torch.Tensor, y : Tensor, z) -> Tuple[Tensor, Tensor, Tensor]:
10176            return (x, y + z, z)
10177        """)
10178
10179        with tempfile.TemporaryDirectory() as tmp_dir:
10180            script_path = os.path.join(tmp_dir, 'script.py')
10181            with open(script_path, 'w') as f:
10182                f.write(code)
10183            fn = get_fn('test_type_annotation_py3', script_path)
10184            fn = torch.jit.ignore(fn)
10185
10186            with self.assertRaisesRegex(RuntimeError, r"Expected a value of type 'Tensor' for argument"
10187                                                      r" 'x' but instead found type 'Tuple\[Tensor,"):
10188                @torch.jit.script
10189                def bad_fn(x):
10190                    x, y = fn((x, x), x, x)
10191                    return y
10192
10193            with self.assertRaisesRegex(RuntimeError, r"too many values .* need 2 but found 3"):
10194                @torch.jit.script
10195                def bad_fn2(x):
10196                    x, y = fn(x, x, x)
10197                    return y
10198
10199            with self.assertRaisesRegex(RuntimeError, r"need 4 values .* found only 3"):
10200                @torch.jit.script
10201                def bad_fn3(x):
10202                    x, y, z, w = fn(x, x, x)
10203                    return y
10204
10205            def good_fn(x):
10206                y, z, w = fn(x, x, x)
10207                return y, z, w
10208
10209            self.checkScript(good_fn, (torch.ones(2, 2),), optimize=True)
10210
10211    def test_type_annotation_module(self):
10212        class BaseModule(torch.jit.ScriptModule):
10213            @torch.jit.ignore
10214            def foo(self, x):
10215                # type: (Tensor) -> Tensor
10216                return x + 1
10217
10218            @torch.jit.ignore
10219            def bar(self, x, y):
10220                # type: (Tensor, Tensor) -> Tuple[Tensor, Tensor]
10221                return x + y, y
10222
10223            @torch.jit.ignore
10224            def baz(self, x, y):
10225                return x
10226
10227        class ModuleTooMany(BaseModule):
10228            @torch.jit.script_method
10229            def method(self, x):
10230                return self.foo(x, x)
10231
10232        class ModuleTooFew(BaseModule):
10233            @torch.jit.script_method
10234            def method(self, x):
10235                return self.bar(x)
10236
10237        class ModuleTooManyAssign(BaseModule):
10238            @torch.jit.script_method
10239            def method(self, x):
10240                y, z, w = self.bar(x, x)
10241                return x
10242
10243        class ModuleDefault(BaseModule):
10244            @torch.jit.script_method
10245            def method(self, x):
10246                y = self.baz(x)
10247                return x
10248
10249        with self.assertRaisesRegex(RuntimeError, "Expected at most 2 arguments but found 3"):
10250            ModuleTooMany()
10251        with self.assertRaisesRegex(RuntimeError, "Argument y not provided"):
10252            ModuleTooFew()
10253        with self.assertRaisesRegex(RuntimeError, "need 3 values .* found only 2"):
10254            ModuleTooManyAssign()
10255        with self.assertRaisesRegex(RuntimeError, "Argument y not provided."):
10256            ModuleDefault()
10257
10258    def test_type_inferred_from_empty_annotation(self):
10259        """
10260        Test that the type inferred from an empty or missing annotation is Torch.Tensor wtih `inferred=true`
10261        """
10262        @torch.jit.script
10263        def fn(x):
10264            return x
10265
10266        graph = fn.graph
10267        n = next(graph.inputs())
10268        self.assertTrue(n.type() == torch._C.TensorType.getInferred())
10269
10270        with self.assertRaisesRegex(RuntimeError, "Inferred 'x' to be of type 'Tensor"):
10271            fn("1")
10272
10273    def test_script_define_order(self):
10274        class M(torch.jit.ScriptModule):
10275
10276            @torch.jit.script_method
10277            def call_foo(self, input):
10278                return self.foo(input)
10279
10280            @torch.jit.script_method
10281            def foo(self, input):
10282                return input + 1
10283        m = M()
10284        self.assertEqual(2, m.call_foo(torch.ones((), dtype=torch.int64)))
10285
10286    def test_script_define_order_recursive_fail(self):
10287        class M(torch.jit.ScriptModule):
10288
10289            @torch.jit.script_method
10290            def call_foo(self, input):
10291                return self.foo(input)
10292
10293            @torch.jit.script_method
10294            def foo(self, input):
10295                self.call_foo(input)
10296
10297        with self.assertRaisesRegex(RuntimeError, 'called recursively'):
10298            M()
10299
10300    def test_script_kwargs_fn_call(self):
10301        class M(torch.jit.ScriptModule):
10302
10303            @torch.jit.script_method
10304            def call_foo(self, input):
10305                return self.foo(input=input, bar=1)
10306
10307            @torch.jit.script_method
10308            def foo(self, bar, input):
10309                # type: (int, Tensor) -> Tensor
10310                return input + bar
10311        m = M()
10312        self.assertEqual(2, m.call_foo(torch.ones((), dtype=torch.int64)))
10313
10314    def test_if_define(self):
10315        @torch.jit.script
10316        def foo(a):
10317            if bool(a == 0):
10318                b = 1
10319            else:
10320                b = 0
10321            return b + 1
10322
10323        @torch.jit.script
10324        def foo2(a):
10325            b = 0
10326            if bool(a == 0):
10327                b = 1
10328            return b + 1
10329
10330        @torch.jit.script
10331        def foo3(a):
10332            b = 1
10333            if bool(a == 0):
10334                c = 4
10335            else:
10336                b = 0
10337            return b + 1
10338
10339        a = torch.ones(1, dtype=torch.long)
10340        b = torch.zeros(1, dtype=torch.long)
10341        self.assertEqual(1, foo(a))
10342        self.assertEqual(2, foo(b))
10343        self.assertEqual(1, foo2(a))
10344        self.assertEqual(2, foo2(b))
10345        self.assertEqual(1, foo3(a))
10346        self.assertEqual(2, foo3(b))
10347
10348    def test_script_module_export_submodule(self):
10349        class M1(torch.jit.ScriptModule):
10350            def __init__(self) -> None:
10351                super().__init__()
10352                self.weight = nn.Parameter(torch.randn(2))
10353
10354            @torch.jit.script_method
10355            def forward(self, thing):
10356                return self.weight + thing
10357
10358        class M2(torch.jit.ScriptModule):
10359            def __init__(self) -> None:
10360                super().__init__()
10361                # test submodule
10362                self.sub = M1()
10363                self.weight = nn.Parameter(torch.randn(2, 3))
10364                self.bias = nn.Parameter(torch.randn(2))
10365                self.define("""
10366                    def hi(self, a):
10367                        return self.weight.mm(a)
10368                """)
10369
10370            @torch.jit.script_method
10371            def doit(self, input):
10372                return self.weight.mm(input)
10373
10374            @torch.jit.script_method
10375            def doit2(self, input):
10376                return self.weight.mm(input)
10377
10378            @torch.jit.script_method
10379            def doit3(self, input):
10380                return input + torch.ones([1], dtype=torch.double)
10381
10382            @torch.jit.script_method
10383            def forward(self, input):
10384                a = self.doit(input)
10385                b = self.doit2(input)
10386                c = self.hi(input)
10387                return a + b + self.bias + c
10388
10389        with torch.jit.optimized_execution(False):
10390            m_orig = M2()
10391            m_import = self.getExportImportCopy(m_orig)
10392
10393            input = torch.randn(3, 2)
10394            self.assertEqual(m_orig.doit(input), m_import.doit(input))
10395            self.assertEqual(m_orig.hi(input), m_import.hi(input))
10396            self.assertEqual(m_orig.doit3(input), m_import.doit3(input))
10397            self.assertEqual(m_orig.forward(input), m_import.forward(input))
10398
10399    @slowTest
10400    def test_compile_module_with_constant(self):
10401        class Double(nn.Module):
10402            def __init__(self, downsample=None):
10403                super().__init__()
10404
10405            def forward(self, input):
10406                return input * 2
10407
10408        class Mod(nn.Module):
10409            __constants__ = ['downsample']
10410
10411            def __init__(self, downsample=None):
10412                super().__init__()
10413                self.downsample = downsample
10414
10415            def forward(self, input):
10416                if self.downsample is not None:
10417                    return self.downsample(input)
10418                return input
10419
10420        none_mod = torch.jit.script(Mod(None))
10421        double_mod = torch.jit.script(Mod(Double()))
10422        self.assertEqual(none_mod(torch.tensor(1)), torch.tensor(1))
10423        self.assertEqual(double_mod(torch.tensor(1)), torch.tensor(1) * 2)
10424
10425    def test_device_kwarg(self):
10426        from torch import device
10427
10428        def f():
10429            return device(type='cuda'), torch.device(type='cpu')
10430        self.checkScript(f, ())
10431
10432    def test_script_module_export_tensor_type(self):
10433        class M(torch.jit.ScriptModule):
10434            def __init__(self, type):
10435                super().__init__()
10436                self.param = torch.nn.Parameter(torch.zeros((5, 5), dtype=type).random_())
10437
10438            @torch.jit.script_method
10439            def foo(self):
10440                return self.param
10441
10442        with torch.jit.optimized_execution(False):
10443            for type in [torch.float, torch.double]:
10444                m_orig = M(type)
10445                m_import = self.getExportImportCopy(m_orig)
10446                # check to make sure the storage wasn't resized
10447                self.assertTrue(m_orig.param.storage().size() == 25)
10448                self.assertEqual(m_orig.foo(), m_import.foo())
10449                self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype)
10450
10451    @unittest.skipIf(not RUN_CUDA, "testing cuda tensors require CUDA")
10452    def test_script_module_export_tensor_cuda(self):
10453        class M(torch.jit.ScriptModule):
10454
10455            def __init__(self) -> None:
10456                super().__init__()
10457                self.param = torch.nn.Parameter(torch.zeros((5, 5), device='cuda:0').random_())
10458
10459            @torch.jit.script_method
10460            def foo(self):
10461                return self.param
10462
10463        m_orig = M()
10464        m_import = self.getExportImportCopy(m_orig)
10465        # check to make sure the storage wasn't resized
10466        self.assertTrue(m_orig.param.storage().size() == 25)
10467        self.assertTrue(m_import.foo().device == torch.device('cuda:0'))
10468        self.assertEqual(m_orig.foo(), m_import.foo())
10469        self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype)
10470
10471    def test_script_module_export_blocks(self):
10472        class M(torch.jit.ScriptModule):
10473            def __init__(self, n, m):
10474                super().__init__()
10475                self.weight = torch.nn.Parameter(torch.rand(n, m))
10476
10477            @torch.jit.script_method
10478            def forward(self, input):
10479                if bool(input.sum() > 0):
10480                    output = self.weight.mv(input)
10481                else:
10482                    output = self.weight + input
10483                return output
10484
10485        m_orig = M(200, 200)
10486        m_import = self.getExportImportCopy(m_orig)
10487
10488        t = torch.rand(200)
10489        self.assertEqual(m_orig(t), m_import(t))
10490
10491    def test_script_module_export_shared_storage(self):
10492        class M(torch.jit.ScriptModule):
10493
10494            def __init__(self) -> None:
10495                super().__init__()
10496                self.param1 = torch.nn.Parameter(torch.rand(5, 5))
10497                self.param2 = torch.nn.Parameter(self.param1[3])
10498                self.param3 = torch.nn.Parameter(torch.rand(5, 5))
10499                self.param4 = torch.nn.Parameter(torch.rand(11, 5)[1:6])
10500
10501            @torch.jit.script_method
10502            def foo(self):
10503                return self.param1 + self.param2 + self.param3 + self.param4
10504
10505        with torch.jit.optimized_execution(False):
10506            m_orig = M()
10507            m_import = self.getExportImportCopy(m_orig)
10508
10509            self.assertEqual(m_orig.foo(), m_import.foo())
10510
10511            self.assertTrue(m_import.param1.storage().data_ptr() == m_import.param2.storage().data_ptr())
10512            self.assertTrue(m_import.param1.storage().data_ptr() != m_import.param3.storage().data_ptr())
10513
10514    def test_sequential_intermediary_types(self):
10515        class A(torch.nn.Module):
10516            def forward(self, x):
10517                return x + 3
10518
10519        class B(torch.nn.Module):
10520            def forward(self, x):
10521                return {"1": x}
10522
10523        class C(torch.nn.Module):
10524            def __init__(self) -> None:
10525                super().__init__()
10526                self.foo = torch.nn.Sequential(A(), B())
10527
10528            def forward(self, x):
10529                return self.foo(x)
10530
10531        self.checkModule(C(), (torch.tensor(1),))
10532
10533    def test_ellipsis_const_mid(self):
10534        def ellipsize(x):
10535            # type: (Tensor) -> List[int]
10536            return x[2, Ellipsis, 0:4, 4:8].size()
10537
10538        dummy = torch.zeros(8, 8, 8, 8, 8)
10539        self.checkScript(ellipsize, (dummy,), optimize=True)
10540
10541    def test_ellipsis_const_mid_select(self):
10542        def ellipsize(x):
10543            # type: (Tensor) -> List[int]
10544            return x[2, Ellipsis, 4, 4, 4:8, 2].size()
10545
10546        dummy = torch.zeros(8, 8, 8, 8, 8, 8, 8)
10547        self.checkScript(ellipsize, (dummy,), optimize=True)
10548
10549    def test_ellipsis_const_start(self):
10550        def ellipsize(x):
10551            # type: (Tensor) -> List[int]
10552            return x[Ellipsis, 0:4, 4:8].size()
10553        dummy = torch.zeros(8, 8, 8, 8, 8)
10554        self.checkScript(ellipsize, (dummy,), optimize=True)
10555
10556    def test_ellipsis_const_end(self):
10557        def ellipsize(x):
10558            # type: (Tensor) -> List[int]
10559            return x[0:4, 2, Ellipsis].size()
10560        dummy = torch.zeros(8, 8, 8, 8, 8)
10561        self.checkScript(ellipsize, (dummy,), optimize=True)
10562
10563    def test_ellipsis_mid(self):
10564        def ellipsize(x):
10565            # type: (Tensor) -> List[int]
10566            return x[2, ..., 0:4, 4:8].size()
10567
10568        dummy = torch.zeros(8, 8, 8, 8, 8)
10569        self.checkScript(ellipsize, (dummy,), optimize=True)
10570
10571    def test_ellipsis_mid_select(self):
10572        def ellipsize(x):
10573            # type: (Tensor) -> List[int]
10574            return x[2, ..., 4, 4, 4:8, 2].size()
10575
10576        dummy = torch.zeros(8, 8, 8, 8, 8, 8, 8)
10577        self.checkScript(ellipsize, (dummy,), optimize=True)
10578
10579    def test_ellipsis_start(self):
10580        def ellipsize(x):
10581            # type: (Tensor) -> List[int]
10582            return x[..., 0:4, 4:8].size()
10583        dummy = torch.zeros(8, 8, 8, 8, 8)
10584        self.checkScript(ellipsize, (dummy,), optimize=True)
10585
10586    def test_ellipsis_end(self):
10587        def ellipsize(x):
10588            # type: (Tensor) -> List[int]
10589            return x[0:4, 2, ...].size()
10590        dummy = torch.zeros(8, 8, 8, 8, 8)
10591        self.checkScript(ellipsize, (dummy,), optimize=True)
10592
10593    def test_torch_manual_seed(self):
10594        with freeze_rng_state():
10595            def test():
10596                torch.manual_seed(2)
10597                return torch.rand(1)
10598
10599            script = torch.jit.script(test)
10600            self.assertEqual(test(), script())
10601            graph = script.graph_for()
10602            FileCheck().check("aten::manual_seed").run(graph)
10603
10604    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
10605    def test_index_select_shape_prop(self):
10606
10607        @torch.jit.script
10608        def foo(x, y):
10609            return torch.index_select(x, index=y, dim=1)
10610
10611        a = torch.zeros(2, 2)
10612        b = torch.zeros(4, dtype=torch.long)
10613        torch._C._jit_pass_complete_shape_analysis(foo.graph, (a, b), False)
10614        FileCheck().check("Float(2, 4, strides=[4, 1], requires_grad=0, device=cpu)").run(str(foo.graph))
10615
10616    def test_shape_analysis_loop(self):
10617        def foo(a, b, x):
10618            c = a
10619            # on the first iteration of the loop it appears that
10620            # c should have a expand to the size of b
10621            # but on the second+ iterations, there is no broadcast and the
10622            # sizes are different.
10623            # previously this would cause the compiler to (1) enter an infinite
10624            # loop trying to compute the shape, and (2) insert invalid
10625            # broadcasts.
10626            # this test ensure we don't regress on these issues
10627            for _ in range(2):
10628                a = c + b
10629                c = x
10630                b = x
10631            return a
10632
10633        self.checkScript(foo, (torch.zeros(1), torch.zeros(4), torch.zeros(5)), optimize=False)
10634
10635    def test_intlist_args(self):
10636        def func_1(x):
10637            return torch.nn.functional.adaptive_avg_pool1d(x, 1)
10638
10639        def func_2(x):
10640            return torch.nn.functional.adaptive_avg_pool1d(x, output_size=1)
10641
10642        def func_3(x):
10643            return torch.nn.functional.adaptive_avg_pool1d(x, output_size=[1])
10644
10645        x = torch.randn(8, 8, 8)
10646        self.checkScript(func_1, [x], optimize=True)
10647        self.checkScript(func_2, [x], optimize=True)
10648        self.checkScript(func_3, [x], optimize=True)
10649
10650    def test_wrong_implicit_expand(self):
10651
10652        @_trace(torch.zeros(3), torch.zeros(1))
10653        def foo(a, b):
10654            return a + b
10655
10656        a = torch.rand(4)
10657        b = torch.rand(4)
10658        self.assertEqual(a + b, foo(a, b))
10659
10660    def test_builtin_args_fails(self):
10661
10662        with self.assertRaisesRegex(RuntimeError, 'Argument self not provided'):
10663            @torch.jit.script
10664            def f1(a):
10665                torch.sum(foo=4)
10666
10667        with self.assertRaisesRegex(RuntimeError, 'specified twice'):
10668            @torch.jit.script
10669            def f2(a):
10670                torch.sum(a, self=a)
10671
10672        with self.assertRaisesRegex(RuntimeError, 'not provided'):
10673            @torch.jit.script
10674            def f3(a):
10675                torch.sum(dim=4)
10676
10677        with self.assertRaisesRegex(RuntimeError, 'for argument \'tensors\' but instead found type \'Tensor'):
10678            @torch.jit.script
10679            def f4(a):
10680                torch.cat(a)
10681
10682        with self.assertRaisesRegex(RuntimeError, r'argument \'tensors\' but instead found type \'List\[int\]'):
10683            @torch.jit.script
10684            def f5(a):
10685                torch.cat([3])
10686
10687        with self.assertRaisesRegex(RuntimeError, r'Expected a value of'
10688                                    r' type \'List\[int\]\' for argument'
10689                                    r' \'size\' but instead found type '
10690                                    r'\'List\[Union\[List\[int\], int\]\]'):
10691            @torch.jit.script
10692            def f6(a):
10693                a.expand(size=[3, [4]])
10694
10695    def test_builtin_args(self):
10696
10697        def t0(a):
10698            # default arg dim
10699            return torch.cat([a, a])
10700
10701        self.checkScript(t0, (torch.zeros(1, 1),))
10702
10703        def t1(a):
10704            # keywords out of order
10705            return torch.cat(dim=1, tensors=[a, a])
10706
10707        self.checkScript(t1, (torch.zeros(1, 1, 2),))
10708
10709        def t2(a):
10710            # mix const/non-const attributes
10711            if 1 == 1:
10712                b = 1
10713            else:
10714                b = 0
10715            return torch.sum(a, dim=b, keepdim=False)
10716
10717        self.checkScript(t2, (torch.zeros(1, 1, 2),))
10718
10719    def test_parser_type_annotations(self):
10720        cu = torch.jit.CompilationUnit('''
10721            def foo(x : Tensor, y : Tuple[Tuple[Tensor, Tensor], Tensor]) -> Tuple[Tensor, Tensor]:
10722                return x, x
10723        ''')
10724
10725        self.assertExpected(str(cu.foo.schema))
10726
10727    def test_parser_type_annotations_comment(self):
10728        cu = torch.jit.CompilationUnit('''
10729            def foo(x, y):
10730                # type: (Tensor, Tuple[Tuple[Tensor, Tensor], Tensor]) -> Tuple[Tensor, Tensor]
10731                return x, x
10732        ''')
10733
10734        self.assertExpected(str(cu.foo.schema))
10735
10736    def test_parser_type_annotations_unknown_type(self):
10737        with self.assertRaisesRegex(RuntimeError, "Unknown type name 'Foo'"):
10738            cu = torch.jit.CompilationUnit('''
10739                def foo(x : Tensor, y : Tuple[Tuple[Foo, Tensor], Tensor]) -> Tuple[Tensor, Tensor]:
10740                    return x, x
10741            ''')
10742
10743    def test_parser_type_annotations_subscript_non_ident(self):
10744        with self.assertRaisesRegex(RuntimeError, r'Subscripted type must be a type identifier'):
10745            cu = torch.jit.CompilationUnit('''
10746                def foo(x : Tensor, y : Tuple[Tensor, Tensor][Tensor]) -> Tuple[Tensor, Tensor]:
10747                    return x, x
10748            ''')
10749
10750    def test_parser_type_annotations_subscript_tensor(self):
10751        with self.assertRaisesRegex(RuntimeError, r'Unknown type constructor Tensor'):
10752            cu = torch.jit.CompilationUnit('''
10753                def foo(x : Tensor, y : Tensor[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
10754                    return x, x
10755            ''')
10756
10757    def test_parser_type_annotations_incompatible_expression(self):
10758        with self.assertRaisesRegex(RuntimeError, r'Expression of type \+ cannot be used in a type expression'):
10759            cu = torch.jit.CompilationUnit('''
10760                def foo(x : Tensor, y : Tuple[3 + 4, Tensor]) -> Tuple[Tensor, Tensor]:
10761                    return x, x
10762            ''')
10763
10764    def test_gather_dynamic_index(self):
10765        def t(x):
10766            gather1 = x[0]
10767            idx = 0 + 1
10768            gather2 = x[idx]
10769            return gather1 + gather2
10770
10771        self.checkScript(t, (torch.zeros(3, 2, 3),))
10772
10773    def test_torch_ignore_conversion_to_none(self):
10774        class A(torch.nn.Module):
10775            @torch.jit.ignore
10776            def ignored(self, a: int) -> None:
10777                l: int = len([2 for i in range(a) if i > 2])
10778                return
10779
10780            def forward(self) -> int:
10781                a: int = 4
10782                b: int = 5
10783                self.ignored(a)
10784                return a + b
10785
10786        class B(torch.nn.Module):
10787            @torch.jit.ignore
10788            def ignored(self, a: int):
10789                l: int = len([2 for i in range(a) if i > 2])
10790                return
10791
10792            def forward(self) -> int:
10793                a: int = 4
10794                b: int = 5
10795                self.ignored(a)
10796                return a + b
10797
10798        modelA = torch.jit.script(A())
10799        self.assertEqual(modelA(), 9)
10800
10801        modelB = torch.jit.script(B())
10802        self.assertEqual(modelB(), 9)
10803
10804    def test_addmm_grad(self):
10805        """ This test checks several things:
10806            1. An expand node was inserted before the addmm operating on the
10807               bias term.
10808            2. The fused form of addmm appears in the ultimate graph that's
10809               executed.
10810            3. A sum op was emitted for accumulating gradients along the 0th
10811               (expanded) dimension of the bias term.
10812            4. The correct symbolic representation for the backward pass of the
10813               mm operator was emitted (x.t() -> mm)
10814
10815            TODO: we should actually check these conditions once we have a way
10816            to dump the GraphExecutor state. Namely the processed forward graph
10817            and the backward graph.
10818        """
10819        @torch.jit.script
10820        def addmm_grad_test(b, x, w):
10821            return torch.addmm(b, x, w)
10822
10823        # Initialize param and input values
10824        w_init = torch.rand(2, 5)
10825        b_init = torch.rand(5)
10826        x = torch.rand(3, 2)
10827
10828        # Clone trainable params
10829        b = b_init.clone()
10830        b.requires_grad_()
10831        w = w_init.clone()
10832        w.requires_grad_()
10833
10834        # Test symbolic differentiation
10835        y = addmm_grad_test(b, x, w)
10836        y.sum().backward()
10837
10838        # clone params for autograd reference
10839        b_ref = b_init.clone()
10840        b_ref.requires_grad_()
10841        w_ref = w_init.clone()
10842        w_ref.requires_grad_()
10843        y_ref = torch.addmm(b_ref, x, w_ref)
10844        y_ref.sum().backward()
10845
10846        self.assertEqual(w.grad, w_ref.grad)
10847        self.assertEqual(b.grad, b_ref.grad)
10848
10849    @unittest.skipIf(not RUN_CUDA, "running tests on cuda to verify cudnn fix")
10850    def test_batch_norm_inference_backward_cuda(self):
10851        with enable_profiling_mode_for_profiling_tests():
10852            class MyBatchNorm(torch.nn.Module):
10853                def __init__(self, num_features, affine, track_running_stats):
10854                    super().__init__()
10855                    self.bn = torch.nn.BatchNorm2d(
10856                        num_features, 1e-5, affine=affine, track_running_stats=track_running_stats).float()
10857
10858                def forward(self, x: torch.Tensor):
10859                    o = self.bn(x)
10860                    o = torch.nn.functional.relu(o)
10861                    return o
10862
10863            batch = 4
10864            c = 2
10865            hw = 3
10866            # Initialize param and input values
10867            x_init = torch.randn(batch, c, hw, hw, dtype=torch.float).cuda()
10868            grad = torch.randn(batch, c, hw, hw, dtype=torch.float).cuda()
10869
10870            training = False
10871            affine = True
10872            track_running_stats = True
10873
10874            module = torch.jit.script(MyBatchNorm(c, affine, track_running_stats)).cuda()
10875            ref_module = MyBatchNorm(c, affine, track_running_stats).cuda()
10876            module.eval()
10877            ref_module.eval()
10878
10879            jit_module = torch.jit.script(module)
10880            ref_module.load_state_dict(module.state_dict())
10881
10882            x = x_init.detach().clone()
10883            x.requires_grad_()
10884            x_ref = x_init.detach().clone()
10885            x_ref.requires_grad_()
10886
10887            # Test symbolic differentiation
10888            # Run Forward and Backward thrice to trigger autodiff graph
10889            for i in range(0, 3):
10890                y = jit_module(x)
10891                y.backward(grad)
10892            x.grad.zero_()
10893
10894            module.bn.running_mean.zero_()
10895            module.bn.running_var.fill_(1.0)
10896            ref_module.bn.running_mean.zero_()
10897            ref_module.bn.running_var.fill_(1.0)
10898
10899            # run jitted module
10900            y = jit_module(x)
10901            y.backward(grad)
10902            # reference computation
10903            y_ref = ref_module(x_ref)
10904            y_ref.backward(grad)
10905
10906            self.assertEqual(y_ref, y)
10907            self.assertEqual(x.grad, x_ref.grad)
10908            self.assertEqual(module.bn.running_mean, ref_module.bn.running_mean)
10909            self.assertEqual(module.bn.running_var, ref_module.bn.running_var)
10910
10911    def test_zeros(self):
10912        class M(torch.jit.ScriptModule):
10913            __constants__ = ['d']
10914
10915            def __init__(self) -> None:
10916                super().__init__()
10917                self.d = torch.device('cpu')
10918
10919            @torch.jit.script_method
10920            def create(self):
10921                return torch.zeros([1, 1, 2], dtype=torch.float, device=self.d, layout=torch.strided)
10922
10923        r = M().create()
10924        self.assertEqual(r.dtype, torch.float)
10925        self.assertEqual(torch.zeros([1, 1, 2], dtype=torch.float), r)
10926
10927        def fn():
10928            return torch.zeros((1, 2, 3))
10929
10930        self.checkScript(fn, ())
10931
10932    def test_vararg_zeros(self):
10933        def foo():
10934            return torch.zeros(3, 4, 5, dtype=torch.int)
10935
10936        self.checkScript(foo, ())
10937
10938    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "the original version of test_rand")
10939    def test_rand(self):
10940        def test_rand():
10941            a = torch.rand([3, 4])
10942            return a + 1.0 - a
10943
10944        self.checkScript(test_rand, ())
10945        fn = torch.jit.script(test_rand)
10946        out = fn()
10947        self.assertEqual(out.dtype, torch.get_default_dtype())
10948        g = fn.graph_for()
10949        # Testing shape analysis correctly setting type
10950        if GRAPH_EXECUTOR != ProfilingMode.SIMPLE:
10951            FileCheck().check("Double(*, *, requires_grad=0, device=cpu)") \
10952                       .check_not("Float(*, *, requires_grad=0, device=cpu)").run(g)
10953
10954        @torch.jit.script
10955        def randint():
10956            return torch.randint(0, 5, [1, 2])
10957        out = randint()
10958        self.assertEqual(out.dtype, torch.int64)
10959        if GRAPH_EXECUTOR != ProfilingMode.SIMPLE:
10960            FileCheck().check("Long(*, *, requires_grad=0, device=cpu)") \
10961                       .check_not("Float(*, *, requires_grad=0, device=cpu)") \
10962                       .check_not("Double(*, *, requires_grad=0, device=cpu)") \
10963                       .run(randint.graph_for())
10964
10965    @unittest.skipIf(not RUN_CUDA, "no CUDA")
10966    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "skip if profiling isn't enabled")
10967    def test_autodiff_complex(self):
10968        def foo(x: torch.Tensor, y: torch.Tensor, W: torch.Tensor):
10969            return torch.exp(torch.mm(torch.complex(x, y), W.cfloat()))
10970
10971        @torch.jit.script
10972        def jitted_foo(x: torch.Tensor, y: torch.Tensor, W: torch.Tensor):
10973            return torch.exp(torch.mm(torch.complex(x, y), W.cfloat()))
10974
10975        x = torch.randn(128, 16, dtype=torch.float32, device='cuda:0')
10976        y = torch.randn(128, 16, dtype=torch.float32, device='cuda:0')
10977        W = torch.randn(16, 1, dtype=torch.float32, device='cuda:0', requires_grad=True)
10978        W.data /= 4
10979
10980        with enable_profiling_mode_for_profiling_tests():
10981            for i in range(4):
10982                self.assertTrue((foo(x, y, W).grad_fn is None) == (jitted_foo(x, y, W).grad_fn is None))
10983
10984
10985    def test_linear_grad(self):
10986        with enable_profiling_mode_for_profiling_tests():
10987            def t(x: torch.Tensor, w: torch.Tensor, b: Optional[torch.Tensor]):
10988                return torch.nn.functional.linear(x, w, b)
10989
10990            x_init = torch.randn(4, 2)
10991            w_init = torch.randn(3, 2)
10992            b_init = torch.randn(3)
10993            grad = torch.randn(4, 3)
10994
10995            with disable_autodiff_subgraph_inlining():
10996                # script module
10997                jit_t = torch.jit.script(t)
10998
10999                x = x_init.detach().requires_grad_()
11000                w = w_init.detach().requires_grad_()
11001                b = b_init.detach().requires_grad_()
11002                x_ref = x_init.detach().requires_grad_()
11003                w_ref = w_init.detach().requires_grad_()
11004                b_ref = b_init.detach().requires_grad_()
11005
11006                # profiling/optimization runs
11007                jit_o = jit_t(x, w, b)
11008                jit_o.backward(grad)
11009                jit_o = jit_t(x, w, b)
11010                jit_o.backward(grad)
11011
11012                x.grad.zero_()
11013                w.grad.zero_()
11014                b.grad.zero_()
11015                jit_o = jit_t(x, w, b)
11016                jit_o.backward(grad)
11017                o = t(x_ref, w_ref, b_ref)
11018                o.backward(grad)
11019
11020                self.assertEqual(jit_o, o)
11021                self.assertEqual(x.grad, x_ref.grad)
11022                self.assertEqual(w.grad, w_ref.grad)
11023                self.assertEqual(b.grad, b_ref.grad)
11024
11025                x.grad.zero_()
11026                w.grad.zero_()
11027                x_ref.grad.zero_()
11028                w_ref.grad.zero_()
11029                jit_o = jit_t(x, w, None)
11030                jit_o.backward(grad)
11031                o = t(x_ref, w_ref, None)
11032                o.backward(grad)
11033
11034                self.assertEqual(jit_o, o)
11035                self.assertEqual(x.grad, x_ref.grad)
11036                self.assertEqual(w.grad, w_ref.grad)
11037
11038    @skipIfTorchDynamo("TorchDynamo doesn't support profile")
11039    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "the profiling version of test_rand")
11040    def test_rand_profiling(self):
11041        def test_rand():
11042            a = torch.rand([3, 4])
11043            return a + 1.0 - a
11044
11045        # Testing shape analysis correctly setting type
11046        with enable_profiling_mode_for_profiling_tests():
11047            with num_profiled_runs(1):
11048                fn = torch.jit.script(test_rand)
11049                out = fn()
11050                graph_str = torch.jit.last_executed_optimized_graph()
11051                self.assertEqual(out.dtype, torch.float)
11052                FileCheck().check("Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)") \
11053                           .check_not("Double(3, 4, strides=[4, 1], requires_grad=0, device=cpu)").run(graph_str)
11054
11055            # fn = self.checkScript(test_rand, ())
11056            # out = fn()
11057            # self.assertEqual(out.dtype, torch.float)
11058
11059        @torch.jit.script
11060        def randint():
11061            return torch.randint(0, 5, [1, 2])
11062
11063        with enable_profiling_mode_for_profiling_tests():
11064            with num_profiled_runs(1):
11065                out = randint()
11066                graph_str = torch.jit.last_executed_optimized_graph()
11067                self.assertEqual(out.dtype, torch.int64)
11068                FileCheck().check("profiled_type=Long(1, 2, strides=[2, 1], requires_grad=0, device=cpu)").run(graph_str)
11069
11070
11071    def test_erase_number_types(self):
11072        def func(a):
11073            b = 7 + 1 + 3
11074            c = a + b
11075            c += b
11076            return c
11077
11078        graph = torch.jit.script(func).graph
11079        FileCheck().check("int = prim::Constant").check("aten::add_").run(str(graph))
11080        self.run_pass("erase_number_types", graph)
11081        FileCheck().check_not("int = prim::Constant").run(str(graph))
11082
11083    def test_refine_tuple_types(self):
11084        # TupleConstruct output type is not correct here.
11085        graph_str = """
11086        graph(%a : Float(123), %b : Float(4, 5, 6)):
11087          %c : (Tensor, Tensor) = prim::TupleConstruct(%a, %b)
11088          return (%c)
11089        """
11090        graph = parse_ir(graph_str)
11091        torch._C._jit_pass_refine_tuple_types(graph)
11092
11093        # After the pass, the output type should've been updated.
11094        self.assertTrue('(Float(123), Float(4, 5, 6))' in str(graph.findNode('prim::TupleConstruct').output()))
11095
11096    # TODO(henrytu): Add test for RefineTypes for NamedTuple when it's supported by IR parser.
11097
11098    def test_remove_dropout(self):
11099        weight_0_shape = (20, 5)
11100        weight_1_shape = (20, 20)
11101        input_shape = (10, 5)
11102
11103        class M(torch.nn.Module):
11104            def __init__(self) -> None:
11105                super().__init__()
11106                self.weight_0 = torch.nn.Parameter(torch.rand(weight_0_shape))
11107                self.weight_1 = torch.nn.Parameter(torch.rand(weight_1_shape))
11108
11109            def forward(self, x):
11110                o = F.linear(x, self.weight_0)
11111                o = F.dropout(o, training=self.training)
11112                o = F.linear(o, self.weight_1)
11113                return o
11114
11115        data = torch.rand(input_shape)
11116        m = M()
11117        m = torch.jit.script(m)
11118        with self.assertRaisesRegex(RuntimeError, r'Dropout removal module in training mode is not yet supported'):
11119            torch._C._jit_pass_remove_dropout(m._c)
11120        m.eval()
11121        ref_res = m(data)
11122        # Need to inline otherwise we see instances of Function.
11123        # We would have to use torch.linear/dropout to get around it otherwise.
11124        from torch.jit._recursive import wrap_cpp_module
11125        m = wrap_cpp_module(torch._C._freeze_module(m._c))
11126        torch._C._jit_pass_remove_dropout(m._c)
11127        res = m(data)
11128        FileCheck().check_not("aten::dropout").run(str(m.graph))
11129        torch.testing.assert_close(ref_res, res, rtol=1e-2, atol=1e-3)
11130
11131    def test_unfold_zero_dim(self):
11132        def fn(x):
11133            return x.unfold(0, 1, 1)
11134
11135        graph = torch.jit.script(fn).graph
11136        torch._C._jit_pass_complete_shape_analysis(graph, (torch.tensor(0.39),), False)
11137        out_dims = fn(torch.tensor(0.3923)).ndim
11138        self.assertEqual(graph.findNode("aten::unfold").output().type().dim(), out_dims)
11139
11140    def test_mm_batching(self):
11141
11142        with enable_profiling_mode_for_profiling_tests():
11143            lstm_cell = torch.jit.script(LSTMCellS)
11144
11145            def lstm(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
11146                for i in range(x.size(0)):
11147                    hx, cx = lstm_cell(x[i], hx, cx, w_ih, w_hh, b_ih, b_hh)
11148                return hx
11149
11150            slstm = torch.jit.script(lstm)
11151
11152            inputs = get_lstm_inputs('cpu', training=True, seq_length=10)
11153            slstm(*inputs, profile_and_replay=True).sum().backward(retain_graph=True)
11154            if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
11155                slstm(*inputs, profile_and_replay=True).sum().backward()
11156
11157            fw_graph = slstm.graph_for(*inputs)
11158            if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
11159                bw_graph = backward_graph(slstm, diff_graph_idx=0)
11160                self.assertTrue('prim::MMBatchSide' in str(fw_graph))
11161                self.assertTrue('prim::MMTreeReduce' in str(bw_graph))
11162
11163            sout = slstm(*inputs)
11164            out = lstm(*inputs)
11165            self.assertEqual(sout, out)
11166            self.assertEqual(torch.autograd.grad(sout.sum(), inputs),
11167                             torch.autograd.grad(out.sum(), inputs))
11168
11169    def test_loop_unrolling(self):
11170        def fn(x):
11171            y = 0
11172            for i in range(int(x)):
11173                y -= i
11174            return y
11175
11176        graph = torch.jit.script(fn).graph
11177        self.run_pass('loop_unrolling', graph)
11178        unroll_factor = 8
11179        FileCheck().check("prim::Loop").check_count("aten::sub", unroll_factor) \
11180            .check("prim::Loop").check("aten::sub").run(str(graph))
11181        self.checkScript(fn, (torch.tensor(10),))
11182
11183    def test_loop_unrolling_const(self):
11184        def fn():
11185            y = 0
11186            for _ in range(10):
11187                y -= 1
11188            return y
11189
11190        def fn2():
11191            y = 0
11192            for i in range(10):
11193                y -= i
11194            return y
11195
11196        def check(fn, name):
11197            graph = torch.jit.script(fn).graph
11198            self.run_pass('loop_unrolling', graph)
11199            # entirely unrolled
11200            FileCheck().check_not("prim::Loop'").run(str(graph))
11201            self.checkScript(fn, ())
11202
11203        check(fn, 'add_const')
11204        check(fn2, 'add_iter')
11205
11206    def test_loop_unrolling_nested(self):
11207        def fn(x):
11208            y = 0
11209            for _ in range(10):
11210                for j in range(int(x)):
11211                    y -= j
11212            return y
11213
11214        graph = torch.jit.script(fn).graph
11215        self.run_pass('loop_unrolling', graph)
11216        # inner loop with 8 subs followed by loop epilogue
11217        unroll_factor = 8
11218        FileCheck().check("prim::Loop").check("prim::Loop").check_count('aten::sub', unroll_factor) \
11219            .check("prim::Loop").check("aten::sub").run(str(graph))
11220        self.checkScript(fn, (torch.tensor(10),))
11221
11222    def test_loop_unroll_unused_counter(self):
11223        def fn(x):
11224            y = 0
11225            for _ in range(int(x)):
11226                y -= 1
11227            return y
11228
11229        graph = torch.jit.script(fn).graph
11230        self.run_pass('loop_unrolling', graph)
11231        FileCheck().check("prim::Loop").check_not("aten::add").check("return") \
11232            .run(str(graph))
11233
11234    def test_loop_unroll_negative(self):
11235        def fn(x):
11236            y = 0
11237            for _ in range(int(x)):
11238                y += 1
11239            return y
11240
11241        self.checkScript(fn, (torch.tensor(-20),))
11242        self.checkScript(fn, (torch.tensor(-2),))
11243        self.checkScript(fn, (torch.tensor(-1),))
11244        self.checkScript(fn, (torch.tensor(0),))
11245        self.checkScript(fn, (torch.tensor(1),))
11246        self.checkScript(fn, (torch.tensor(2),))
11247
11248    def test_where(self):
11249        def fn(x, y):
11250            return torch.where(x > 0.0, x, y)
11251
11252        self.checkScript(fn, (torch.randn(3, 2, dtype=torch.float), torch.ones(3, 2, dtype=torch.float)))
11253
11254    def test_where_method(self):
11255        def fn(x, y):
11256            return x.where(x > 0.0, y)
11257
11258        self.checkScript(fn, (torch.randn(3, 2, dtype=torch.float), torch.ones(3, 2, dtype=torch.float)))
11259
11260    def test_union_to_number(self):
11261        @torch.jit.script
11262        def fn(x: Union[int, complex, float], y: Union[int, complex, float]):
11263            return x + y
11264        FileCheck().check(": Scalar):").run(fn.graph)
11265
11266    def test_reassign_module_lhs(self):
11267        with self.assertRaisesRegex(RuntimeError, 'Cannot re-assign \'self\''):
11268            class ReassignSelfLHS(torch.jit.ScriptModule):
11269                @torch.jit.script_method
11270                def forward(self, x):
11271                    for _ in range(20):
11272                        self = x
11273                    return self
11274
11275            ReassignSelfLHS()
11276
11277    def test_reassign_module_rhs(self):
11278        with self.assertRaisesRegex(RuntimeError, 'Cannot re-assign \'x\' to a value of type module'):
11279            class ReassignSelfRHS(torch.jit.ScriptModule):
11280                @torch.jit.script_method
11281                def forward(self, x):
11282                    for _ in range(20):
11283                        x = self
11284                    return self
11285
11286            ReassignSelfRHS()
11287
11288    def test_unknown_builtin(self):
11289        with self.assertRaisesRegex(RuntimeError, 'object has no attribute or method'):
11290            @torch.jit.script
11291            def unknown_builtin(x):
11292                return x.splork(3)
11293
11294    def test_return_tuple(self):
11295        def return_tuple(x):
11296            a = (x, x)
11297            return a, x
11298        self.checkScript(return_tuple, (torch.rand(4),))
11299
11300    def test_add_tuple_optional(self):
11301        def foo(input: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]) -> Optional[torch.Tensor]:
11302            changed_input = input[0] + 1
11303            value: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]] = (changed_input,) + input[1:]
11304            return value[2]
11305        inp: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]] = (torch.rand(4), None, None)
11306        self.checkScript(foo, (inp,))
11307
11308    def test_add_tuple_non_optional(self):
11309        def foo(input: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> torch.Tensor:
11310            changed_input = input[0] + 1
11311            value: Tuple[torch.Tensor, torch.Tensor, torch.Tensor] = (changed_input,) + input[1:]
11312            return torch.sum(value[2]) + 4
11313        inp: Tuple[torch.Tensor, torch.Tensor, torch.Tensor] = (torch.rand(4), torch.rand(4), torch.rand(4))
11314        self.checkScript(foo, (inp,))
11315
11316    def test_add_tuple_different_types(self):
11317        def foo(a: Tuple[int, float], b: Tuple[int]) -> int:
11318            c: Tuple[int, float, int] = a + b
11319            d: Tuple[int, float, int, int] = c + b
11320            return d[3] + 1
11321        a = (1, 2.0)
11322        b = (3,)
11323        self.checkScript(foo, (a, b))
11324
11325    def test_add_tuple_same_types(self):
11326        def foo(a: Tuple[int, int], b: Tuple[int, int, int]) -> int:
11327            c: Tuple[int, int, int, int, int] = a + b
11328            d: Tuple[int, int, int, int, int, int, int, int] = c + b
11329            return d[6] - 2
11330        a = (1, 2)
11331        b = (3, 4, 5)
11332        self.checkScript(foo, (a, b))
11333
11334    def test_method_no_self(self):
11335        with self.assertRaisesRegex(RuntimeError, 'methods must have a self argument'):
11336            class MethodNoSelf(torch.jit.ScriptModule):
11337                @torch.jit.script_method  # noqa: B902
11338                def forward():  # noqa: B902
11339                    return torch.zeros(3, 4)
11340
11341            MethodNoSelf()
11342
11343    def test_return_stmt_not_at_end(self):
11344        def return_stmt(x):
11345            if bool(x > 3):
11346                return x + 3
11347            else:
11348                return x
11349        self.checkScript(return_stmt, (torch.rand(1),))
11350
11351    def test_for_in_range(self):
11352        def fn():
11353            c = 0
11354            for i in range(100):
11355                c += i
11356            return c
11357        self.checkScript(fn, ())
11358
11359    def test_for_in_range_dynamic(self):
11360        def fn():
11361            c = 0
11362            for i in range(100):
11363                acc = 0
11364                for j in range(i):
11365                    acc += j
11366                c += acc
11367            return c
11368        self.checkScript(fn, (), optimize=False)
11369
11370    def test_for_in_range_ast(self):
11371        def test_script_for_in_range_ast():
11372            c = 0
11373            for i in range(100):
11374                acc = 0
11375                for j in range(i):
11376                    acc += j
11377                c += acc
11378            return c
11379
11380        self.checkScript(test_script_for_in_range_ast, ())
11381
11382    def test_for_in_range_if_ast(self):
11383        @torch.jit.script
11384        def test_script_for_in_range_if_ast(x):
11385            output = x
11386            for i in range(20):
11387                if i == 0:
11388                    output = x.unsqueeze(0)
11389                else:
11390                    output = torch.cat((output, x.unsqueeze(0)), dim=0)
11391            return output
11392        inputs = self._make_scalar_vars([0], torch.int64)
11393
11394        self.assertEqual(test_script_for_in_range_if_ast(*inputs).shape[0], 20)
11395
11396    def test_for_in_range_start_end(self):
11397        def fn():
11398            x = 0
11399            for i in range(7, 100):
11400                x += i
11401            return x
11402        self.checkScript(fn, ())
11403
11404    def test_for_in_range_start_end_step(self):
11405        def fn(start, end, step):
11406            # type: (int, int, int) -> int
11407            x = 0
11408            for i in range(start, end, step):
11409                x += i
11410            return x
11411
11412        self.checkScript(fn, (7, 100, 7))
11413        self.checkScript(fn, (7, 100, -7))
11414        self.checkScript(fn, (2, -11, -3))
11415        self.checkScript(fn, (2, -11, 3))
11416        self.checkScript(fn, (2, 10, 3))
11417        self.checkScript(fn, (-2, -10, -10))
11418
11419    def test_for_in_range_zero_step(self):
11420        @torch.jit.script
11421        def fn():
11422            x = 0
11423            for i in range(2, -11, 0):
11424                x += i
11425            return x
11426
11427        with self.assertRaisesRegex(RuntimeError, "must not be zero"):
11428            fn()
11429
11430    def test_range_args(self):
11431        with self.assertRaisesRegex(RuntimeError, r'range expected at least 1 arguments, got 0'):
11432            @torch.jit.script
11433            def range_no_arg(x):
11434                for _ in range():
11435                    x += 1
11436                return x
11437        with self.assertRaisesRegex(RuntimeError, r'found float'):
11438            @torch.jit.script
11439            def range_non_float():
11440                for i in range(.5):
11441                    print(i)
11442
11443    def test_parse_empty_tuple_annotation(self):
11444        cu = torch.jit.CompilationUnit('''
11445            def foo(x : Tuple[()]) -> Tuple[()]:
11446                return x
11447        ''')
11448
11449        foo_code = cu.find_function('foo').code
11450        FileCheck().check("Tuple[()]").check("Tuple[()]").run(foo_code)
11451
11452    def test_parse_empty_tuple_annotation_element_error(self):
11453        with self.assertRaisesRegex(
11454                RuntimeError, 'Tuple literal in Tuple type annotation must not have any elements'):
11455            cu = torch.jit.CompilationUnit('''
11456                def foo(x : Tuple[(int,)]) -> Tuple[(int,)]:
11457                    return x
11458            ''')
11459
11460    def test_parse_none_type_annotation(self):
11461        cu = torch.jit.CompilationUnit('''
11462            def foo(x : NoneType) -> NoneType:
11463                return x
11464        ''')
11465
11466        foo_code = cu.find_function('foo').code
11467        FileCheck().check(": NoneType").check("-> NoneType").run(foo_code)
11468
11469    def test_empty_tuple_str(self):
11470        empty_tuple_type = torch._C.TupleType([])
11471        g = {'Tuple' : typing.Tuple}
11472        python_type = eval(empty_tuple_type.annotation_str, g)
11473        assert python_type is typing.Tuple[()]
11474
11475    def test_tuple_str(self):
11476        tuple1_type = torch._C.TupleType([torch._C.StringType.get()])
11477        self.assertEqual(tuple1_type.annotation_str, "Tuple[str]")
11478        tuple2_type = torch._C.TupleType([torch._C.StringType.get(), torch._C.StringType.get()])
11479        self.assertEqual(tuple2_type.annotation_str, "Tuple[str, str]")
11480
11481    def test_dict_str(self):
11482        dict_type = torch._C.DictType(torch._C.StringType.get(), torch._C.StringType.get())
11483        self.assertEqual(dict_type.annotation_str, "Dict[str, str]")
11484
11485    def test_none_type_str(self):
11486        none_type = torch._C.NoneType.get()
11487        g = {'NoneType' : type(None)}
11488        python_type = eval(none_type.annotation_str, g)
11489        assert python_type is type(None)
11490
11491    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
11492    def test_zip_enumerate_modulelist(self):
11493        class Sub(torch.nn.Module):
11494            def forward(self, thing):
11495                return thing - 2
11496
11497        class Double(torch.nn.Module):
11498            def forward(self, thing):
11499                return thing * 2
11500
11501        # zipping over two
11502        class ZipModLists(torch.nn.Module):
11503            def __init__(self, mods, mods2):
11504                super().__init__()
11505                self.mods = mods
11506                self.mods2 = mods2
11507
11508            def forward(self, x):
11509                iter = 0
11510                for mod1, mod2 in zip(self.mods, self.mods2):
11511                    x = mod2(mod1(x))
11512                    iter += 1
11513                return x, iter
11514
11515        class ZipWithValues(torch.nn.Module):
11516            __constants__ = ['tup_larger', 'tup_smaller']
11517
11518            def __init__(self, mods, mods2):
11519                super().__init__()
11520                self.mods = mods
11521                self.mods2 = mods2
11522                self.tup_larger = list(range(len(mods2) + 1))
11523                self.tup_smaller = list(range(max(len(mods2) + 1, 1)))
11524
11525            def forward(self, x):
11526                iter = 0
11527                x2 = x
11528                for val, mod1, mod2 in zip(self.tup_larger, self.mods, self.mods2):
11529                    x = mod2(mod1(x)) + val
11530                    iter += 1
11531                for val, mod1, mod2 in zip(self.tup_smaller, self.mods, self.mods2):
11532                    x2 = mod2(mod1(x2)) + val
11533                    iter += 1
11534                return x, iter
11535
11536        mods = nn.ModuleList([Double()]), nn.ModuleList([Double(), Sub(), Sub()]), nn.ModuleList([Sub(), Double()])
11537        for i in range(len(mods)):
11538            for j in range(len(mods)):
11539                mod = ZipModLists(mods[i], mods[j])
11540                self.checkModule(mod, (torch.tensor(.5),))
11541                mod2 = ZipWithValues(mods[i], mods[j])
11542                self.checkModule(mod2, (torch.tensor(.5),))
11543
11544
11545    def test_enumerate_modlist_range(self):
11546        class Double(torch.nn.Module):
11547            def forward(self, thing):
11548                return thing * 2
11549
11550        class Mod(torch.nn.Module):
11551            def __init__(self) -> None:
11552                super().__init__()
11553                self.mods = nn.ModuleList([Double(), Double()])
11554
11555            def forward(self, x):
11556                x2 = x
11557                iter = 0
11558                for val, mod in enumerate(self.mods):
11559                    x2 = mod(x2) * val
11560                    iter += 1
11561                return iter, x, x2
11562
11563        self.checkModule(Mod(), (torch.tensor(.5),))
11564
11565        # variable length, modulelist
11566        class Mod2(Mod):
11567            def forward(self, x):
11568                for val, mod in zip(range(int(x)), self.mods):
11569                    x = mod(x) * val
11570                return x
11571
11572        with self.assertRaisesRegex(Exception, "that does not have a statically determinable length"):
11573            torch.jit.script(Mod2())
11574
11575        # modulelist, variable length
11576        class Mod3(Mod):
11577            def forward(self, x):
11578                for val, mod in zip(self.mods, range(int(x))):
11579                    x = mod(x) * val
11580                return x
11581
11582        with self.assertRaisesRegex(Exception, "that does not have a statically determinable length"):
11583            torch.jit.script(Mod3())
11584
11585    def test_for_in_enumerate(self):
11586        def fn(x):
11587            # type: (List[int]) -> int
11588            sum = 0
11589            for (i, v) in enumerate(x):
11590                sum += i * v
11591
11592            return sum
11593
11594        self.checkScript(fn, ([1, 2, 3, 4, 5],))
11595
11596        def fn_enumerate_start_arg(x):
11597            # type: (List[int]) -> int
11598            sum = 0
11599            for (i, v) in enumerate(x, 1):
11600                sum += i * v
11601
11602            return sum
11603
11604        self.checkScript(fn_enumerate_start_arg, ([1, 2, 3, 4, 5],))
11605
11606        def fn_enumerate_start_kwarg(x):
11607            # type: (List[int]) -> int
11608            sum = 0
11609            for (i, v) in enumerate(x, start=1):
11610                sum += i * v
11611
11612            return sum
11613
11614        self.checkScript(fn_enumerate_start_kwarg, ([1, 2, 3, 4, 5],))
11615
11616        def fn_nested_enumerate(x):
11617            # type: (List[int]) -> int
11618            sum = 0
11619            for (i, (j, v)) in enumerate(enumerate(x)):
11620                sum += i * j * v
11621
11622            return sum
11623
11624        self.checkScript(fn_nested_enumerate, ([1, 2, 3, 4, 5],))
11625
11626        with self.assertRaisesRegex(RuntimeError, r'enumerate expected at least 1 arguments, got 0'):
11627            @torch.jit.script
11628            def enumerate_no_arg(x):
11629                # type: (List[int]) -> int
11630                sum = 0
11631                for _ in enumerate():
11632                    sum += 1
11633
11634                return sum
11635
11636        with self.assertRaisesRegex(RuntimeError, r'enumerate expected at most 2 arguments, got 3'):
11637            @torch.jit.script
11638            def enumerate_too_many_args(x):
11639                # type: (List[int]) -> int
11640                sum = 0
11641                for _ in enumerate(x, x, x):
11642                    sum += 1
11643
11644                return sum
11645
11646    def test_list_comprehension_modulelist(self):
11647        class Inner(torch.nn.Module):
11648            def forward(self, x):
11649                return x + 10
11650
11651        class M(torch.nn.Module):
11652            def __init__(self, mod_list):
11653                super().__init__()
11654                self.module_list = mod_list
11655
11656            def forward(self, x):
11657                out = torch.jit.annotate(List[Tensor], [mod(x) for mod in self.module_list])
11658                return out
11659
11660        mod = M(nn.ModuleList([Inner(), Inner()]))
11661        self.checkModule(mod, (torch.tensor(3),))
11662
11663        mod = M(nn.ModuleList([]))
11664        torch.jit.script(mod)
11665
11666        class M2(M):
11667            def __init__(self, mod_list):
11668                super().__init__(mod_list)
11669
11670            def forward(self, x):
11671                out = [mod(x) for mod in self.module_list]
11672                return out
11673
11674        mod = M2(nn.ModuleList([Inner(), Inner()]))
11675        self.checkModule(mod, (torch.tensor(3),))
11676
11677        mod = M2(nn.ModuleList([]))
11678        # defaults to List of Tensor for empty modulelist
11679        self.assertEqual(torch.jit.script(mod)(torch.tensor(.5)), [])
11680
11681        def bad_type_annotation():
11682            out = torch.jit.annotate(int, [x for x in [1, 2, 3]])  # noqa: C416
11683            return out
11684
11685        with self.assertRaisesRegex(Exception, "Expected an annotation"
11686                                    " of type List"):
11687            torch.jit.script(bad_type_annotation)
11688
11689    def test_list_comprehension_variable_write(self):
11690        # i in comprehension doesn't write to function scope
11691        def foo():
11692            i = 1
11693            x = [i if i != 5 else 3 for i in range(7)]  # noqa: C416
11694            return i, x
11695
11696        self.assertEqual(foo(), torch.jit.script(foo)())
11697
11698    def test_for_in_zip(self):
11699        def fn(x, y):
11700            # type: (List[int], List[int]) -> int
11701            sum = 0
11702            for (i, j) in zip(x, y):
11703                sum += i * j
11704
11705            return sum
11706
11707        self.checkScript(fn, ([1, 2, 3, 4, 5], [2, 3, 4, 5, 6]))
11708
11709        def fn_multi_inputs(x, y, z):
11710            # type: (List[int], List[int], List[int]) -> int
11711            sum = 0
11712            for (i, j, k) in zip(x, y, z):
11713                sum += i * j * k
11714
11715            return sum
11716
11717        self.checkScript(fn_multi_inputs, ([1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6]))
11718
11719        def fn_nested_zip(x, y, z):
11720            # type: (List[int], List[int], List[int]) -> int
11721            sum = 0
11722            for (i, (j, k)) in zip(x, zip(y, z)):
11723                sum += i * j * k
11724
11725            return sum
11726
11727        self.checkScript(fn_multi_inputs, ([1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6]))
11728
11729        with self.assertRaisesRegex(RuntimeError, r'zip expected at least 1 arguments, got 0'):
11730            @torch.jit.script
11731            def zip_no_arg(x):
11732                # type: (List[int]) -> int
11733                sum = 0
11734                for _ in zip():
11735                    sum += 1
11736
11737                return sum
11738
11739        with self.assertRaisesRegex(RuntimeError, r'too many values to unpack: need 2 but found 3'):
11740            @torch.jit.script
11741            def fn_nested_zip_wrong_target_assign(x, y, z):
11742                # type: (List[int], List[int], List[int]) -> int
11743                sum = 0
11744                for (i, (j, k)) in zip(x, y, z):
11745                    sum += i * j * k
11746
11747                return sum
11748
11749    def test_for_in_zip_enumerate(self):
11750        def fn_zip_enumerate(x, y):
11751            # type: (List[int], List[int]) -> int
11752            sum = 0
11753            for (i, (j, v), k) in zip(x, enumerate(y), range(0, 100)):
11754                sum += i * j * v * k
11755
11756            return sum
11757
11758        self.checkScript(fn_zip_enumerate, ([1, 2, 3, 4], [2, 3, 4, 5]))
11759
11760        def fn_enumerate_zip(x, y):
11761            # type: (List[int], List[int]) -> int
11762            sum = 0
11763            for (i, (j, v)) in enumerate(zip(x, y)):
11764                sum += i * j * v
11765
11766            return sum
11767
11768        self.checkScript(fn_enumerate_zip, ([1, 2, 3, 4], [2, 3, 4, 5]))
11769
11770    def test_for_in_tensors(self):
11771        def test_sizes(x):
11772            sumz = 0
11773            for s in x:
11774                sumz += 1
11775            return sumz
11776        self.checkScript(test_sizes, (torch.rand(5, 4, 3, 2, 1),))
11777        self.checkScript(test_sizes, (torch.rand(777),))
11778        self.checkScript(test_sizes, (torch.rand(0),))
11779
11780    def test_for_in_tensors_rank0(self):
11781        with self.assertRaisesRegex(RuntimeError, "of a 0-d tensor"):
11782            @torch.jit.script
11783            def test_sizes(x):
11784                sumz = 0
11785                for s in x:
11786                    sumz += 1
11787                return sumz
11788
11789            test_sizes(torch.tensor(1))
11790
11791    def test_for_in_tensors_fail_scalar(self):
11792        with self.assertRaisesRegex(RuntimeError, "'float' object is not iterable"):
11793            @torch.jit.script
11794            def test_sizes(x):
11795                # type: (float) -> int
11796                sumz = 0
11797                for s in x:
11798                    sumz += 1
11799                return sumz
11800
11801            test_sizes(0.0)
11802
11803    def test_for_in_tensors_nested(self):
11804        def test_sizes(x):
11805            sumz = 0
11806            for n in x:
11807                for t in n:
11808                    sumz += 1
11809            return sumz
11810
11811        self.checkScript(test_sizes, (torch.rand(5, 4, 3, 2, 1),))
11812
11813    # to avoid defining sum_list in multiple tests
11814    def get_sum_list_fn(self):
11815        def sum_list(a):
11816            # type: (List[int]) -> int
11817            sum = 0
11818            for i in a:
11819                sum += i
11820
11821            return sum
11822
11823        return sum_list
11824
11825    def test_sum_list_diff_elms(self):
11826        self.checkScript(self.get_sum_list_fn(), ([1, 2, 3, 4, 5],))
11827
11828    def test_sum_list_empty(self):
11829        self.checkScript(self.get_sum_list_fn(), ([],))
11830
11831    def test_sum_list_one(self):
11832        self.checkScript(self.get_sum_list_fn(), ([1],))
11833
11834    def test_sum_list_literal(self):
11835
11836        def sum_list():
11837            # type: () -> int
11838            sum = 0
11839            for i in [1, 2, 3, 4, 5]:
11840                sum += i
11841
11842            return sum
11843
11844        self.checkScript(sum_list, ())
11845
11846    def test_sum_list_wrong_type(self):
11847
11848        with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"):
11849            @torch.jit.script
11850            def sum_list(a):
11851                # type: (int) -> int
11852                sum = 0
11853                for i in a:  # noqa: T484
11854                    sum += i
11855
11856                return sum
11857
11858            sum_list(1)
11859
11860    def test_list_iterables(self):
11861        with self.assertRaisesRegex(RuntimeError, 'List of iterables is not supported currently'):
11862            cu = torch.jit.CompilationUnit('''
11863            def list_iterables(x):
11864                for i, j in [2, 3, 4], [5, 6, 7]:
11865                    x += i
11866                    x += j
11867                return x
11868            ''')
11869
11870    def test_for_in_string(self):
11871        def test_strings(x):
11872            # type: (str) -> str
11873            reverse = ""
11874            for c in x:
11875                reverse = c + reverse
11876            return reverse
11877
11878        self.checkScript(test_strings, ("hello",))
11879        self.checkScript(test_strings, ("",))
11880
11881        def test_list_strings(x):
11882            # type: (List[str]) -> str
11883            result = ""
11884            for sub_str in x:
11885                result += sub_str
11886            return result
11887
11888        self.checkScript(test_list_strings, (["hello", "world"],))
11889        self.checkScript(test_list_strings, (["hello", " ", "world", ""],))
11890
11891    def test_for_in_dict(self):
11892        def test_dicts(x):
11893            # type: (Dict[str, int]) -> int
11894            sum = 0
11895            for key in x:
11896                sum += x[key]
11897            return sum
11898
11899        self.checkScript(test_dicts, ({"a": 1, "b": 2, "c": 3},))
11900
11901        def test_dict_keys_values(x):
11902            # type: (Dict[str, int]) -> Tuple[str, int]
11903            key_str = ""
11904            sum = 0
11905            for key in x.keys():
11906                key_str += key
11907            for val in x.values():
11908                sum += val
11909            return key_str, sum
11910
11911        self.checkScript(test_dicts, ({"a": 1, "b": 2, "c": 3},))
11912
11913    def test_for_tuple_unpack(self):
11914        def for_tuple_unpack(x, y):
11915            for i, j in [[3, 4], [5, 6], [7, 8]]:
11916                x += i
11917                y += j
11918            return x, y
11919
11920        self.checkScript(for_tuple_unpack, (torch.tensor(3), torch.tensor(5)))
11921
11922        def nested_tuple_unpack(x, y):
11923            # type: (List[int], List[int]) -> int
11924            sum = 0
11925            for i, (j, k), v in zip(x, enumerate(x), y):
11926                sum += i + j + k + v
11927            return sum
11928
11929        self.checkScript(nested_tuple_unpack, ([1, 3, 5], [2, 4, 6]))
11930
11931    def test_for_tuple_assign(self):
11932        def test_simple_assign(x):
11933            # type: (Tuple[int, float]) -> float
11934            sum = 0.0
11935            for a in x:
11936                sum += float(a)
11937            return sum
11938
11939        self.checkScript(test_simple_assign, ((1, 2.5),))
11940
11941        def test_tuple_assign(x):
11942            # type: (Tuple[Tuple[int, int], Tuple[int, int]]) -> int
11943            sum = 0
11944            for a in x:
11945                sum += a[0]
11946                sum += a[1]
11947            return sum
11948
11949        self.checkScript(test_tuple_assign, (((1, 2), (4, 7)), ))
11950
11951    def test_single_starred_lhs(self):
11952        with self.assertRaisesRegex(RuntimeError, 'A Starred expression may only appear on the lhs within the presence'
11953                                                  ' of another non-starred expression'):
11954            cu = torch.jit.CompilationUnit('''
11955            def single_starred_lhs(x):
11956                a = (x, x, x)
11957                *b, = a
11958                return b
11959            ''')
11960
11961    def test_singleton_tuple_unpack(self):
11962        def foo(a):
11963            b, = (a,)
11964            return b + 1
11965        self.checkScript(foo, (torch.rand(3),))
11966
11967    def test_tuple_assignments(self):
11968        def var_tuple_assign(x, y):
11969            # type: (Tuple[Tensor, Tensor], Tensor) -> Tensor
11970            (a, b), c = x, y
11971            return a + b + c
11972
11973        tuple_inputs = (torch.randn(1, 4), torch.randn(3, 4))
11974        self.checkScript(var_tuple_assign, (tuple_inputs, torch.randn(3, 4)))
11975
11976        def nested_tuple_assign(x, y, z):
11977            # type: (int, Tuple[int, Tuple[int, int]], Tuple[int, int]) -> int
11978            a, (b, (c, d)), (e, f) = x, y, z
11979            return a + b + c + d + e + f
11980
11981        self.checkScript(nested_tuple_assign, ((1, (2, (3, 4)), (5, 6))))
11982
11983        def subscript_tuple_assign(a, x, i):
11984            # type: (List[int], Tensor, int) -> Tuple[int, Tensor, int]
11985            a[i], (x[i], b) = 1, (2, 3)
11986            return a[i] + 1, x + 5, b
11987
11988        self.checkScript(subscript_tuple_assign, ([12, 7, 9, 11], torch.tensor((3, 13, 17)), 0))
11989
11990        def star_tuple_assign():
11991            # type: () -> Tuple[int, int, Tuple[int, int], Tuple[int, int]]
11992            a, (b, *c), *d = 1, (2, 3, 4), 5, 6
11993            return a, b, c, d
11994
11995        self.checkScript(star_tuple_assign, ())
11996
11997        def subscript_tuple_augmented_assign(a):
11998            # type: (Tuple[int, int]) -> Tuple[int, int]
11999            a[0] += 1
12000            return a
12001
12002        with self.assertRaisesRegex(RuntimeError, 'does not support augmented assign'):
12003            scripted_aug_assign = torch.jit.script(subscript_tuple_augmented_assign)
12004
12005        class AttrTupleAssignmentTestClass:
12006            def __init__(self, a: int, b: int):
12007                self.a = a
12008                self.b = b
12009
12010            def set_ab(self, a: int, b: int):
12011                self.a, self.b = (a, b)
12012
12013            def get(self) -> Tuple[int, int]:
12014                return (self.a, self.b)
12015
12016        make_global(AttrTupleAssignmentTestClass)
12017
12018        @torch.jit.script
12019        def attr_tuple_assignment(o: AttrTupleAssignmentTestClass, a: int, b: int):
12020            o.set_ab(a, b)
12021            return o
12022
12023        o = AttrTupleAssignmentTestClass(1, 2)
12024        self.assertEqual(attr_tuple_assignment(o, 3, 4).get(), (3, 4))
12025
12026    def test_multiple_assign(self):
12027        def test():
12028            a = b, c = d, f = (1, 1)
12029
12030            # side effect
12031            ten = torch.tensor(1)
12032            ten1 = ten2 = ten.add_(1)
12033
12034            # ordering
12035            x = 1
12036            y = 3
12037            x, y = y, x + y
12038
12039            return a, b, c, d, f, ten, ten1, ten2, x, y
12040
12041        self.checkScript(test, ())
12042
12043    def test_multi_reduction(self):
12044        with self.assertRaisesRegex(
12045                RuntimeError,
12046                'augmented assignment can only have one LHS expression'):
12047            cu = torch.jit.CompilationUnit('''
12048            def multi_reduction(x):
12049                a, b += x
12050                return a, b
12051            ''')
12052
12053    def test_invalid_call_arguments(self):
12054        with self.assertRaisesRegex(RuntimeError, 'but instead found type '):
12055            @torch.jit.script
12056            def invalid_call_arguments(x):
12057                return torch.unsqueeze(3, 4, 5, 6, 7, 8)
12058
12059    def test_invalid_lhs_assignment(self):
12060        with self.assertRaisesRegex(RuntimeError, 'unexpected expression'):
12061            cu = torch.jit.CompilationUnit('''
12062            def invalid_lhs_assignment(x):
12063                x + 1 = x
12064                return x
12065            ''')
12066
12067    def test_multi_starred_expr_lhs(self):
12068        with self.assertRaisesRegex(RuntimeError, 'Only one starred expression is allowed on the lhs'):
12069            cu = torch.jit.CompilationUnit('''
12070            def multi_starred_expr_lhs():
12071                a, *b, *c = [1, 2, 3, 4, 5, 6]
12072                return a
12073            ''')
12074
12075    def test_pack_tuple_into_non_var(self):
12076        with self.assertRaisesRegex(RuntimeError, 'Cannot pack a tuple into a non-variable'):
12077            cu = torch.jit.CompilationUnit('''
12078            def pack_tuple_into_non_var(x):
12079                a, *1 = (3, 4, 5)
12080                return x
12081            ''')
12082
12083    def test_print_kwargs(self):
12084        with self.assertRaisesRegex(RuntimeError, 'print doesn\'t accept any keyword arguments'):
12085            cu = torch.jit.CompilationUnit('''
12086            def print_kwargs(x):
12087                print(x, flush=True)
12088                return x
12089            ''')
12090
12091    def test_builtin_use_as_value(self):
12092        with self.assertRaisesRegex(RuntimeError, 'builtin cannot be used as a value'):
12093            @torch.jit.script
12094            def builtin_use_as_value(x):
12095                return x.unsqueeze
12096
12097    def test_wrong_use_as_tuple(self):
12098        with self.assertRaisesRegex(RuntimeError, 'cannot be used as a tuple'):
12099            def test_fn():
12100                return 3
12101
12102            @torch.jit.script
12103            def wrong_use_as_tuple(self):
12104                a, b = test_fn
12105                return a
12106
12107    def test_wrong_attr_lookup(self):
12108        with self.assertRaisesRegex(RuntimeError, 'attribute lookup is not defined on builtin'):
12109            @torch.jit.script
12110            def wrong_attr_lookup(self, x):
12111                a = x.unsqueeze.myattr
12112                return a
12113
12114    def test_wrong_use_as_callable(self):
12115        with self.assertRaisesRegex(RuntimeError, 'cannot call a value'):
12116            @torch.jit.script
12117            def wrong_use_as_callable(x):
12118                return x(3, 4, 5)
12119
12120    def test_python_val_doesnt_have_attr(self):
12121        with self.assertRaisesRegex(RuntimeError, 'object has no attribute abcd'):
12122
12123            @torch.jit.script
12124            def python_val_doesnt_have_attr():
12125                # this has to be a module otherwise attr lookup would not be
12126                # allowed in the first place
12127                return shutil.abcd
12128
12129    def test_wrong_module_attr_lookup(self):
12130        with self.assertRaisesRegex(RuntimeError, 'python value of type \'type\' cannot be used as a value'):
12131            import io
12132
12133            @torch.jit.script
12134            def wrong_module_attr_lookup():
12135                return io.BytesIO
12136
12137    def test_wrong_method_call_inputs(self):
12138        with self.assertRaisesRegex(RuntimeError, 'Argument y not provided'):
12139            class SomeModule(torch.jit.ScriptModule):
12140
12141                @torch.jit.script_method
12142                def foo(self, x, y):
12143                    return x
12144
12145                @torch.jit.script_method
12146                def forward(self, x, y):
12147                    return self.foo(x)
12148            SomeModule()
12149
12150    def test_single_starred_expr_for_loop(self):
12151        with self.assertRaisesRegex(RuntimeError, 'A Starred expression may only appear'):
12152            cu = torch.jit.CompilationUnit('''
12153            def test():
12154                x = 0
12155                for *a in [1, 2, 3]:
12156                    x = x + 1
12157                return x
12158            ''')
12159
12160    def test_call_ge(self):
12161        with self.assertRaisesRegex(RuntimeError, 'Expected at most 1 arguments but found 3'):
12162            @_trace(torch.zeros(1, 2, 3))
12163            def foo(x):
12164                return x
12165
12166            @torch.jit.script
12167            def test_fn():
12168                return foo(torch.full([1], 1), torch.full([1], 2), torch.full([1], 3))
12169
12170    def test_wrong_return_type(self):
12171        with self.assertRaisesRegex(RuntimeError, 'but instead got value of type tuple'):
12172            @torch.jit.ignore
12173            def somefunc():
12174                # type: () -> Tuple[Tuple[Tensor, Tensor]]
12175                return torch.zeros(3, 4), torch.zeros(4, 5)  # noqa: T484
12176
12177            @torch.jit.script
12178            def wrong_return_type():
12179                return somefunc()
12180            wrong_return_type()
12181
12182    # Tests for calling between different front-end modes
12183    def test_call_python_fn_from_tracing_fn(self):
12184        def python_fn(x):
12185            return torch.neg(x)
12186
12187        @_trace(torch.rand(3, 4))
12188        def traced_fn(x):
12189            return python_fn(x) + 1
12190
12191        # The neg op in the python function should be properly inlined to the
12192        # graph
12193        FileCheck().check("aten::neg").run(str(traced_fn.graph))
12194
12195    def test_call_python_mod_from_tracing_fn(self):
12196        class PythonMod(torch.nn.Module):
12197            def __init__(self) -> None:
12198                super().__init__()
12199                self.param = torch.nn.Parameter(torch.rand(4, 3), requires_grad=False)
12200
12201            def forward(self, x):
12202                return torch.mm(x, self.param)
12203
12204        pm = PythonMod()
12205
12206        @_trace(torch.rand(3, 4))
12207        def traced_fn(x):
12208            return pm(x) + 1.0
12209
12210        # Note: the parameter self.param from the Python module is inlined
12211        # into the graph
12212        self.assertTrue(len(list(traced_fn.graph.inputs())) == 1)
12213        FileCheck().check("aten::mm").check("aten::add").run(str(traced_fn.graph))
12214
12215    @_tmp_donotuse_dont_inline_everything
12216    def test_call_traced_fn_from_tracing_fn(self):
12217        @_trace(torch.rand(3, 4))
12218        def traced_fn1(x):
12219            return torch.neg(x)
12220
12221        @_trace(torch.rand(3, 4))
12222        def traced_fn(x):
12223            return traced_fn1(x) + 1
12224
12225        FileCheck().check("traced_fn").check("prim::CallFunction").check("aten::add") \
12226            .run(str(traced_fn.graph))
12227
12228    @unittest.skip("error in first class mode")
12229    def test_call_traced_mod_from_tracing_fn(self):
12230        class TracedModule(torch.nn.Module):
12231            def __init__(self) -> None:
12232                super().__init__()
12233                self.param = torch.nn.Parameter(torch.rand(4, 3), requires_grad=False)
12234
12235            def forward(self, x):
12236                return torch.mm(x, self.param)
12237
12238        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
12239
12240        with self.assertRaisesRegex(RuntimeError, "must be registered as submodules"):
12241            @_trace(torch.rand(3, 4))
12242            def traced_fn(x):
12243                return tm(x) + 1.0
12244
12245    @_tmp_donotuse_dont_inline_everything
12246    def test_call_script_fn_from_tracing_fn(self):
12247        @torch.jit.script
12248        def script_fn(x):
12249            return torch.neg(x)
12250
12251        @_trace(torch.rand(3, 4))
12252        def traced_fn(x):
12253            return script_fn(x) + 1
12254
12255        FileCheck().check("prim::CallFunction").check("aten::add").run(str(traced_fn.graph))
12256
12257    @unittest.skip("error in first class mode")
12258    def test_call_script_mod_from_tracing_fn(self):
12259        with self.assertRaisesRegex(RuntimeError, "must be registered as submodules"):
12260            class ScriptMod(torch.jit.ScriptModule):
12261                def __init__(self) -> None:
12262                    super().__init__()
12263                    self.param = torch.nn.Parameter(torch.rand(3, 4), requires_grad=False)
12264
12265                @torch.jit.script_method
12266                def forward(self, x):
12267                    for _i in range(4):
12268                        x += self.param
12269                    return x
12270
12271            sm = ScriptMod()
12272
12273            @_trace(torch.rand(3, 4))
12274            def traced_fn(x):
12275                return sm(x) + 1.0
12276
12277
12278    def test_call_python_fn_from_traced_module(self):
12279        def python_fn(x):
12280            return torch.neg(x)
12281
12282        class TracedModule(torch.nn.Module):
12283            def __init__(self) -> None:
12284                super().__init__()
12285                self.param = torch.nn.Parameter(torch.rand(4, 3))
12286
12287            def forward(self, x):
12288                return torch.mm(python_fn(x), self.param)
12289
12290        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
12291
12292        # Note: parameter self.param from the traced module should appear as
12293        # an input to the graph and the neg op from the Python function should
12294        # be properly inlined
12295        self.assertTrue(len(list(tm.graph.inputs())) == 2)
12296        FileCheck().check("aten::neg").check("aten::mm").run(str(tm.graph))
12297
12298    def test_call_python_mod_from_traced_module(self):
12299        class PythonModule(torch.nn.Module):
12300            def __init__(self) -> None:
12301                super().__init__()
12302                self.param = torch.nn.Parameter(torch.rand(5, 7))
12303
12304            def forward(self, x):
12305                return torch.mm(x, self.param)
12306
12307        class TracedModule(torch.nn.Module):
12308            def __init__(self) -> None:
12309                super().__init__()
12310                self.param = torch.nn.Parameter(torch.rand(4, 5))
12311                self.mod = PythonModule()
12312
12313            def forward(self, x):
12314                return self.mod(torch.mm(x, self.param)) + 1.0
12315
12316        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
12317
12318        FileCheck().check_not("value=<Tensor>").check("aten::mm")\
12319            .check('prim::CallMethod[name="forward"]').check("aten::add") \
12320            .run(str(tm.graph))
12321        FileCheck().check("aten::mm").run(str(tm.mod.graph))
12322
12323    def test_op_dtype(self):
12324
12325        def check_equal_and_dtype(a, b):
12326            self.assertEqual(a, b)
12327            self.assertEqual(a.dtype, b.dtype)
12328
12329        def fn():
12330            a = torch.arange(10)
12331            b = torch.arange(10, dtype=torch.float)
12332            c = torch.arange(1, 10, 2)
12333            d = torch.arange(1, 10, 2, dtype=torch.float)
12334            e = torch.arange(1, 10., 2)
12335            f = torch.arange(1, 10., 2, dtype=torch.float)
12336            return a, b, c, d, e, f
12337
12338        scripted_fn = torch.jit.script(fn)
12339        eager_out = fn()
12340        script_out = scripted_fn()
12341        for a, b in zip(eager_out, script_out):
12342            check_equal_and_dtype(a, b)
12343
12344    def test_floor_div(self):
12345        @torch.jit.script
12346        def foo(a, b):
12347            # type: (int, int) -> int
12348            return a // b
12349        for i in range(-8, 8):
12350            for j in range(-8, 8):
12351                if j != 0:
12352                    self.assertEqual(foo(i, j), i // j)
12353
12354    def test_floordiv(self):
12355        funcs_template = dedent('''
12356        def fn():
12357            ten = {a_construct}
12358            ten_or_scalar = {b_construct}
12359            return ten // ten_or_scalar, torch.floor_divide(ten, ten_or_scalar)
12360        ''')
12361
12362        lhs = ["torch.tensor([5.5, 3.2])", "torch.tensor([2, 2])", "torch.tensor([3, 2])"]
12363        rhs = ["1.5", "2", "4", "1.1"] + lhs
12364        for tensor in lhs:
12365            for tensor_or_scalar in rhs:
12366                funcs_str = funcs_template.format(a_construct=tensor, b_construct=tensor_or_scalar)
12367                scope = {}
12368                execWrapper(funcs_str, globals(), scope)
12369                cu = torch.jit.CompilationUnit(funcs_str)
12370                f_script = cu.fn
12371                f = scope['fn']
12372                self.assertEqual(f_script(), f())
12373
12374    def test_call_python_fn_from_script_fn(self):
12375        @torch.jit.ignore
12376        def python_fn(x):
12377            return torch.neg(x)
12378
12379        @torch.jit.script
12380        def script_fn(x):
12381            return python_fn(x) + 1
12382
12383        # Note: the call to python_fn appears as `^python_fn()` and is called
12384        # as a PythonOp in the interpreter
12385        a = torch.tensor(1)
12386        self.assertEqual(script_fn(a), torch.tensor(0))
12387        FileCheck().check("python_fn").run(str(script_fn.graph))
12388
12389    def test_call_python_mod_from_script_fn(self):
12390        class PythonModule(torch.nn.Module):
12391            def __init__(self) -> None:
12392                super().__init__()
12393                self.param = torch.nn.Parameter(torch.rand(5, 7))
12394
12395            def forward(self, x):
12396                return torch.mm(x, self.param)
12397
12398        pm = PythonModule()
12399
12400        @torch.jit.script
12401        def script_fn(x):
12402            return pm(x) + 1
12403
12404        # Note: call to pm(x) appears as ^<python_value>() in the trace.
12405        # Parameters are NOT inlined.
12406        FileCheck().check("python_value").check("aten::add").run(str(script_fn.graph))
12407
12408    @_tmp_donotuse_dont_inline_everything
12409    def test_call_script_fn_from_script_fn(self):
12410        @torch.jit.script
12411        def script_fn1(x):
12412            return torch.neg(x)
12413
12414        @torch.jit.script
12415        def script_fn(x):
12416            return script_fn1(x) + 1
12417
12418        FileCheck().check("prim::CallFunction").run(str(script_fn.graph))
12419
12420    def test_call_script_mod_from_script_fn(self):
12421        with self.assertRaisesRegex(RuntimeError, "Cannot call a ScriptModule that is not a submodule of the caller"):
12422            class ScriptMod(torch.jit.ScriptModule):
12423                @torch.jit.script_method
12424                def forward(self, x):
12425                    return torch.mm(x, torch.zeros([4, 3]))
12426
12427            sm = ScriptMod()
12428
12429            @torch.jit.script
12430            def script_fn(x):
12431                return sm(x) + 1
12432
12433    def test_call_python_fn_from_script_module(self):
12434        @torch.jit.ignore
12435        def python_fn(x):
12436            return torch.neg(x)
12437
12438        class ScriptMod(torch.jit.ScriptModule):
12439            def __init__(self) -> None:
12440                super().__init__()
12441                self.param = torch.nn.Parameter(torch.rand(4, 3))
12442
12443            @torch.jit.script_method
12444            def forward(self, x):
12445                return python_fn(torch.mm(x, self.param))
12446
12447        sm = ScriptMod()
12448        FileCheck().check("aten::mm").check("python_fn") \
12449            .run(str(sm.forward.graph))
12450
12451    def test_call_python_mod_from_script_module(self):
12452        class PythonMod(torch.nn.Module):
12453            def __init__(self) -> None:
12454                super().__init__()
12455                self.param = torch.nn.Parameter(torch.rand(3, 5))
12456
12457            @torch.jit.ignore
12458            def forward(self, x):
12459                return torch.mm(x, self.param)
12460
12461        class ScriptMod(torch.jit.ScriptModule):
12462            def __init__(self) -> None:
12463                super().__init__()
12464                self.param = torch.nn.Parameter(torch.rand(4, 3))
12465                self.pm = PythonMod()
12466
12467            @torch.jit.script_method
12468            def forward(self, x):
12469                return self.pm(torch.mm(x, self.param))
12470
12471        sm = ScriptMod()
12472        # Note: the call into PythonMod appears as ^forward(). Parameters
12473        # are NOT inlined
12474        FileCheck().check("aten::mm").check("forward").run(str(sm.graph))
12475
12476    @_tmp_donotuse_dont_inline_everything
12477    def test_call_script_fn_from_script_module(self):
12478        @torch.jit.script
12479        def script_fn(x):
12480            return torch.neg(x)
12481
12482        class ScriptMod(torch.jit.ScriptModule):
12483            def __init__(self) -> None:
12484                super().__init__()
12485                self.param = torch.nn.Parameter(torch.rand(4, 3))
12486
12487            @torch.jit.script_method
12488            def forward(self, x):
12489                return script_fn(torch.mm(x, self.param))
12490
12491        sm = ScriptMod()
12492        graph = (sm.forward.graph)
12493        FileCheck().check("aten::mm").check("prim::CallFunction").run(str(graph))
12494
12495    @_tmp_donotuse_dont_inline_everything
12496    def test_call_script_mod_from_script_module(self):
12497        class ScriptMod1(torch.jit.ScriptModule):
12498            def __init__(self) -> None:
12499                super().__init__()
12500                self.param = torch.nn.Parameter(torch.rand(3, 5))
12501
12502            @torch.jit.script_method
12503            def forward(self, x):
12504                return torch.mm(x, self.param)
12505
12506        class ScriptMod(torch.jit.ScriptModule):
12507            def __init__(self) -> None:
12508                super().__init__()
12509                self.param = torch.nn.Parameter(torch.rand(4, 3))
12510                self.tm = ScriptMod1()
12511
12512            @torch.jit.script_method
12513            def forward(self, x):
12514                return self.tm(torch.mm(x, self.param))
12515
12516        sm = ScriptMod()
12517        # Note: the parameters from both modules should appear in the flattened
12518        # input list to the graph. The mm op from ScriptMod1 should be properly
12519        # inlined
12520        # 3 % values in graph input lists, two mms in body
12521        FileCheck().check_count('%', 3).check(":").check_count("mm", 1).check("prim::CallMethod").run(str(sm.graph))
12522
12523    def test_module_with_params_called_fails(self):
12524        with self.assertRaisesRegex(RuntimeError, "Cannot call a ScriptModule that is not a submodule of the caller"):
12525            class ScriptMod(torch.jit.ScriptModule):
12526                def __init__(self) -> None:
12527                    super().__init__()
12528                    self.param = torch.nn.Parameter(torch.rand(3, 3))
12529
12530                @torch.jit.script_method
12531                def forward(self, x):
12532                    return torch.mm(x, self.param)
12533
12534            sm = ScriptMod()
12535
12536            @torch.jit.script
12537            def some_func(x):
12538                return sm(x)
12539
12540    def test_tuple_index_to_list(self):
12541        def test_non_constant_input(a):
12542            # type: (bool) -> int
12543            if a:
12544                b = 1
12545            else:
12546                b = 0
12547            c = (0, 1)
12548            return c[b]
12549
12550        self.checkScript(test_non_constant_input, (True,))
12551        self.checkScript(test_non_constant_input, (False,))
12552
12553        with self.assertRaisesRegex(RuntimeError, "because we cannot resolve the output type"):
12554            @torch.jit.script
12555            def test_non_constant_input(a):
12556                # type: (bool) -> None
12557                if a:
12558                    b = 1
12559                else:
12560                    b = 0
12561                c = (0, 1.1)
12562                print(c[b])
12563
12564    def test_tuple_indexing(self):
12565        def tuple_index(a):
12566            if bool(a):
12567                b = (1, 2)
12568            else:
12569                b = (0, 2)
12570            return b[-2], b[1]
12571
12572        self.checkScript(tuple_index, (torch.tensor([0]),))
12573        self.checkScript(tuple_index, (torch.tensor([1]),))
12574        self.checkScript(tuple_index, (torch.tensor([1]),), optimize=True)
12575        tuple_comp = torch.jit.script(tuple_index)
12576        FileCheck().check_count("TupleIndex", 2, exactly=True).run(str(tuple_comp.graph))
12577
12578        with self.assertRaisesRegex(RuntimeError, "index must be an integer"):
12579            @torch.jit.script
12580            def test_indexing_float():
12581                c = (1, 2)
12582                return c[0.1]
12583
12584        def test_indexing_out_of_bounds_pos():
12585            c = (1, 2)
12586            return c[2]
12587
12588        self.checkScriptRaisesRegex(test_indexing_out_of_bounds_pos, (), Exception,
12589                                    "out of range")
12590
12591        def test_indexing_out_of_bounds_neg():
12592            c = (1, 2)
12593            return c[-3]
12594
12595        self.checkScriptRaisesRegex(test_indexing_out_of_bounds_pos, (), Exception,
12596                                    "out of range")
12597
12598        def negative_index():
12599            tup = (1, 2, 3, 4)
12600            return tup[-1]
12601
12602        self.checkScript(negative_index, [])
12603
12604        def really_negative_index():
12605            tup = (1, 2, 3, 4)
12606            return tup[-100]
12607
12608        self.checkScriptRaisesRegex(really_negative_index, [], Exception, "index out of range")
12609
12610        def negative_slice():
12611            tup = (1, 2, 3, 4)
12612            return tup[-3:4]
12613
12614        self.checkScript(negative_slice, [])
12615
12616        def really_slice_out_of_bounds():
12617            tup = (1, 2, 3, 4)
12618            return tup[-300:4000]
12619
12620        self.checkScript(really_slice_out_of_bounds, [])
12621
12622    def test_namedtuple_attr(self):
12623        def f(x):
12624            return x.max(dim=1).indices + torch.max(x, dim=1).indices
12625
12626        self.checkScript(f, (torch.rand(20, 20, 20),), optimize=True)
12627
12628        with self.assertRaisesRegex(RuntimeError, "object has no attribute or method"):
12629            @torch.jit.script
12630            def g1(x):
12631                return x.max(dim=1).unknown_symbol
12632
12633        with self.assertRaisesRegex(RuntimeError, "object has no attribute or method"):
12634            @torch.jit.script
12635            def g2(x):
12636                print((x, x, x).__doc__)
12637                return x
12638
12639    def test_tuple_len(self):
12640        @torch.jit.script
12641        def foo():
12642            return len((1, "str", None))
12643
12644        self.assertEqual(foo(), 3)
12645
12646        @torch.jit.script
12647        def test_indexing_end_out_of_bounds():
12648            c = (1, 2)
12649            return c[2:10]
12650
12651        self.assertEqual(test_indexing_end_out_of_bounds(), ())
12652
12653    def test_lower_nested_tuples(self):
12654        @torch.jit.script
12655        def test():
12656            return ((1, 2), 3)
12657
12658        self.run_pass('constant_propagation', test.graph)
12659        FileCheck().check("prim::Constant").check_not("TupleConstruct").run(test.graph)
12660        # fails if a tuple can't be lowered
12661        self.run_pass('lower_all_tuples', test.graph)
12662
12663    def test_unwrap_optional_builtin(self):
12664        def test(x):
12665            # type: (Optional[int]) -> int
12666            x = torch.jit._unwrap_optional(x)
12667            x = x + x  # noqa: T484
12668            return x
12669
12670        self.checkScript(test, (3,))
12671
12672        with self.assertRaisesRegex(AssertionError, "Unwrapping null optional"):
12673            test(None)
12674
12675        test_script = torch.jit.script(test)
12676        with self.assertRaisesRegex(RuntimeError, "Unwrapping null optional"):
12677            test_script(None)
12678
12679        @torch.jit.script
12680        def test_test():
12681            return torch.jit._unwrap_optional(1)
12682
12683        with self.assertRaisesRegex(RuntimeError, r"could not be inferred from actual type None"):
12684            @torch.jit.script
12685            def test_no_type():
12686                # type: () -> int
12687                return torch.jit._unwrap_optional(None)
12688
12689    def test_indexing_error(self):
12690        with self.assertRaisesRegex(RuntimeError, "'int' object is not subscriptable"):
12691            @torch.jit.script
12692            def test_wrong_type():
12693                a = 8
12694                return a[0]
12695
12696    def test_unsupported_builtin_error(self):
12697        with self.assertRaisesRegex(RuntimeError,
12698                                    "Python builtin <built-in function hypot> is currently"):
12699            @torch.jit.script
12700            def test_unsupported(a):
12701                return math.hypot(a, 2.0)
12702
12703    def test_annotated_script_fn(self):
12704        @torch.jit.script
12705        def foo(x, y, z):
12706            # type: (Tensor, Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tuple[Tensor, Tensor]]) -> Tensor
12707            return x
12708
12709        self.assertExpected(str(foo.schema))
12710
12711    def test_annotated_script_method(self):
12712        class SM(torch.jit.ScriptModule):
12713            @torch.jit.script_method
12714            def forward(self, x, y):
12715                # type: (Tuple[Tensor, Tensor], Tensor) -> Tuple[Tensor, Tensor, Tensor]
12716                return y, y, y
12717
12718        sm = SM()
12719
12720        self.assertExpectedStripMangled(str(sm.forward.schema))
12721
12722    def test_annotated_script_fn_return_mismatch(self):
12723        with self.assertRaisesRegex(RuntimeError, "but is actually of type"):
12724            @torch.jit.script
12725            def return_tup(x):
12726                # type: (Tensor) -> Tuple[Tuple[Tensor, Tensor], Tensor]
12727                return x, x  # noqa: T484
12728
12729    def test_annotated_script_fn_arg_mismatch(self):
12730        with self.assertRaisesRegex(RuntimeError, r"Arguments for call are not valid"):
12731            @torch.jit.script
12732            def tuple_arg(x):
12733                # type: (Tuple[Tensor, Tensor]) -> Tensor
12734                return x + 1  # noqa: T484
12735
12736    def test_script_non_tensor_args_outputs(self):
12737        @torch.jit.script
12738        def fn(x, y):
12739            # type: (Tensor, float) -> float
12740            return float((x + y).sum())
12741
12742        x = torch.ones(2, 2)
12743        z = fn(x, 1)
12744        self.assertIsInstance(z, float)
12745        self.assertEqual(z, 8.)
12746
12747    @unittest.skip('https://github.com/pytorch/pytorch/issues/9595')
12748    def test_inline_and_run_annotated_script_fn(self):
12749        @torch.jit.script
12750        def to_inline(x, y):
12751            # type: (Tuple[Tensor, Tensor], Tensor) -> Tensor
12752            return y
12753
12754        @torch.jit.script
12755        def some_func(x):
12756            return to_inline((x, x), x)
12757
12758        x = torch.rand(3, 4)
12759        self.assertEqual(some_func(x), x)
12760
12761    def _make_filereader_test_file(self):
12762        filename = tempfile.mktemp()
12763        writer = torch._C.PyTorchFileWriter(filename)
12764        buffers = [os.urandom(size) for size in [random.randint(1, 100) for i in range(20)]]
12765        offsets = []
12766        for i, buf in enumerate(buffers):
12767            writer.write_record(str(i), buf, len(buf))
12768            offsets.append(i)
12769        serialized_offsets = pickle.dumps(offsets)
12770        writer.write_record("meta", serialized_offsets, len(serialized_offsets))
12771        writer.write_end_of_file()
12772        return filename, buffers, serialized_offsets
12773
12774    def test_file_format_serialization(self):
12775        filename, buffers, serialized_offsets = self._make_filereader_test_file()
12776
12777        reader = torch._C.PyTorchFileReader(filename)
12778        serialized_offsets_read = reader.get_record("meta")
12779        parsed_serialized_offsets = pickle.loads(serialized_offsets)
12780
12781        for i, offset in enumerate(parsed_serialized_offsets):
12782            data = reader.get_record(str(offset))
12783            assert data == buffers[i]
12784
12785    def test_file_reader_no_memory_leak(self):
12786        num_iters = 10000
12787        filename, _, _ = self._make_filereader_test_file()
12788
12789        # Load from filename
12790        tracemalloc.start()
12791        for i in range(num_iters):
12792            torch._C.PyTorchFileReader(filename)
12793        _, peak_from_string = tracemalloc.get_traced_memory()
12794        tracemalloc.stop()
12795
12796        # Load from stream
12797        tracemalloc.start()
12798        with open(filename, 'rb') as f:
12799            for i in range(num_iters):
12800                f.seek(0)
12801                torch._C.PyTorchFileReader(f)
12802        _, peak_from_file = tracemalloc.get_traced_memory()
12803        tracemalloc.stop()
12804
12805        # Check if the peak sizes at most differ by an empirically obtained factor
12806        self.assertLess(peak_from_file, peak_from_string * 500)
12807
12808    # for each type, the input type annotation and corresponding return type annotation
12809    def type_input_return_pairs(self):
12810        return [
12811            ('Tensor', 'Tensor'),
12812            ('torch.Tensor', 'Tensor'),
12813            ('str', 'str'),
12814            ('int', 'int'),
12815            ('bool', 'bool'),
12816            ('BroadcastingList3[float]', 'List[float]'),
12817            ('BroadcastingList2[int]', 'List[int]'),
12818            ('List[int]', 'List[int]'),
12819            ('Optional[int]', 'Optional[int]'),
12820        ]
12821
12822    # replacing code input & return type pair
12823    def format_code(self, code, pair):
12824        return code.format(input=pair[0], output=pair[1])
12825
12826    # ***** Type annotation tests ****
12827    # Test combinations of:
12828    # {String frontend, Python AST Frontend}
12829    # {Python 3-style type annotations, MyPy-style type comments}
12830    # {Script method, Script function}
12831
12832    #  String frontend , Python 3-style type annotations , Script function
12833    def test_annot_string_py3_fn(self):
12834        code = '''
12835            def foo(x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
12836                return x, x
12837        '''
12838        test_str = []
12839        for pair in self.type_input_return_pairs():
12840            cu = torch.jit.CompilationUnit(self.format_code(code, pair))
12841            test_str.append(str(cu.foo.schema))
12842        self.assertExpected("\n".join(test_str) + "\n")
12843
12844    #  String frontend , Python 3-style type annotations , Script method
12845    def test_annot_string_py3_method(self):
12846        class TestModule(torch.jit.ScriptModule):
12847            def __init__(self) -> None:
12848                super().__init__()
12849
12850        code = '''
12851            def foo(self, x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
12852                return x, x
12853        '''
12854        test_str = []
12855        for pair in self.type_input_return_pairs():
12856            # clear the class registry as we will be defining foo multiple times
12857            jit_utils.clear_class_registry()
12858            tm = TestModule()
12859            tm.define(self.format_code(code, pair))
12860            test_str.append(str(tm.foo.schema))
12861        self.assertExpectedStripMangled("\n".join(test_str) + "\n")
12862
12863    #  String frontend , MyPy-style type comments , Script function
12864    def test_annot_string_mypy_fn(self):
12865        code = '''
12866            def foo(x, y):
12867                # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
12868                return x, x
12869        '''
12870        test_str = []
12871        for pair in self.type_input_return_pairs():
12872            cu = torch.jit.CompilationUnit(self.format_code(code, pair))
12873            test_str.append(str(cu.foo.schema))
12874        self.assertExpectedStripMangled("\n".join(test_str) + "\n")
12875
12876    #  String frontend , MyPy-style type comments , Script method
12877    def test_annot_string_mypy_method(self):
12878        class TestModule(torch.jit.ScriptModule):
12879            def __init__(self) -> None:
12880                super().__init__()
12881
12882        code = '''
12883        def foo(self, x, y):
12884            # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
12885            return x, x
12886        '''
12887
12888        test_str = []
12889        for pair in self.type_input_return_pairs():
12890            # clear the class registry as we will be defining foo multiple times
12891            jit_utils.clear_class_registry()
12892            tm = TestModule()
12893            tm.define(self.format_code(code, pair))
12894            test_str.append(str(tm.foo.schema))
12895        self.assertExpectedStripMangled("\n".join(test_str) + "\n")
12896
12897    #  Python AST Frontend , Python 3-style type annotations , Script function
12898    def test_annot_ast_py3_fn(self):
12899        code = dedent('''
12900            from typing import Tuple, List, Optional
12901            from torch import Tensor
12902            from torch.jit.annotations import BroadcastingList2, BroadcastingList3
12903            import torch
12904            @torch.jit.script
12905            def foo(x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
12906                return x, x
12907        ''')
12908        test_str = []
12909        for pair in self.type_input_return_pairs():
12910            fn = jit_utils._get_py3_code(self.format_code(code, pair), 'foo')
12911            test_str.append(str(fn.schema))
12912        self.assertExpectedStripMangled("\n".join(test_str) + "\n")
12913
12914    def test_multiline_annot_ast_py3_fn(self):
12915        code = dedent('''
12916            from typing import Tuple, List, Optional
12917            from torch import Tensor
12918            from torch.jit.annotations import BroadcastingList2, BroadcastingList3
12919            import torch
12920            @torch.jit.script
12921            def foo(x,  # type: {input}
12922                    y   # type: Tuple[Tensor, Tensor]
12923                    ):
12924                # type: (...) -> Tuple[{output}, {output}]
12925                return x, x
12926        ''')
12927        test_str = []
12928
12929        for pair in self.type_input_return_pairs():
12930            fn = jit_utils._get_py3_code(self.format_code(code, pair), 'foo')
12931            args = fn.schema.arguments
12932            returns = fn.schema.returns
12933            self.assertEqual(str(args[0].type), pair[1])
12934            self.assertEqual(str(args[1].type), "Tuple[Tensor, Tensor]")
12935            self.assertEqual(str(returns[0].type), f"Tuple[{pair[1]}, {pair[1]}]")
12936
12937    def test_bad_multiline_annotations(self):
12938        with self.assertRaisesRegex(RuntimeError, "Return type line"):
12939            @torch.jit.script
12940            def bad_type_line(a,  # type: Tensor
12941                              b,  # type: Tensor
12942                              c   # type: Tensor
12943                              ):
12944                # type: (int, int, int) -> Tensor
12945                # type: bad type line  # noqa: F723
12946
12947                return a + b + c
12948
12949        with self.assertRaisesRegex(RuntimeError, "Return type line"):
12950            @torch.jit.script
12951            def bad_return_line(a,  # type: Tensor
12952                                b,
12953                                c   # type: Tensor
12954                                ):
12955                # type: (int, int, int) -> Tensor
12956                return a + b + c
12957
12958        # TODO: this should be supported but is difficult to parse
12959        with self.assertRaisesRegex(RuntimeError, "Number of type annotations"):
12960            @torch.jit.script
12961            def missing_type(a,  # type: Tensor
12962                             b,
12963                             c   # type: Tensor
12964                             ):
12965                # type: (...) -> Tensor
12966                return a + b + c
12967
12968    #  Python AST Frontend , Python 3-style type annotations , Script method
12969    def test_annot_ast_py3_method(self):
12970        code = dedent('''
12971            from typing import Tuple, List, Optional
12972            from torch import Tensor
12973            from torch.jit.annotations import BroadcastingList2, \\
12974                BroadcastingList3
12975            import torch
12976            class FooModule(torch.jit.ScriptModule):
12977                @torch.jit.script_method
12978                def foo(self, x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
12979                    return x, x
12980            instance = FooModule()
12981        ''')
12982
12983        test_str = []
12984        for pair in self.type_input_return_pairs():
12985            fn = jit_utils._get_py3_code(self.format_code(code, pair), 'instance')
12986            test_str.append(str(fn.foo.schema))
12987        self.assertExpectedStripMangled("\n".join(test_str) + "\n")
12988
12989    #  Python AST Frontend , MyPy-style type comments , Script function
12990    def test_annot_ast_mypy_fn(self):
12991        code = dedent('''
12992            import torch
12993            @torch.jit.script
12994            def foo(x, y):
12995                # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
12996                return x, x
12997        ''')
12998
12999        test_str = []
13000        for pair in self.type_input_return_pairs():
13001            fn = jit_utils._get_py3_code(self.format_code(code, pair), 'foo')
13002            test_str.append(str(fn.schema))
13003        self.assertExpected("\n".join(test_str) + "\n")
13004
13005    #  Python AST Frontend , MyPy-style type comments , Script method
13006    def test_annot_ast_mypy_method(self):
13007        code = dedent('''
13008            import torch
13009            class FooModule(torch.jit.ScriptModule):
13010                @torch.jit.script_method
13011                def foo(self, x, y):
13012                    # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
13013                    return x, x
13014            instance = FooModule()
13015        ''')
13016
13017        test_str = []
13018        for pair in self.type_input_return_pairs():
13019            fn = jit_utils._get_py3_code(self.format_code(code, pair), 'instance')
13020            test_str.append(str(fn.foo.schema))
13021        self.assertExpectedStripMangled("\n".join(test_str) + "\n")
13022
13023    # Tests that "# type: ignore[*]" is supported in type lines and is
13024    # properly ignored.
13025    def test_mypy_type_ignore(self):
13026        @torch.jit.script
13027        def foo(x):  # type: ignore
13028            return x
13029
13030        @torch.jit.script
13031        def bar(x):  # type: ignore[no-redef]
13032            return x
13033
13034    def test_method_casts_script(self):
13035        cast_types = [
13036            'byte', 'char', 'double', 'float', 'int', 'long', 'short'
13037        ]
13038
13039        for cast_type in cast_types:
13040            cu = torch.jit.CompilationUnit(f'''
13041            def cast_to(x):
13042                return x.{cast_type}()
13043            ''')
13044
13045            x = torch.rand(3, 4, 5) * 128
13046            cu_result = cu.cast_to(x)
13047            reference = getattr(x, cast_type)()
13048            self.assertEqual(cu_result, reference)
13049
13050    def test_string_frontend_elif(self):
13051        code = '''
13052            def func(niter):
13053                # type: (int)
13054                rv = 0
13055                for i in range(niter):
13056                    if i % 3 == 0 and i % 5 == 0:
13057                        rv += 35
13058                    elif i % 3 == 0:
13059                        rv += 3
13060                    elif i % 5 == 0:
13061                        rv += 5
13062                    else:
13063                        rv += i
13064                return rv
13065        '''
13066
13067        self.checkScript(dedent(code), (101,))
13068
13069    def test_module_parameters_and_buffers(self):
13070        weights = torch.randn(10, 10)
13071        bias = torch.randn(10)
13072        weights2 = torch.randn(10, 10)
13073        bias2 = torch.randn(10)
13074
13075        class TestLinear(torch.nn.Module):
13076            def __init__(self, in_features, out_features):
13077                super().__init__()
13078                self.in_features = in_features
13079                self.out_features = out_features
13080                self.weight = torch.nn.Parameter(torch.empty(out_features, in_features))
13081                self.bias = torch.nn.Parameter(torch.empty(out_features))
13082                self.counter = nn.Buffer(torch.ones(out_features))
13083                self.reset_parameters()
13084
13085            def reset_parameters(self):
13086                torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
13087                if self.bias is not None:
13088                    fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
13089                    bound = 1 / math.sqrt(fan_in)
13090                    torch.nn.init.uniform_(self.bias, -bound, bound)
13091
13092            def forward(self, input):
13093                return F.linear(input, self.weight, self.bias) + self.counter
13094
13095        # Initialize a ScriptModule that uses the weak module above multiple times
13096        class Strong(torch.jit.ScriptModule):
13097            def __init__(self) -> None:
13098                super().__init__()
13099                self.fc1 = TestLinear(10, 10)
13100                self.fc1.weight = torch.nn.Parameter(weights)
13101                self.fc1.bias = torch.nn.Parameter(bias)
13102                self.fc2 = TestLinear(10, 10)
13103                self.fc2.weight = torch.nn.Parameter(weights2)
13104                self.fc2.bias = torch.nn.Parameter(bias2)
13105
13106            @torch.jit.script_method
13107            def forward(self, x):
13108                return x + self.fc1(x) + self.fc1(x) + self.fc2(x)
13109
13110        strong_mod = Strong()
13111
13112        # Run same calculation as module
13113        inp = torch.ones(10)
13114        lin = torch.nn.Linear(10, 10)
13115        lin.weight = torch.nn.Parameter(weights)
13116        lin.bias = torch.nn.Parameter(bias)
13117        lin2 = torch.nn.Linear(10, 10)
13118        lin2.weight = torch.nn.Parameter(weights2)
13119        lin2.bias = torch.nn.Parameter(bias2)
13120        expected_result = inp + (lin(inp) + torch.ones(10)) * 2 + lin2(inp) + torch.ones(10)
13121
13122        self.assertEqual(strong_mod(inp), expected_result)
13123        self.assertExportImportModule(strong_mod, (inp,))
13124
13125    def test_module_copying(self):
13126        class Submodule(torch.nn.Module):
13127            def forward(self, x):
13128                return x + 100
13129
13130        class Weak(torch.nn.Module):
13131            def __init__(self, in_features, out_features):
13132                super().__init__()
13133                self.weight = torch.nn.Parameter(torch.ones(out_features, in_features))
13134                self.bias = torch.nn.Parameter(torch.ones(out_features))
13135                self.buffer = nn.Buffer(torch.ones(out_features))
13136                self.submodule = Submodule()
13137
13138            def forward(self, x):
13139                return F.linear(x, self.weight, self.bias) \
13140                    + self.buffer + self.submodule(x)
13141
13142        class Strong(torch.jit.ScriptModule):
13143            def __init__(self, weak):
13144                super().__init__()
13145                self.weak = weak
13146
13147            @torch.jit.script_method
13148            def forward(self, x):
13149                return self.weak(x)
13150
13151        inp = torch.ones(5, 5) * 5
13152        weak_mod = Weak(5, 5)
13153        strong_mod = Strong(weak_mod)
13154
13155        self.assertTrue(isinstance(strong_mod.weak, torch.jit.ScriptModule))
13156        self.assertFalse(isinstance(weak_mod, torch.jit.ScriptModule))
13157
13158        self.assertIs(strong_mod.weak.weight, weak_mod.weight)
13159        self.assertIs(strong_mod.weak.buffer, weak_mod.buffer)
13160        # strong_mod.weak.submodule has been recursively scripted
13161        self.assertIsNot(strong_mod.weak.submodule, weak_mod.submodule)
13162
13163        weak_mod.weight.data += torch.ones(5, 5) * 100
13164        self.assertTrue(strong_mod(inp).allclose(weak_mod(inp)))
13165
13166        # Re-assignment is not tracked
13167        weak_mod.weight = torch.nn.Parameter(torch.ones(5, 5) * 100)
13168        self.assertFalse(strong_mod(inp).allclose(weak_mod(inp)))
13169
13170    def test_backend_cudnn_enabled(self):
13171        # Only test that this compiles
13172        @torch.jit.script
13173        def fn(x):
13174            if torch.backends.cudnn.enabled:
13175                x = x + 2
13176            else:
13177                x = x + 3
13178            return x
13179
13180    def test_inplace_add(self):
13181
13182        def foo(a, b):
13183            c = a + b
13184            c.add_(b)
13185            return c
13186        self.checkScript(foo, (torch.rand(3), torch.rand(3)))
13187
13188    def test_add_out(self):
13189        def foo(a, b):
13190            c = a + b
13191            e = 2 * a
13192            torch.add(c, b, out=e)
13193            return e
13194        self.checkScript(foo, (torch.rand(3), torch.rand(3)))
13195
13196    def test_tuple_error_msg(self):
13197        def fn(t: Any):
13198            if isinstance(t, tuple):
13199                a, b = t
13200            return a + b
13201        with self.assertRaisesRegexWithHighlight(RuntimeError, "Provided tuple is not fully defined/refined", "t"):
13202            s = torch.jit.script(fn)
13203
13204    def test_augmented_assign(self):
13205        def foo(a, b):
13206            a += b
13207            a -= b
13208            a /= b
13209            a *= b
13210            return a, b
13211        self.checkScript(foo, (torch.rand(3), torch.rand(3)))
13212
13213    def test_ignored_props(self):
13214        class A(nn.Module):
13215            __jit_ignored_attributes__ = ["ignored", "ignored_return_val"]
13216
13217            @property
13218            def ignored(self):
13219                raise ValueError("shouldn't be called")
13220
13221            @property
13222            def ignored_return_val(self):
13223                return 1
13224
13225            @torch.jit.ignore
13226            def call(self):
13227                return self.ignored_return_val
13228
13229        f = torch.jit.script(A())
13230        # jank way to test if there is no error
13231        self.assertTrue(isinstance(f, torch.jit.ScriptModule))
13232        self.assertTrue(isinstance(f.call(), property))
13233
13234
13235    def test_pass(self):
13236        def foo(x):
13237            # type: (bool) -> int
13238            for _i in range(3):
13239                pass
13240            if x:
13241                pass
13242            else:
13243                pass
13244            return 3
13245
13246        self.checkScript(foo, (True,))
13247
13248    def test_lhs_indexing(self):
13249        def foo(a, b):
13250            a = a.clone()
13251            a[0] = b
13252            return a
13253        self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
13254
13255    def test_lhs_advanced_indexing_assignment(self):
13256        def foo(x, y):
13257            a = torch.exp(x)
13258            b = x == 1
13259            a[b] = y[b]
13260            return a
13261        self.checkScript(foo, (torch.ones(4, 3), torch.ones(4, 3)))
13262
13263    def test_lhs_advanced_indexing_augmented_assignment(self):
13264        def foo(x, y):
13265            a = torch.exp(x)
13266            b = x == 1
13267            a[b] += y[b]
13268            return a
13269        self.checkScript(foo, (torch.ones(4, 3), torch.ones(4, 3)))
13270
13271    def test_lhs_indexing_list(self):
13272        def foo(a, b):
13273            ls = [a]
13274            ls[0] = b
13275            return ls
13276        self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
13277
13278    def test_inplace_copy_script(self):
13279        def foo(x):
13280            a = torch.rand(3, 4)
13281            a.copy_(x)
13282            return a
13283        self.checkScript(foo, (torch.rand(3, 4),))
13284
13285    def test_lhs_indexing_increment(self):
13286        def foo(a, b):
13287            a[0] += b
13288            return a
13289        self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
13290
13291    def test_lhs_indexing_increment_list(self):
13292        def foo(a, b):
13293            a = a.clone()
13294            ls = [a, b]
13295            ls[0] += b
13296            return ls
13297        self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
13298
13299    def test_lhs_indexing_increment_list_prim(self):
13300        def foo():
13301            ls = [1, 2, 3]
13302            ls[0] += 5
13303            return ls
13304        self.checkScript(foo, ())
13305
13306    def test_lhs_indexing_multi(self):
13307        def foo(a, b):
13308            a = a.clone()
13309            foo, a[0], bar = (1, b, 3)
13310            return foo, a, bar
13311        self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
13312
13313    def test_bool_dispatch(self):
13314        with torch._jit_internal._disable_emit_hooks():  # TODO: Python print broadcasting list
13315            def kwarg_false(x):
13316                # type: (Tensor) -> Tensor
13317                return F.max_pool1d(x, 1, 1, return_indices=False)
13318            self.checkScript(kwarg_false, (torch.randn(3, 3, 3),))
13319
13320            def kwarg_true(x):
13321                # type: (Tensor) -> Tuple[Tensor, Tensor]
13322                return F.max_pool1d(x, 1, 1, return_indices=True)
13323            self.checkScript(kwarg_true, (torch.randn(3, 3, 3),))
13324
13325            def full_kwarg_false(x):
13326                # type: (Tensor) -> Tensor
13327                return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=False)
13328            self.checkScript(full_kwarg_false, (torch.randn(3, 3, 3),))
13329
13330            def full_kwarg_true(x):
13331                # type: (Tensor) -> Tuple[Tensor, Tensor]
13332                return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=True)
13333            self.checkScript(full_kwarg_true, (torch.randn(3, 3, 3),))
13334
13335            def use_default(x):
13336                # type: (Tensor) -> Tensor
13337                return F.max_pool1d(x, 1, 1)
13338            self.checkScript(use_default, (torch.randn(3, 3, 3),))
13339
13340            def arg_false(x):
13341                # type: (Tensor) -> Tensor
13342                return F.max_pool1d(x, 1, 1, 0, 1, False, False)
13343            self.checkScript(arg_false, (torch.randn(3, 3, 3),))
13344
13345            def arg_true(x):
13346                # type: (Tensor) -> Tuple[Tensor, Tensor]
13347                return F.max_pool1d(x, 1, 1, 0, 1, False, True)
13348            self.checkScript(arg_true, (torch.randn(3, 3, 3),))
13349
13350    def test_infer_size(self):
13351        from torch._C import _infer_size
13352
13353        def fn(x, y):
13354            # type: (Tensor, Tensor) -> List[int]
13355            return _infer_size(x.size(), y.size())
13356
13357        self.checkScript(fn, (torch.ones(2, 4, 2), torch.ones(2, 4, 2)))
13358
13359    def test_hash(self):
13360        def tester(fn, inputs):
13361            for x in inputs:
13362                for y in inputs:
13363                    if x == y:
13364                        self.assertEqual(fn(x), fn(y))
13365                    else:
13366                        self.assertNotEqual(fn(x), fn(y))
13367
13368        @torch.jit.script
13369        def int_hash(x):
13370            # type: (int) -> int
13371            return hash(x)
13372
13373        @torch.jit.script
13374        def float_hash(x):
13375            # type: (float) -> int
13376            return hash(x)
13377
13378        @torch.jit.script
13379        def str_hash(x):
13380            # type: (str) -> int
13381            return hash(x)
13382
13383        tester(int_hash, (20, 21, 22))
13384        tester(float_hash, (20.0, 21.00001, 22.443))
13385        tester(str_hash, ("", "hello", "a"))
13386
13387    def test_id(self):
13388        with self.assertRaisesRegex(RuntimeError, "Expected a value"):
13389            @torch.jit.script
13390            def test_id_scalars():
13391                return id(2) == id(None)
13392
13393        @torch.jit.script
13394        class FooTest:
13395            def __init__(self, x):
13396                self.foo = x
13397
13398            def getFooTest(self):
13399                return self.foo
13400
13401        @torch.jit.script
13402        def test_id_class_types():
13403            obj1 = FooTest(torch.tensor(3))
13404            obj2 = FooTest(torch.tensor(2))
13405            assert obj1 is not obj2
13406            assert id(obj1) != id(obj2)
13407            assert id(obj1) != id(None)
13408            return True
13409
13410        self.assertTrue(test_id_class_types())
13411
13412    def test_mutable_dce(self):
13413        @torch.jit.script
13414        def foo():
13415            a = torch.rand(2, 3)
13416            a += torch.rand(2, 3)
13417            b = torch.rand(2, 3)
13418            b += torch.rand(2, 3)
13419            # b should be cleaned up but not a
13420            return a
13421
13422        FileCheck().check_count("aten::rand", 2, exactly=True) \
13423            .check_count("aten::add", 1, exactly=True).run(str(foo.graph))
13424
13425    def test_mutable_dce_block(self):
13426        @torch.jit.script
13427        def foo():
13428            a = torch.rand(2, 3)
13429            a += torch.rand(2, 3)
13430            b = torch.rand(2, 3)
13431            if bool(a > torch.zeros(2, 3)):
13432                b += torch.rand(2, 3)
13433                a += torch.rand(2, 3)
13434            # a should be cleaned up but not b
13435            return b
13436
13437        FileCheck().check("prim::If").check_count("aten::rand", 1, exactly=True) \
13438            .run(str(foo.graph))
13439
13440    def test_mutable_dce_graph_input(self):
13441        @torch.jit.script
13442        def foo(a):
13443            a += torch.rand(2, 3)
13444            # shouldn't clean up `a` even though it's not used in the output
13445
13446        FileCheck().check("aten::rand").check("aten::add").run(str(foo.graph))
13447
13448    def test_mutable_dce_list(self):
13449        @torch.jit.script
13450        def foo(a):
13451            l = []
13452            l.append(a)
13453            c = l[0]
13454            b = torch.rand(2, 3)
13455            c += torch.rand(2, 3)
13456            return b
13457
13458        # c does not get cleaned up because there is a wildcard + mutation
13459        FileCheck().check_count("aten::rand", 2, exactly=True).run(str(foo.graph))
13460
13461    def test_mutable_dce_loop(self):
13462        @torch.jit.script
13463        def foo(a):
13464            l = []
13465            l.append(a)
13466            i = 0
13467            b = torch.rand(2, 3)
13468            while i < 1:
13469                dead = torch.rand(2, 3)
13470                c = l[0]
13471                c += torch.rand(2, 3)
13472                i += 1
13473            return b
13474
13475        FileCheck().check("prim::Loop").check_not("aten::rand").check("aten::__getitem__") \
13476            .check_count("aten::rand", 1, exactly=True).run(str(foo.graph))
13477
13478    def test_mutable_dce_indirect_wildcards(self):
13479        def fn():
13480            x = torch.ones(2, 3)
13481            x_1 = x.view(-1)
13482            l = []
13483            l.append(x_1)
13484            x_view = l[0]
13485            x.add_(torch.ones(2, 3))
13486            return x_view
13487        self.checkScript(fn, ())
13488
13489    def test_mutable_dce_indirect_wildcard_write(self):
13490        def fn():
13491            indexes = torch.jit.annotate(List[Tensor], [])
13492            word_ids = torch.zeros(10, dtype=torch.int32)
13493            word_ids[1] = 1
13494            indexes.append(word_ids)
13495
13496            return word_ids
13497        self.checkScript(fn, ())
13498
13499    def test_mutable_dce_wildcards(self):
13500        def fn():
13501            x = torch.ones(2, 3)
13502            l = []
13503            l.append(x)
13504            x_view = l[0]
13505            x.add_(torch.ones(2, 3))
13506            return x_view
13507
13508        self.checkScript(fn, (), profiling=ProfilingMode.SIMPLE)
13509
13510    def test_cpp_function_tensor_str(self):
13511        x = torch.randn(2, 2)
13512        scale = torch.randn(2, 2, requires_grad=True)
13513        shift = torch.randn(2, 2, requires_grad=True)
13514
13515        @torch.jit.script
13516        def fn(x, scale, shift):
13517            return scale * x + shift
13518
13519        with self.capture_stdout() as captured:
13520            print(fn(x, scale, shift))
13521
13522    def test_string_index(self):
13523        def fn(x):
13524            # type: (str)
13525            return x[2], x[-1]
13526
13527        self.checkScript(fn, ("abcde",))
13528
13529    def test_ord(self):
13530        def fn(x):
13531            # type: (str) -> int
13532            return ord(x)
13533
13534        self.checkScript(fn, ("h"))
13535        self.checkScript(fn, ("y"))
13536
13537        def index_str_to_tensor(s):
13538            # type: (str) -> Tensor
13539            return torch.tensor(ord(s))  # noqa: T484
13540
13541        s = '\u00a3'.encode()[:1]
13542        self.checkScript(index_str_to_tensor, (s,))
13543
13544    def test_chr(self):
13545        def fn(x):
13546            # type: (int) -> str
13547            return chr(x)
13548
13549        self.checkScript(fn, (1,))
13550        self.checkScript(fn, (97,))
13551
13552    def test_round(self):
13553        def round_float(x):
13554            # type: (float) -> float
13555            return round(x)
13556
13557        def round_int(x):
13558            # type: (int) -> float
13559            return round(x)
13560
13561        self.checkScript(round_float, (1.5,))
13562        self.checkScript(round_int, (2,))
13563
13564    def test_convert_base(self):
13565        def test_hex(x):
13566            # type: (int) -> str
13567            return hex(x)
13568
13569        def test_oct(x):
13570            # type: (int) -> str
13571            return oct(x)
13572
13573        def test_bin(x):
13574            # type: (int) -> str
13575            return bin(x)
13576
13577        numbers = [-1000, -10, 0, 1, 10, 2343]
13578        for n in numbers:
13579            self.checkScript(test_bin, (n,))
13580            self.checkScript(test_oct, (n,))
13581            self.checkScript(test_hex, (n,))
13582
13583    @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle")
13584    def test_get_set_state(self):
13585        class Root(torch.jit.ScriptModule):
13586            __constants__ = ['number']
13587
13588            def __init__(self, number):
13589                super().__init__()
13590                self.buffer1 = nn.Buffer(torch.ones(2, 2))
13591                self.buffer2 = nn.Buffer(torch.ones(2, 2))
13592                self.number = number
13593
13594            @torch.jit.script_method
13595            def __getstate__(self):
13596                return (self.buffer1, self.buffer2, 74, self.training)
13597
13598            @torch.jit.script_method
13599            def __setstate__(self, state):
13600                self.buffer1 = state[0] + 10
13601                self.buffer2 = state[1] + 10
13602                self.training = state[3]
13603
13604        class M(torch.jit.ScriptModule):
13605            __constants__ = ['number']
13606
13607            def __init__(self, number, submodule):
13608                super().__init__()
13609                self.buffer1 = nn.Buffer(torch.ones(2, 2))
13610                self.buffer2 = nn.Buffer(torch.ones(2, 2))
13611                self.number = number
13612                self.submodule = submodule
13613
13614            @torch.jit.script_method
13615            def __getstate__(self):
13616                return (self.buffer1, self.buffer2, 74, self.submodule, self.training)
13617
13618            @torch.jit.script_method
13619            def __setstate__(self, state):
13620                self.buffer1 = state[0] + 10
13621                self.buffer2 = state[1] + 10
13622                self.submodule = state[3]
13623                self.training = state[4]
13624
13625        with TemporaryFileName() as fname:
13626            m = M(23, submodule=Root(99))
13627            m.save(fname)
13628            loaded = torch.jit.load(fname)
13629
13630        # Check original module
13631        self.assertEqual(m.buffer1, torch.ones(2, 2))
13632        self.assertEqual(m.buffer2, torch.ones(2, 2))
13633
13634        # Check top level module
13635        self.assertEqual(loaded.buffer1, torch.ones(2, 2) + 10)
13636        self.assertEqual(loaded.buffer2, torch.ones(2, 2) + 10)
13637
13638        # Check submodule
13639        self.assertEqual(loaded.submodule.buffer1, torch.ones(2, 2) + 10)
13640        self.assertEqual(loaded.submodule.buffer2, torch.ones(2, 2) + 10)
13641
13642        # Check simpler module
13643        class NoArgState(torch.nn.Module):
13644            def __init__(self) -> None:
13645                super().__init__()
13646                self.buffer1 = nn.Buffer(torch.ones(2, 2))
13647                self.buffer2 = nn.Buffer(torch.ones(2, 2))
13648
13649            def forward(self):
13650                pass
13651
13652            @torch.jit.export
13653            def __getstate__(self):
13654                return 5, self.training
13655
13656            @torch.jit.export
13657            def __setstate__(self, state):
13658                self.buffer1 = torch.ones(2, 2) + state[0]
13659                self.buffer2 = torch.ones(2, 2) + 10
13660                self.training = state[1]
13661
13662        with TemporaryFileName() as fname:
13663            m = torch.jit.script(NoArgState())
13664            m.save(fname)
13665            loaded = torch.jit.load(fname)
13666            self.assertEqual(loaded.buffer1, torch.ones(2, 2) + 5)
13667            self.assertEqual(loaded.buffer2, torch.ones(2, 2) + 10)
13668
13669
13670
13671    def test_string_slicing(self):
13672        def fn1(x):
13673            # type: (str) -> str
13674            return x[1:3]
13675
13676        def fn2(x):
13677            # type: (str) -> str
13678            return x[-1:3]
13679
13680        def fn3(x):
13681            # type: (str) -> str
13682            return x[3:1]
13683
13684        def fn4(x):
13685            # type: (str) -> str
13686            return x[3:100]
13687
13688        self.checkScript(fn1, ("abcdefghi",))
13689        self.checkScript(fn2, ("abcdefghi",))
13690        self.checkScript(fn3, ("abcdefghi",))
13691        self.checkScript(fn4, ("abcdefghi",))
13692
13693    def test_early_return_closure(self):
13694        code = dedent('''
13695            def tanh(self):
13696                output = torch.tanh(self)
13697                def backward(grad_output):
13698                    pass
13699                return output, backward
13700        ''')
13701        cu = torch.jit.CompilationUnit(code)
13702        g = cu.tanh.graph
13703        FileCheck().check_count("prim::Closure_0", 2).check("NoneType = prim::Constant") \
13704                   .check_next("return").run(g)
13705
13706        code = dedent('''
13707            def tanh(self):
13708                output = torch.tanh(self)
13709                def backward(grad_output):
13710                    a = 1
13711                    if output:
13712                        return 1
13713                    else:
13714                        a = 2
13715                    return a
13716                return output, backward
13717        ''')
13718        cu = torch.jit.CompilationUnit(code)
13719        g = cu.tanh.graph
13720        FileCheck().check_count("prim::Closure_0", 2).check("int = prim::If") \
13721                   .run(g)
13722
13723        code = dedent('''
13724            def loop_in_closure(self):
13725                output = torch.tanh(self)
13726                def backward(grad_output):
13727                    for i in range(3):
13728                        return 1
13729                    return 4
13730                return output, backward
13731        ''')
13732        cu = torch.jit.CompilationUnit(code)
13733        fc = FileCheck()
13734        fc.check("prim::Closure").check("(Tensor, NoneType) = prim::TupleConstruct")
13735        # Loop then two if's added in exit transform
13736        fc.check("prim::Closure").check("prim::Loop").check_count("prim::If", 2)
13737        fc.run(cu.loop_in_closure.graph)
13738
13739        code = dedent('''
13740            def tanh(self):
13741                output = torch.tanh(self)
13742                def backward(grad_output):
13743                    if 1 == 1:
13744                        return 1
13745                    else:
13746                        return 1.
13747                return output, backward
13748        ''')
13749        with self.assertRaisesRegex(RuntimeError, "returned a value of type int but"):
13750            cu = torch.jit.CompilationUnit(code)
13751
13752    @_inline_everything
13753    def test_early_return_fork_join(self):
13754        @torch.jit.script
13755        def foo(x):
13756            if x.dim() == 2:
13757                return torch.neg(x), x
13758            else:
13759                return torch.neg(x), x + 1
13760
13761        x = torch.rand(3, 4)
13762
13763        @torch.jit.script
13764        def wait_script(x):
13765            fut = torch.jit._fork(foo, x)
13766            y_hat = foo(x)
13767            y = torch.jit._wait(fut)
13768            return y, y_hat
13769
13770        FileCheck().check("with prim::fork").check("prim::If").check("return")\
13771                   .run(wait_script.graph)
13772
13773    def test_early_return_type_refinement(self):
13774        @torch.jit.script
13775        def test(x):
13776            # type: (Optional[int]) -> int
13777            if x is None:
13778                return 1
13779            else:
13780                return x
13781        self.assertEqual(test(None), 1)
13782        self.assertEqual(test(2), 2)
13783
13784    def test_exceptions_with_control_flow(self):
13785        def test_num_ifs(func, num_ifs):
13786            g = torch.jit.script(func).graph
13787            FileCheck().check_count("prim::If", num_ifs, exactly=True).run(g)
13788
13789        def no_guard_ifs_added(x):
13790            # type: (int) -> int
13791            if x == 1:
13792                return 1
13793            else:
13794                if x == 2:
13795                    raise RuntimeError("hi")
13796                else:
13797                    raise RuntimeError("hi")
13798
13799        self.checkScript(no_guard_ifs_added, (1,))
13800        self.checkScriptRaisesRegex(no_guard_ifs_added, (2,), Exception, "")
13801        test_num_ifs(no_guard_ifs_added, 2)
13802
13803        # FUNCTION LOOKS LIKE:
13804        # graph(%x.1 : int):
13805        #   %7 : str = prim::Constant[value="Exception"]()
13806        #   %2 : int = prim::Constant[value=1]()
13807        #   %5 : int = prim::Constant[value=2]()
13808        #   %19 : int = prim::Uninitialized()
13809        #   %3 : bool = aten::eq(%x.1, %2)
13810        #   %20 : int = prim::If(%3)
13811        #     block0():
13812        #       -> (%2)
13813        #     block1():
13814        #       %6 : bool = aten::eq(%x.1, %5)
13815        #        = prim::If(%6)
13816        #         block0():
13817        #            = prim::RaiseException(%7)
13818        #           -> ()
13819        #         block1():
13820        #            = prim::RaiseException(%7)
13821        #           -> ()
13822        #       -> (%19)
13823        #   return (%20)
13824
13825        def no_ifs_added(x):
13826            # type: (int) -> int
13827            if x < 0:
13828                raise RuntimeError("hi")
13829            return x
13830
13831        self.checkScript(no_ifs_added, (1,))
13832        self.checkScriptRaisesRegex(no_ifs_added, (-2,), Exception, "")
13833        test_num_ifs(no_ifs_added, 1)
13834
13835        def test_if_might(x):
13836            # type: (int)
13837            if x > 0:
13838                if x == 1:
13839                    return 1
13840                else:
13841                    a = 2
13842            else:
13843                raise RuntimeError("hi")
13844            return a + 2
13845
13846        self.checkScript(test_if_might, (1,))
13847        self.checkScript(test_if_might, (3,))
13848        self.checkScriptRaisesRegex(no_ifs_added, (-2,), Exception, "")
13849        test_num_ifs(test_if_might, 3)  # one if added to guard a + 2
13850
13851        def test_loop_no_escape(x):
13852            # type: (int)
13853            if x >= 0:
13854                for i in range(x):
13855                    raise RuntimeError("hi")
13856            else:
13857                return 5
13858            return x + 3
13859
13860        self.checkScript(test_loop_no_escape, (0,))
13861        self.checkScript(test_loop_no_escape, (-1,))
13862        self.checkScriptRaisesRegex(test_loop_no_escape, (1,), Exception, "")
13863
13864        # if guard gets optimized away
13865        test_num_ifs(test_loop_no_escape, 1)
13866
13867        def test_loop_exception_with_continue(x):
13868            # type: (int)
13869            i = 0
13870            for i in range(5):
13871                if i == x:
13872                    raise RuntimeError("hi")
13873                else:
13874                    continue
13875                print(i)
13876            return i + 5
13877
13878        self.checkScript(test_loop_exception_with_continue, (-1,))
13879        self.checkScriptRaisesRegex(test_loop_exception_with_continue, (1,), Exception, "")
13880        test_num_ifs(test_loop_exception_with_continue, 1)  # no ifs added to guard print
13881
13882
13883    def test_exception_exits_closure(self):
13884        code = dedent('''
13885            def no_return_func(self):
13886                # type: (Tensor) -> Tensor
13887                output = torch.tanh(self)
13888                def backward(grad_output):
13889                    raise RuntimeError("Hi")
13890        ''')
13891        with self.assertRaisesRegex(RuntimeError, "does not return along all"):
13892            cu = torch.jit.CompilationUnit(code)
13893
13894        code = dedent('''
13895            def test_exit_pair_reset(x):
13896                # type: (int) -> int
13897                if x > 0:
13898                    a = 0
13899                    def backward(grad_output):
13900                        raise RuntimeError("Hi")
13901                    a = a + 1
13902                else:
13903                    return x
13904                return a + 1
13905        ''')
13906        func = torch.jit.CompilationUnit(code).test_exit_pair_reset
13907        self.assertEqual(func(1,), 2)
13908        self.assertEqual(func(-1,), -1)
13909        # final a + 1 gets inlined into the first branch and optimized away
13910        FileCheck().check_count("prim::If", 1, exactly=True).run(func.graph)
13911
13912    def test_non_final_return(self):
13913        def simple(x):
13914            if bool(x > 3):
13915                return x + 1
13916            else:
13917                return x + 2
13918            raise RuntimeError("nope")
13919
13920        def nest(x):
13921            x = x + 1
13922            if bool(x > 3):
13923                if bool(x > 4):
13924                    x += 1
13925                return x + 1
13926            else:
13927                return x + 2
13928
13929        def early_ret(x):
13930            x = x + 1
13931            if bool(x > 3):
13932                return x + 1
13933            x = x + 1
13934            return x + 2
13935
13936        def nest_early_ret(x):
13937            x = x + 1
13938            if bool(x > 3):
13939                if bool(x > 4):
13940                    return x + 2
13941                return x + 1
13942            x = x + 1
13943            return x + 2
13944
13945        def not_early_ret(x):
13946            s = ""
13947            if bool(x > 3):
13948                if bool(x > 4):
13949                    return 1, s
13950                s += "foo"
13951            else:
13952                s += "5"
13953            s += "hi"
13954            return 7, s
13955
13956        def not_total_ret(x):
13957            s = ""
13958            if bool(x > 3):
13959                if bool(x > 4):
13960                    return 1, s
13961                else:
13962                    return 2, s
13963            else:
13964                s += "5"
13965            return 7, s
13966
13967        for i in range(3):
13968            for func in [simple, nest, early_ret, nest_early_ret, not_early_ret,
13969                         not_total_ret]:
13970                self.checkScript(func, (torch.tensor(2.5 + i),))
13971
13972        def vars_used_after_ret(x):
13973            # type: (int) -> int
13974            if x == 0:
13975                return x
13976            else:
13977                y = 2
13978                z = 3
13979            return x + y * z
13980
13981        self.checkScript(vars_used_after_ret, (1,))
13982        self.checkScript(vars_used_after_ret, (0,))
13983
13984        def complicated(x):
13985            # type: (int) -> int
13986            if x:
13987                if x == 2:
13988                    return 1
13989                    assert 1 == 2
13990                else:
13991                    if x == 3:
13992                        return 2
13993                        assert 1 == 2
13994                    else:
13995                        a = 2
13996                        b = 3
13997            else:
13998                a = 4
13999                b = 1
14000            return a + b
14001            assert 1 == 2
14002
14003        for i in range(4):
14004            self.checkScript(complicated, (i,))
14005
14006    def test_partial_returns(self):
14007        with self.assertRaisesRegex(RuntimeError, "does not return along all"):
14008            @torch.jit.script
14009            def no_ret():
14010                # type: () -> int
14011                pass
14012
14013        with self.assertRaisesRegex(RuntimeError, "does not return along all"):
14014            @torch.jit.script
14015            def partial(x):
14016                # type: (Tensor) -> int
14017                if x:
14018                    return 1
14019
14020        with self.assertRaisesRegex(RuntimeError, "does not return along all"):
14021            @torch.jit.script
14022            def typed_none():
14023                # type: () -> Optional[int]
14024                pass
14025
14026        @torch.jit.script
14027        def none_ret():
14028            pass
14029
14030        self.assertIs(none_ret(), None)
14031        FileCheck().check(": None").run(none_ret.graph)
14032
14033    def test_early_returns_loops(self):
14034        def nest_while_ret(x):
14035            # type: (int) -> int
14036            y = 4
14037            while x < 4:
14038                if x < 3:
14039                    return y
14040                else:
14041                    y = y + 1
14042                    break
14043                y = y + 2
14044            y = y + 1
14045            return y
14046
14047        self.checkScript(nest_while_ret, (2,))
14048        self.checkScript(nest_while_ret, (3,))
14049        self.checkScript(nest_while_ret, (4,))
14050
14051        def loop_ret(x, y):
14052            # type: (int, int) -> (int)
14053            i = 0
14054            for i in range(x):
14055                if x == y:
14056                    return x + y
14057                i = i + y
14058            i = i - 1
14059            return i
14060
14061        self.checkScript(loop_ret, (3, 3))
14062        self.checkScript(loop_ret, (2, 3))
14063        self.checkScript(loop_ret, (3, 1))
14064
14065        def test_will_ret(y):
14066            # type: (int) -> int
14067            for i in range(y):
14068                return 2
14069            return 1
14070
14071        self.checkScript(test_will_ret, (0,))
14072        self.checkScript(test_will_ret, (1,))
14073
14074        def test_loop_nest_ret(y):
14075            # type: (int) -> int
14076            for i in range(y):
14077                for i in range(y - 2):
14078                    return 10
14079                return 5
14080            return 0
14081
14082        self.checkScript(test_loop_nest_ret, (0,))
14083        self.checkScript(test_loop_nest_ret, (1,))
14084        self.checkScript(test_loop_nest_ret, (2,))
14085
14086    def test_nn_init(self):
14087        tests = (
14088            ('constant_', (lambda: (torch.ones(2, 2), 2.5)), "Tensor, float"),
14089            ('ones_', (lambda: (torch.ones(2, 2),)), "Tensor"),
14090            ('zeros_', (lambda: (torch.ones(2, 2),)), "Tensor"),
14091            ('uniform_', (lambda: (torch.ones(2, 2),)), "Tensor"),
14092            ('normal_', (lambda: (torch.ones(2, 2),)), "Tensor"),
14093            ('xavier_normal_', (lambda: (torch.ones(2, 2),)), "Tensor"),
14094            ('xavier_uniform_', (lambda: (torch.ones(2, 2),)), "Tensor"),
14095        )
14096
14097        for name, args_fn, type_str in tests:
14098            # Build test code
14099            arg_str = ', '.join([chr(i + ord('a')) for i in range(len(args_fn()))])
14100
14101            code = dedent('''
14102                def test({arg_str}):
14103                    # type: ({type_str})
14104                    return torch.nn.init.{name}({arg_str})
14105            ''').format(arg_str=arg_str, type_str=type_str, name=name)
14106            cu = torch.jit.CompilationUnit(code)
14107
14108            # Compare functions
14109            init_fn = getattr(torch.nn.init, name)
14110            script_out = self.runAndSaveRNG(cu.test, args_fn())
14111            eager_out = self.runAndSaveRNG(init_fn, args_fn())
14112            self.assertEqual(script_out, eager_out)
14113
14114            FileCheck().check_not("prim::PythonOp").run(cu.test.graph)
14115
14116    def test_nn_init_generator(self):
14117        init_fns = (
14118            'uniform_', 'normal_', 'xavier_normal_', 'xavier_uniform_',
14119        )
14120
14121        for name in init_fns:
14122            # Build test code
14123            code = dedent('''
14124                def test(tensor, generator):
14125                    # type: (Tensor, Generator)
14126                    return torch.nn.init.{name}(tensor, generator=generator)
14127            ''').format(name=name)
14128            cu = torch.jit.CompilationUnit(code)
14129
14130            # Compare functions
14131            init_fn = getattr(torch.nn.init, name)
14132
14133            torch.manual_seed(1)
14134
14135            g = torch.Generator()
14136            g.manual_seed(2023)
14137            script_out = cu.test(torch.ones(2, 2), g)
14138
14139            # Change the seed of the default generator to make
14140            # sure that we're using the provided generator
14141            torch.manual_seed(2)
14142
14143            g = torch.Generator()
14144            g.manual_seed(2023)
14145            eager_out = init_fn(torch.ones(2, 2), generator=g)
14146
14147            self.assertEqual(script_out, eager_out)
14148
14149            FileCheck().check_not("prim::PythonOp").run(cu.test.graph)
14150
14151    def test_early_return_rewrite(self):
14152        def test_foo(x: bool):
14153            if x:
14154                return 1
14155            return 2
14156
14157        self.checkScript(test_foo, (True,))
14158        self.checkScript(test_foo, (False,))
14159        FileCheck().check_count("prim::If", 1, exactly=True).run(torch.jit.script(test_foo).graph)
14160
14161        def test_multiple(x: int):
14162            if x == 5:
14163                return x * x
14164            else:
14165                y = 2 * x
14166
14167            z = y * 2
14168            if z == 8:
14169                return 1
14170
14171            if z != 16:
14172                z = z - 2
14173                abc = 4
14174            else:
14175                return 3
14176
14177            z = z * abc
14178            return z * z * z
14179
14180        self.checkScript(test_multiple, (5,))
14181        self.checkScript(test_multiple, (2,))
14182        self.checkScript(test_multiple, (4,))
14183        self.checkScript(test_multiple, (3,))
14184        self.checkScript(test_multiple, (10,))
14185
14186        graph = torch.jit.script(test_multiple).graph
14187        FileCheck().check_count("prim::If", 3, exactly=True).run(graph)
14188
14189    def test_is_scripting_metacompile(self):
14190        @torch.jit.script
14191        def foo():
14192            if torch.jit.is_scripting():
14193                return 1
14194            else:
14195                print("hello") + 2  # will not be compiled
14196
14197        self.assertEqual(foo(), 1)
14198
14199    def test_boolean_literal_constant_metacompile(self):
14200        class Mod(torch.nn.Module):
14201            __constants__ = ['val']
14202
14203            def __init__(self, val):
14204                super().__init__()
14205                self.val = val
14206
14207            def forward(self):
14208                if self.val:
14209                    return 1
14210                else:
14211                    return "2"
14212
14213        self.checkModule(Mod(True), ())
14214        self.checkModule(Mod(False), ())
14215
14216        @torch.jit.script
14217        def foo():
14218            if True:
14219                return 1
14220            else:
14221                return "2"
14222
14223        self.assertEqual(foo(), 1)
14224
14225    def test_assert_is_scripting_metacompile(self):
14226        def foo():
14227            assert not torch.jit.is_scripting(), "TestErrorMsg"
14228            print("hello") + 2  # will not be compiled
14229
14230        f = torch.jit.script(foo)
14231        with self.assertRaisesRegex(torch.jit.Error, "TestErrorMsg"):
14232            f()
14233
14234    def test_isinstance_metacompile(self):
14235        @torch.jit.script
14236        def test_primitive_type(x):
14237            # type: (int) -> int
14238            if isinstance(x, int):
14239                return x + 1
14240            else:
14241                return x - 1
14242
14243        self.assertEqual(test_primitive_type(1), 2)
14244        with self.assertRaisesRegex(Exception, "Expected a value of type"):
14245            test_primitive_type(1.5)
14246
14247        _MyNamedTuple = namedtuple('_MyNamedTuple', ['value'])
14248
14249        @torch.jit.script
14250        def test_non_primitive_types(x):
14251            # type: (_MyNamedTuple) -> Tensor
14252            if isinstance(1, _MyNamedTuple):
14253                return 10
14254
14255            if isinstance(x, _MyNamedTuple):
14256                return x.value + 1
14257            else:
14258                return 1
14259
14260        out = test_non_primitive_types(_MyNamedTuple(value=torch.tensor(5.0)))
14261        self.assertEqual(out, torch.tensor(6.0))
14262
14263    def test_namedtuple_type_inference(self):
14264        _AnnotatedNamedTuple = NamedTuple('_NamedTupleAnnotated', [('value', int)])  # noqa: UP014
14265        _UnannotatedNamedTuple = namedtuple('_NamedTupleUnAnnotated', ['value'])
14266
14267        def test_check_named_tuple_value():
14268            named_tuple = _AnnotatedNamedTuple(1)
14269            return named_tuple.value
14270
14271        self.checkScript(test_check_named_tuple_value, ())
14272
14273        def test_error():
14274            return _UnannotatedNamedTuple(1)
14275
14276        with self.assertRaisesRegex(RuntimeError, r"Expected a value of type \'Tensor \(inferred\)\' "
14277                                                  r"for argument \'value\' but instead found type \'int\'."):
14278            torch.jit.script(test_error)
14279
14280    def test_namedtuple_default_values_simple_type(self):
14281
14282        class Point(NamedTuple):
14283            x: Optional[int] = None
14284            y: int = 2
14285
14286        make_global(Point)
14287
14288        class M(torch.nn.Module):
14289            def forward(self, point: Point):
14290                return point
14291
14292        p = Point(x=3, y=2)
14293
14294        self.checkModule(M(), (p,))
14295        self.checkModule(M(), (Point(),))
14296
14297        m = torch.jit.script(M())
14298
14299        FileCheck().check(r"NamedTuple(x : int? = None, y : int = 2))")   \
14300                   .run(m.graph)
14301
14302    def test_namedtuple_default_values_missing(self):
14303
14304        class Point(NamedTuple):
14305            x: Optional[int]
14306            y: int
14307            z: int = 3
14308
14309        make_global(Point)
14310
14311        class M(torch.nn.Module):
14312            def forward(self, point: Point):
14313                return point
14314
14315        p1 = Point(x=3, y=2)
14316        p2 = Point(x=3, y=2, z=1)
14317
14318        self.checkModule(M(), (p1,))
14319        self.checkModule(M(), (p2,))
14320
14321        m = torch.jit.script(M())
14322
14323        FileCheck().check(r"NamedTuple(x : int?, y : int, z : int = 3))")   \
14324                   .run(m.graph)
14325
14326    def test_namedtuple_default_values_container_type(self):
14327
14328        class Point(NamedTuple):
14329            x: Optional[List[int]] = None
14330            y: List[int] = [1, 2, 3]
14331            z: Optional[Dict[str, int]] = {"a": 1}
14332
14333        make_global(Point)
14334
14335        class M(torch.nn.Module):
14336            def forward(self, point: Point):
14337                return point
14338
14339        p = Point(x=[4, 5, 6], y=[3, 2, 1], z={"b": 2})
14340
14341        self.checkModule(M(), (p,))
14342        self.checkModule(M(), (Point(),))
14343
14344        m = torch.jit.script(M())
14345
14346        first_line = r"NamedTuple(x : int[]? = None, y : int[] = "    \
14347                     r"[1, 2, 3], z : Dict(str, int)? = {a: 1}))"
14348
14349        FileCheck().check(first_line)   \
14350                   .run(m.graph)
14351
14352    def test_namedtuple_default_values_Tensor_type(self):
14353
14354        class Point(NamedTuple):
14355            x: torch.Tensor = torch.rand(2, 3)
14356
14357        make_global(Point)
14358
14359        class M(torch.nn.Module):
14360            def forward(self, point: Point):
14361                return point
14362
14363        p = Point(x=torch.rand(2, 3))
14364
14365        with self.assertRaisesRegex(RuntimeError, "Tensors are not "
14366                                    "supported as default NamedTuple "
14367                                    "fields"):
14368            m = torch.jit.script(M())
14369            m(p)
14370
14371    def test_namedtuple_default_values_using_factory_constructor(self):
14372        Pair = namedtuple("Pair", ["x", "y"], defaults=(1, 2))
14373
14374        make_global(Pair)
14375
14376        @torch.jit.script
14377        def fn(x: Pair) -> Pair:
14378            return x
14379
14380        # TODO: We can't use `checkScript` with the NamedTuple factory
14381        # constructor. Using the factory constructor with TorchScript
14382        # TorchScript creates an anonymous `NamedTuple` class instead of
14383        # preserving the actual name. For example, the actual generated
14384        # signature in this case is:
14385        #   graph(%x.1 : NamedTuple(x : Tensor, y : Tensor))
14386        # It looks like similar test cases have had this issue as well
14387        # (see: `test_namedtuple_python`).
14388        FileCheck().check(r"NamedTuple(x : Tensor = 1, y : Tensor = 2))")   \
14389                   .check_next(r"return (%x.1)")    \
14390                   .run(fn.graph)
14391
14392    def test_isinstance_dynamic(self):
14393        @torch.jit.script
14394        def foo(a):
14395            # type: (Optional[List[int]]) -> int
14396            b = 0
14397            if isinstance(a, (int, (float,), list, str)):
14398                b += 1
14399            if isinstance(a, (int, str)):
14400                b += 1
14401            if isinstance(a, List[int]):
14402                b += 1
14403            return b
14404        self.assertEqual(foo([3, 4]), 2)
14405        self.assertEqual(foo(None), 0)
14406
14407    def test_function_overloads(self):
14408        # TODO: pyflakes currently does not compose @overload annotation with other
14409        # decorators. This is fixed on master but not on version 2.1.1.
14410        # Next version update remove noqa and add @typing.overload annotation
14411
14412        @torch.jit._overload  # noqa: F811
14413        def test_simple(x1):  # noqa: F811
14414            # type: (int) -> int
14415            pass
14416
14417        @torch.jit._overload  # noqa: F811
14418        def test_simple(x1):  # noqa: F811
14419            # type: (float) -> float
14420            pass
14421
14422        def test_simple(x1):  # noqa: F811
14423            return x1
14424
14425        def invoke_function():
14426            return test_simple(1.0), test_simple(.5)
14427
14428        self.checkScript(invoke_function, ())
14429
14430        # testing that the functions are cached
14431        compiled_fns_1 = torch.jit._script._get_overloads(test_simple)
14432        compiled_fns_2 = torch.jit._script._get_overloads(test_simple)
14433        for a, b in zip(compiled_fns_1, compiled_fns_2):
14434            self.assertIs(a.graph, b.graph)
14435
14436        old_func = test_simple
14437
14438        # testing that new functions added work with caching
14439        @torch.jit._overload  # noqa: F811
14440        def test_simple(x1):  # noqa: F811
14441            # type: (str) -> str
14442            pass
14443
14444        @torch.jit.script
14445        def my_func():
14446            return old_func("hi")
14447
14448        # testing new function same qualified name
14449        @torch.jit._overload  # noqa: F811
14450        def test_simple(a, b):  # noqa: F811
14451            # type: (int, int) -> int
14452            pass
14453
14454        def test_simple(a, b):
14455            return a + b
14456
14457        @torch.jit.script
14458        def fn():
14459            return test_simple(3, 4)
14460
14461        self.assertEqual(fn(), 7)
14462
14463        # currently we take the default values have to be specified in the
14464        # overload as well - TODO take them from implementation and apply
14465        # where the type is valid.
14466        @torch.jit._overload  # noqa: F811
14467        def identity(x1):  # noqa: F811
14468            # type: (str) -> str
14469            pass
14470
14471        @torch.jit._overload  # noqa: F811
14472        def identity(x1):  # noqa: F811
14473            # type: (float) -> float
14474            pass
14475
14476        def identity(x1=1.0):  # noqa: F811
14477            return x1
14478
14479        def invoke():
14480            return identity(), identity(.5), identity("hi")
14481
14482        self.checkScript(invoke, ())
14483
14484        def schema_match_failure():
14485            return identity((1, 2))
14486
14487        thrown = False
14488        try:
14489            torch.jit.script(schema_match_failure)
14490        except Exception as e:
14491            thrown = True
14492            self.assertTrue(r"of type 'str'" in str(e) and r"of type 'float" in str(e))
14493        self.assertTrue(thrown)
14494
14495        with self.assertRaisesRegex(Exception, "cannot be directly compiled"):
14496            torch.jit.script(identity)
14497
14498        @torch.jit._overload  # noqa: F811
14499        def impl_compile_failure(x, y):  # noqa: F811
14500            # type: (str, str) -> (str)
14501            pass
14502
14503        @torch.jit._overload  # noqa: F811
14504        def impl_compile_failure(x, y):  # noqa: F811
14505            # type: (int, int) -> (int)
14506            pass
14507
14508        def impl_compile_failure(x, y):  # noqa: F811
14509            return x - y
14510
14511        def test():
14512            impl_compile_failure("one", "two")
14513
14514
14515        with self.assertRaisesRegex(Exception, "Arguments for call are not valid"):
14516            torch.jit.script(test)
14517
14518        @torch.jit._overload  # noqa: F811
14519        def good_overload(x=1):  # noqa: F811
14520            # type: (int) -> (int)
14521            pass
14522
14523        def good_overload(x=1):  # noqa: F811
14524            return x
14525
14526        @torch.jit.script
14527        def foo():
14528            return good_overload()
14529
14530        self.assertEqual(foo(), 1)
14531
14532
14533        with self.assertRaisesRegex(Exception, "must equal to the default parameter"):
14534            @torch.jit._overload  # noqa: F811
14535            def bad_default_on_overload(x, y=2):  # noqa: F811
14536                # type: (int, int) -> (int)
14537                pass
14538
14539            def bad_default_on_overload(x, y=1):  # noqa: F811
14540                # type: (int, int) -> (int)
14541                pass
14542
14543            @torch.jit.script
14544            def test():
14545                return bad_default_on_overload(1, 2)
14546
14547        @torch.jit._overload  # noqa: F811
14548        def diff_default(x):  # noqa: F811
14549            # type: (int) -> int
14550            pass
14551
14552        @torch.jit._overload  # noqa: F811
14553        def diff_default(x):  # noqa: F811
14554            # type: (str) -> str
14555            pass
14556
14557        def diff_default(x="hi"):  # noqa: F811
14558            return x
14559
14560        def test():
14561            return diff_default(), diff_default(2), diff_default("abc")
14562
14563        self.assertEqual(test(), torch.jit.script(test)())
14564
14565        @torch.jit._overload  # noqa: F811
14566        def diff_num_params(x):  # noqa: F811
14567            # type: (float) -> float
14568            pass
14569
14570        @torch.jit._overload  # noqa: F811
14571        def diff_num_params(x, y):  # noqa: F811
14572            # type: (int, int) -> int
14573            pass
14574
14575        def diff_num_params(x, y=2, z=3):  # noqa: F811
14576            # type: (Union[float, int], int, int)
14577            return x + y + z
14578
14579        def test():
14580            return diff_num_params(1.0), diff_num_params(1, 2), diff_num_params(1), diff_num_params(1, 2, 3)
14581
14582        self.assertEqual(test(), torch.jit.script(test)())
14583
14584        @torch.jit._overload  # noqa: F811
14585        def diff_num_params_no_annot():
14586            # type: () -> int
14587            pass
14588
14589        def diff_num_params_no_annot(x=1):    # noqa: F811
14590            return x
14591
14592        def test():
14593            return diff_num_params_no_annot(1.0)
14594
14595        with self.assertRaisesRegex(Exception, "Parameters not specified"):
14596            torch.jit.script(test)
14597
14598    def test_function_overload_misuse(self):
14599        with self.assertRaisesRegex(RuntimeError, "Only `pass` statement or `...` can be the body"):
14600            @torch.jit._overload
14601            def wrong_decl_body(x: str) -> str:
14602                return x + "0"
14603
14604        with self.assertRaisesRegex(RuntimeError, "Only `pass` statement or `...` can be the body"):
14605            class MyClass:
14606                @torch.jit._overload_method
14607                def method(self):
14608                    return 0
14609
14610        @torch.jit._overload
14611        def null_overload(x: int) -> int: ...  # noqa: E704
14612
14613        @torch.jit._overload  # noqa: F811
14614        def null_overload(x: str) -> str:  # noqa: F811
14615            pass
14616
14617        def null_overload_driver():
14618            return null_overload(0)
14619
14620        with self.assertRaisesRegex(RuntimeError, 'Implementation for the function ".+" is missing.'):
14621            torch.jit.script(null_overload_driver)
14622
14623        class OverloadMisuse(torch.nn.Module):
14624            @torch.jit._overload_method
14625            def forward(self, x: int):
14626                pass
14627
14628            @torch.jit._overload_method  # noqa: F811
14629            def forward(self, x: Tensor):  # noqa: F811
14630                pass
14631
14632        with self.assertRaisesRegex(RuntimeError, 'Implementation for the method ".+" is missing.'):
14633            m = torch.jit.script(OverloadMisuse())
14634
14635
14636    def test_script_method_torch_function_overload(self):
14637        class MyCustomTensor(torch.Tensor):
14638            pass
14639
14640        class MyCustomModule(torch.nn.Module):
14641            def forward(self, x):
14642                return torch.relu(x)
14643
14644        scripted_mod = torch.jit.script(MyCustomModule())
14645        t = torch.tensor([3.0])
14646        ref_out = scripted_mod(t)
14647
14648        t_custom = MyCustomTensor([3.0])
14649        out1 = scripted_mod(t_custom)
14650        self.assertEqual(out1, ref_out)
14651
14652        out2 = scripted_mod.forward(t_custom)
14653        self.assertEqual(out2, ref_out)
14654
14655    def test_function_overloading_isinstance(self):
14656        @torch.jit._overload  # noqa: F811
14657        def my_conv(x, y):  # noqa: F811
14658            # type: (float, str) -> (float)
14659            pass
14660
14661        @torch.jit._overload  # noqa: F811
14662        def my_conv(x, y):  # noqa: F811
14663            # type: (float, float) -> (float)
14664            pass
14665
14666        def my_conv(x, y=2.0):  # noqa: F811
14667            if isinstance(y, str):
14668                if y == "hi":
14669                    return 4.0 - x
14670                else:
14671                    return 5.0 - x
14672            else:
14673                return 2.0 + x
14674
14675        def test_uses():
14676            return my_conv(1.5), my_conv(1.5, "hi"), my_conv(1.5, 5.0)
14677
14678        self.checkScript(test_uses, ())
14679
14680    def test_method_overloading(self):
14681        class Over(torch.nn.Module):
14682            @torch.jit._overload_method  # noqa: F811
14683            def forward(self, x):  # noqa: F811
14684                # type: (Tuple[Tensor, Tensor]) -> Tensor
14685                pass
14686
14687            @torch.jit._overload_method  # noqa: F811
14688            def forward(self, x):  # noqa: F811
14689                # type: (Tensor) -> Tensor
14690                pass
14691
14692            def forward(self, x):  # noqa: F811
14693                if isinstance(x, Tensor):
14694                    return x + 20
14695                else:
14696                    return x[0] + 5
14697
14698        class S(torch.jit.ScriptModule):
14699            def __init__(self) -> None:
14700                super().__init__()
14701                self.weak = Over()
14702
14703            @torch.jit.script_method
14704            def forward(self, x):
14705                return self.weak(x) + self.weak((x, x))
14706
14707        s_mod = S()
14708        x = torch.ones(1)
14709        self.assertEqual(s_mod(x), x + 20 + 5 + x)
14710
14711        over = Over()
14712        self.assertEqual(over((x, x)), x + 5)
14713        self.assertEqual(over(x), x + 20)
14714
14715        class Unannotated(torch.nn.Module):
14716            @torch.jit._overload_method  # noqa: F811
14717            def hello(self, x):  # noqa: F811
14718                pass
14719
14720            @torch.jit._overload_method  # noqa: F811
14721            def hello(self, x):  # noqa: F811
14722                # type: (int) -> (int)
14723                pass
14724
14725            def hello(self, x):  # noqa: F811
14726                return x + 3
14727
14728            def forward(self):
14729                return self.hello(1), self.hello(.5)
14730
14731        w = Unannotated()
14732        with self.assertRaisesRegex(Exception, "explicitly add type annotations to overloaded functions"):
14733            torch.jit.script(w)
14734
14735        class CompileOverloadError(torch.nn.Module):
14736            @torch.jit._overload_method  # noqa: F811
14737            def hello(self, x):  # noqa: F811
14738                # type: (str) -> (int)
14739                pass
14740
14741            @torch.jit._overload_method  # noqa: F811
14742            def hello(self, x):  # noqa: F811
14743                # type: (int) -> (int)
14744                pass
14745
14746            def hello(self, x):  # noqa: F811
14747                return x + 1
14748
14749            def forward(self):
14750                return self.hello("hi"), self.hello(.5)
14751
14752        w = CompileOverloadError()
14753        with self.assertRaisesRegex(Exception, "but instead found type 'str'"):
14754            torch.jit.script(w)
14755
14756        # testing overload declared first, then non-overload
14757        with self.assertRaisesRegex(Exception, "Overloads are not useable when a module"):
14758            class W3(torch.nn.Module):
14759                @torch.jit._overload_method  # noqa: F811
14760                def forward(self, x):  # noqa: F811
14761                    # type: (int) -> int
14762                    pass
14763
14764                @torch.jit._overload_method  # noqa: F811
14765                def forward(self, x):  # noqa: F811
14766                    # type: (Tensor) -> Tensor
14767                    pass
14768
14769                def forward(self, x):  # noqa: F811
14770                    return x + 5
14771
14772            a = W3()
14773            b = torch.jit.script(a)
14774
14775            class W3(torch.nn.Module):
14776                def forward(self, x):  # noqa: F811
14777                    return x + 5 + 10
14778
14779            a = W3()
14780            b = torch.jit.script(a)
14781
14782        # testing non-overload declared first, then overload
14783        class W2(torch.nn.Module):
14784            def hello(self, x1, x2):
14785                return x1 + x2
14786
14787            def forward(self, x):
14788                return self.hello(x, x)
14789
14790        a = torch.jit.script(W2())
14791        self.assertEqual(a(torch.tensor(1)), torch.tensor(2))
14792
14793        class W2(torch.nn.Module):
14794            @torch.jit._overload_method  # noqa: F811
14795            def hello(self, x):  # noqa: F811
14796                pass
14797
14798            @torch.jit._overload_method  # noqa: F811
14799            def hello(self, x):  # noqa: F811
14800                # type: (int) -> (int)
14801                pass
14802
14803            def hello(self, x):  # noqa: F811
14804                return x + 5 + 10
14805
14806            def forward(self, x):
14807                return self.hello(1), self.hello(x)
14808
14809        with self.assertRaisesRegex(Exception, "Overloads are not useable when a module"):
14810            a = torch.jit.script(W2())
14811
14812    def test_narrow_copy(self):
14813        def foo(a):
14814            return a.narrow_copy(0, 0, 5)
14815
14816        self.checkScript(foo, [torch.rand(10)])
14817
14818    def test_select_after_chunk(self):
14819        def foo(x):
14820            chunked = torch.chunk(x, 1)
14821            foo = chunked[0]
14822            foo.add_(5)
14823            return x
14824
14825        self.checkScript(foo, [torch.rand(2, 3)])
14826
14827    def test_nn_LSTM_with_layers(self):
14828        class M(torch.jit.ScriptModule):
14829            def __init__(self) -> None:
14830                super().__init__()
14831                self.rnn = nn.LSTM(2, 3, 2, dropout=0)
14832
14833            @torch.jit.script_method
14834            def forward(self, x, lengths, h0, c0):
14835                return self.rnn(x, (h0, c0))[0]
14836
14837        class Eager(torch.nn.Module):
14838            def __init__(self) -> None:
14839                super().__init__()
14840                self.rnn = nn.LSTM(2, 3, 2, dropout=0)
14841
14842            def forward(self, x, lengths, h0, c0):
14843                return self.rnn(x, (h0, c0))[0]
14844
14845        inputs = (torch.randn(1, 1, 2), torch.LongTensor([7]), torch.randn(2, 1, 3), torch.randn(2, 1, 3))
14846        eager_out = self.runAndSaveRNG(lambda: Eager()(*inputs), ())[0]
14847        script_out = self.runAndSaveRNG(lambda: M()(*inputs), ())[0]
14848
14849        self.assertEqual(eager_out, script_out)
14850
14851    def test_nn_LSTM(self):
14852        input = torch.nn.utils.rnn.pack_sequence([torch.randn(5, 5)])
14853
14854        class S(torch.jit.ScriptModule):
14855            def __init__(self) -> None:
14856                super().__init__()
14857                self.x = torch.nn.LSTM(5, 5)
14858
14859            @torch.jit.script_method
14860            def forward(self, input: PackedSequence) -> Tuple[PackedSequence, Tuple[torch.Tensor, torch.Tensor]]:
14861                return self.x(input)
14862
14863        eager_out = self.runAndSaveRNG(lambda x: torch.nn.LSTM(5, 5)(x), (input,))[0]
14864        script_out = self.runAndSaveRNG(lambda x: S()(x), (input,))[0]
14865
14866        self.assertEqual(eager_out, script_out)
14867
14868    def test_nn_GRU(self):
14869        seq_input = torch.nn.utils.rnn.pack_sequence([torch.randn(5, 5)])
14870        tensor_input = torch.randn(5, 5, 5)
14871
14872        class SeqLengthGRU(torch.jit.ScriptModule):
14873            def __init__(self) -> None:
14874                super().__init__()
14875                self.x = torch.nn.GRU(5, 5)
14876
14877            @torch.jit.script_method
14878            def forward(self, input: PackedSequence) -> Tuple[PackedSequence, torch.Tensor]:
14879                return self.x(input)
14880
14881        class TensorGRU(torch.jit.ScriptModule):
14882            def __init__(self) -> None:
14883                super().__init__()
14884                self.x = torch.nn.GRU(5, 5)
14885
14886            @torch.jit.script_method
14887            def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
14888                return self.x(input)
14889
14890        seq_eager_out = self.runAndSaveRNG(lambda x: torch.nn.GRU(5, 5)(x), (seq_input,))[0]
14891        seq_script_out = self.runAndSaveRNG(lambda x: SeqLengthGRU()(x), (seq_input,))[0]
14892        tensor_eager_out = self.runAndSaveRNG(lambda x: torch.nn.GRU(5, 5)(x), (tensor_input,))[0]
14893        tensor_script_out = self.runAndSaveRNG(lambda x: TensorGRU()(x), (tensor_input,))[0]
14894
14895        self.assertEqual(seq_eager_out, seq_script_out)
14896        self.assertEqual(tensor_eager_out, tensor_script_out)
14897
14898    def test_torchscript_memoryformat(self):
14899        @torch.jit.script
14900        def fn(x):
14901            return x.contiguous(memory_format=torch.channels_last)
14902        x = torch.randn(4, 3, 6, 6)
14903        y = fn(x)
14904        self.assertTrue(y.is_contiguous(memory_format=torch.channels_last))
14905
14906    def test_torchscript_multi_head_attn(self):
14907        @torch.jit.script
14908        def jit_multihead_attn_forward(query,                   # type: Tensor
14909                                       key,                     # type: Tensor
14910                                       value,                   # type: Tensor
14911                                       embed_dim_to_check,      # type: int
14912                                       num_heads,               # type: int
14913                                       in_proj_weight,          # type: Tensor
14914                                       in_proj_bias,            # type: Tensor
14915                                       bias_k,                  # type: Optional[Tensor]
14916                                       bias_v,                  # type: Optional[Tensor]
14917                                       add_zero_attn,           # type: bool
14918                                       dropout,                 # type: float
14919                                       out_proj_weight,         # type: Tensor
14920                                       out_proj_bias,           # type: Tensor
14921                                       training=True,           # type: bool
14922                                       key_padding_mask=None,   # type: Optional[Tensor]
14923                                       need_weights=True,       # type: bool
14924                                       attn_mask=None           # type: Optional[Tensor]
14925                                       ):
14926            # type: (...) -> Tuple[Tensor, Optional[Tensor]]
14927            return torch.nn.functional.multi_head_attention_forward(query, key, value,
14928                                                                    embed_dim_to_check, num_heads,
14929                                                                    in_proj_weight, in_proj_bias,
14930                                                                    bias_k, bias_v,
14931                                                                    add_zero_attn, dropout,
14932                                                                    out_proj_weight, out_proj_bias,
14933                                                                    training, key_padding_mask,
14934                                                                    need_weights, attn_mask)
14935
14936        src_l = 3
14937        bsz = 5
14938        embed_size = 8
14939        nhead = 2
14940        multi_head_attn = torch.nn.MultiheadAttention(embed_size, nhead)
14941        query = torch.rand((src_l, bsz, embed_size))
14942        key = torch.rand((src_l, bsz, embed_size))
14943        value = torch.rand((src_l, bsz, embed_size))
14944
14945        mask = (torch.triu(torch.ones(src_l, src_l)) == 1).transpose(0, 1)
14946        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, 0.0).to(torch.get_default_dtype())
14947
14948        jit_out = jit_multihead_attn_forward(query, key, value,
14949                                             embed_size, nhead,
14950                                             multi_head_attn.in_proj_weight,
14951                                             multi_head_attn.in_proj_bias,
14952                                             multi_head_attn.bias_k, multi_head_attn.bias_v,
14953                                             multi_head_attn.add_zero_attn, multi_head_attn.dropout,
14954                                             multi_head_attn.out_proj.weight,
14955                                             multi_head_attn.out_proj.bias, attn_mask=mask)[0]
14956
14957        py_out = torch.nn.functional.multi_head_attention_forward(query, key, value,
14958                                                                  embed_size, nhead,
14959                                                                  multi_head_attn.in_proj_weight,
14960                                                                  multi_head_attn.in_proj_bias,
14961                                                                  multi_head_attn.bias_k,
14962                                                                  multi_head_attn.bias_v,
14963                                                                  multi_head_attn.add_zero_attn,
14964                                                                  multi_head_attn.dropout,
14965                                                                  multi_head_attn.out_proj.weight,
14966                                                                  multi_head_attn.out_proj.bias,
14967                                                                  attn_mask=mask)[0]
14968        # print("rel. error: ")
14969        # print(jit_out / py_out - 1)
14970        self.assertEqual(jit_out, py_out, atol=5e-4, rtol=1e-4)
14971
14972    def test_torchscript_multi_head_attn_fast_path(self):
14973        src_l = 3
14974        bsz = 5
14975        embed_size = 8
14976        nhead = 2
14977        multi_head_attn = torch.nn.MultiheadAttention(embed_size, nhead, batch_first=True)
14978        multi_head_attn = multi_head_attn.eval()
14979
14980        query = key = value = torch.rand((bsz, src_l, embed_size))
14981
14982        with torch.no_grad():
14983            py_out = multi_head_attn(query, key, value)
14984            mha = torch.jit.script(multi_head_attn)
14985            jit_out = mha(query, key, value)
14986        torch.testing.assert_close(jit_out, py_out)
14987
14988    @unittest.skipIf(not RUN_CUDA, "no CUDA")
14989    def test_scriptmodule_multi_head_attn_cuda(self):
14990
14991        class MyModule(torch.jit.ScriptModule):
14992            def __init__(self, embed_dim, num_heads):
14993                super().__init__()
14994                sample_q = torch.randn(3, 2, embed_dim)
14995                sample_kv = torch.randn(3, 2, embed_dim)
14996                attention = nn.MultiheadAttention(embed_dim, num_heads)
14997                attention.eval()
14998
14999                self.mod = torch.jit.trace(attention,
15000                                           (sample_q, sample_kv, sample_kv))
15001
15002            @torch.jit.script_method
15003            def forward(self, q, k, v):
15004                return self.mod(q, k, v)
15005
15006        embed_dim = 8
15007        num_heads = 2
15008        sl = 3
15009        bs = 2
15010        model = MyModule(embed_dim, num_heads).cuda()
15011        q = torch.randn(sl, bs, embed_dim, device="cuda")
15012        kv = torch.randn(sl, bs, embed_dim, device="cuda")
15013
15014        jit_out = model(q, kv, kv)[0]
15015        py_out = torch.nn.functional.multi_head_attention_forward(q, kv, kv,
15016                                                                  embed_dim, num_heads,
15017                                                                  model.mod.in_proj_weight,
15018                                                                  model.mod.in_proj_bias,
15019                                                                  None, None, None, 0.0,
15020                                                                  model.mod.out_proj.weight,
15021                                                                  model.mod.out_proj.bias)[0]
15022        self.assertEqual(jit_out, py_out, atol=5e-4, rtol=1e-4)
15023
15024    @unittest.skipIf(not RUN_CUDA, "no CUDA")
15025    def test_scriptmodule_transformer_cuda(self):
15026
15027        class MyModule(torch.jit.ScriptModule):
15028            def __init__(self, transformer, sample_q, sample_kv):
15029                super().__init__()
15030                transformer.eval()
15031
15032                self.mod = torch.jit.trace(transformer,
15033                                           (sample_q, sample_kv))
15034
15035            @torch.jit.script_method
15036            def forward(self, q, k):
15037                return self.mod(q, k)
15038
15039        d_model = 8
15040        nhead = 2
15041        num_encoder_layers = 2
15042        num_decoder_layers = 2
15043        dim_feedforward = 16
15044        bsz = 2
15045        seq_length = 5
15046        tgt_length = 3
15047
15048        with torch.no_grad():
15049            src = torch.randn(seq_length, bsz, d_model)
15050            tgt = torch.randn(tgt_length, bsz, d_model)
15051            transformer = nn.Transformer(d_model, nhead, num_encoder_layers,
15052                                         num_decoder_layers, dim_feedforward, dropout=0.0)
15053            model = MyModule(transformer, tgt, src)
15054
15055            src = torch.randn(seq_length, bsz, d_model)
15056            tgt = torch.randn(tgt_length, bsz, d_model)
15057            jit_out = model(tgt, src)
15058            py_out = transformer(tgt, src)
15059
15060            # print(jit_out/py_out-1)
15061            # print(torch.allclose(jit_out, py_out, atol=5e-4, rtol=1e-4))
15062        self.assertEqual(jit_out, py_out, atol=5e-4, rtol=1e-4)
15063
15064    def test_list_python_op(self):
15065        def python_list_op(lst):
15066            # type: (List[Tensor]) -> Tensor
15067            return lst[0]
15068
15069        def fn(lst):
15070            # type: (List[Tensor]) -> Tensor
15071            return python_list_op(lst)
15072
15073        self.checkScript(fn, ([torch.ones(2) + 2, torch.ones(2)],))
15074
15075    @unittest.skipIf(not RUN_CUDA, "no CUDA")
15076    def test_weak_cuda(self):
15077        class M(torch.jit.ScriptModule):
15078            def __init__(self) -> None:
15079                super().__init__()
15080                self.lstm = torch.nn.LSTM(5, 5)
15081                self.lstm.cuda()
15082
15083            @torch.jit.script_method
15084            def forward(self, x):
15085                return self.lstm(x)
15086
15087        m = M()
15088        m.cuda()
15089        out = m(torch.ones(5, 5, 5).cuda())
15090        self.assertTrue(out[0].is_cuda)
15091
15092    def test_ignore_decorator(self):
15093        with warnings.catch_warnings(record=True) as warns:
15094            class M(torch.jit.ScriptModule):
15095                def __init__(self) -> None:
15096                    super().__init__()
15097                    tensor = torch.zeros(1, requires_grad=False)
15098                    self.some_state = nn.Buffer(torch.nn.Parameter(tensor))
15099
15100                @torch.jit.script_method
15101                def forward(self, x):
15102                    self.ignored_code(x)
15103                    return x
15104
15105                @torch.jit.ignore(drop_on_export=True)
15106                def ignored_code(self, x):
15107                    self.some_state = torch.tensor((100,))
15108
15109        FileCheck().check("TorchScript will now drop the function").run(str(warns[0]))
15110
15111        # Assert ignored code is run
15112        m = M()
15113
15114        m2 = self.getExportImportCopy(m)
15115        pp = str(m2.forward.code)
15116        self.assertNotIn('ignored_code', pp)
15117
15118        with self.assertRaisesRegex(torch.jit.Error, "annotated to be ignored and cannot be run"):
15119            m2.forward(torch.ones(1))
15120
15121    def test_ignored_as_value(self):
15122        class Model(nn.Module):
15123            @torch.jit.unused
15124            def tuple_ignored(self, x):
15125                # type: (Tensor) -> Tuple[Tensor, Tensor]
15126                return x, x
15127
15128            @torch.jit.unused
15129            def single_val_ignored(self, x, y):
15130                # type: (Tensor, Tensor) -> Tensor
15131                return x
15132
15133            def forward(self, x, use_ignore_path):
15134                # type: (Tensor, bool) -> Tuple[Tensor, Tensor]
15135                if 1 == 2:
15136                    return self.tuple_ignored(x)
15137                if use_ignore_path:
15138                    return self.single_val_ignored(x, x), self.single_val_ignored(x, x)
15139                return x, x
15140
15141        original = Model()
15142        scripted = torch.jit.script(original)
15143        self.assertEqual(scripted(torch.tensor(.5), False), (torch.tensor(.5), torch.tensor(.5)))
15144
15145        buffer = io.BytesIO()
15146        torch.jit.save(scripted, buffer)
15147        buffer.seek(0)
15148        loaded = torch.jit.load(buffer)
15149
15150        with self.assertRaisesRegex(torch.jit.Error, "annotated to be ignored and cannot be run"):
15151            loaded(torch.tensor(.5), True)
15152
15153    def test_module_error(self):
15154        class MyModule(torch.nn.Module):
15155            def forward(self, foo):
15156                return foo
15157
15158        with self.assertRaisesRegex(RuntimeError, "cannot be compiled since it inherits from nn.Module"):
15159            torch.jit.script(MyModule)
15160
15161    def test_view_write(self):
15162        def fn(x, y):
15163            l = []
15164            l.append(x)
15165            x_view = l[0]
15166            a = x + x
15167            x_view.add_(y)
15168            b = x + x
15169            return a == b
15170        self.checkScript(fn, (torch.rand(2, 3), torch.rand(2, 3)))
15171
15172    def test_module_attrs(self):
15173        class M(torch.jit.ScriptModule):
15174            def __init__(self, table):
15175                super().__init__()
15176                self.table = torch.jit.Attribute(table, Dict[str, torch.Tensor])
15177                self.x = torch.nn.Parameter(torch.tensor([100.0]))
15178
15179            @torch.jit.script_method
15180            def forward(self, key):
15181                # type: (str) -> Tensor
15182                return self.table[key] + self.x
15183
15184        with torch._jit_internal._disable_emit_hooks():
15185            # TODO: re-enable module hook when Python printing of attributes is
15186            # supported
15187            m = M({char : torch.ones(1) + ord(char) - ord("a") for char in "abcdefg"})
15188            self.assertEqual(m("c"), torch.tensor([103.]))
15189
15190    def test_module_none_attrs(self):
15191        class MyMod(torch.jit.ScriptModule):
15192            def __init__(self) -> None:
15193                super().__init__()
15194                self.optional_value = None
15195
15196            @torch.jit.script_method
15197            def forward(self):
15198                return self.optional_value
15199
15200        graph = MyMod().forward.graph
15201        FileCheck().check("prim::GetAttr").run(graph)
15202        self.run_pass('peephole', graph)
15203        FileCheck().check_not("prim::GetAttr").run(graph)
15204
15205    def test_tensor_import_export(self):
15206        @torch.jit.script
15207        def foo(x):
15208            a = torch.tensor(1)
15209            b = torch.tensor([1, 2])
15210            c = [a, b]
15211            return c
15212
15213        self.run_pass('constant_propagation', foo.graph)
15214        m = self.createFunctionFromGraph(foo.graph)
15215        self.getExportImportCopy(m)
15216
15217    def get_pickle_values(self):
15218        return (('dict', {"I": "am", "a test": "test"}, Dict[str, str]),
15219                ('float', 2.3, float),
15220                ('int', 99, int),
15221                ('bool', False, bool),
15222                ('tuple', (1, 2, 3, 4), Tuple[int, int, int, int]),
15223                ('list', [(1, 2), (3, 4)], List[Tuple[int, int]]),
15224                ('tensor', torch.randn(2, 2), torch.Tensor),
15225                ('int_list', [1, 2, 3, 4], List[int]),
15226                ('tensor_list', [torch.ones(2, 2) + i for i in range(4)], List[torch.Tensor]),
15227                ('bool_list', [True, True, False, True], List[bool]),
15228                ('float_list', [1., 2., 3., 4.], List[float]),
15229                ('str_list', ['hello', 'bye'], List[str]),
15230                ('none', None, Optional[int]),
15231                ('a_device', torch.device('cpu'), torch.device),
15232                ('another_device', torch.device('cuda:1'), torch.device))
15233
15234    def test_attribute_serialization(self):
15235        tester = self
15236
15237        class M(torch.jit.ScriptModule):
15238            def __init__(self) -> None:
15239                super().__init__()
15240                for name, value, the_type in tester.get_pickle_values():
15241                    setattr(self, name, torch.jit.Attribute(value, the_type))
15242
15243            @torch.jit.script_method
15244            def forward(self):
15245                return (self.dict, self.float, self.int, self.bool, self.tuple,
15246                        self.list, self.int_list, self.tensor_list, self.bool_list,
15247                        self.float_list, self.str_list, self.none)
15248
15249        m = M()
15250        imported_m = self.getExportImportCopy(m)
15251        self.assertEqual(m(), imported_m())
15252
15253    def test_string_len(self):
15254        def fn(x):
15255            # type: (str) -> int
15256            return len(x)
15257
15258        self.checkScript(fn, ("",))
15259        self.checkScript(fn, ("h",))
15260        self.checkScript(fn, ("hello",))
15261
15262    def test_multiline_optional_future_refinement(self):
15263        @torch.jit.script
15264        def fun() -> int:
15265            future: Optional[
15266                torch.jit.Future[Tuple[torch.Tensor]]
15267            ] = None
15268
15269            return 1
15270        self.assertEqual(fun(), 1)
15271
15272    @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle")
15273    def test_attribute_unpickling(self):
15274        tensor = torch.randn(2, 2)
15275        tester = self
15276
15277        class M(torch.jit.ScriptModule):
15278            def __init__(self) -> None:
15279                super().__init__()
15280                for name, value, the_type in tester.get_pickle_values():
15281                    setattr(self, "_" + name, torch.jit.Attribute(value, the_type))
15282
15283            @torch.jit.script_method
15284            def forward(self):
15285                return (self._dict, self._float, self._int, self._bool, self._tuple,
15286                        self._list, self._int_list, self._tensor_list, self._bool_list,
15287                        self._float_list, self._str_list, self._none)
15288
15289        with TemporaryFileName() as fname:
15290            M().save(fname)
15291            loaded = torch.jit.load(fname)
15292
15293            def is_tensor_value(item):
15294                if isinstance(item, torch.Tensor):
15295                    return True
15296                if isinstance(item, list):
15297                    return is_tensor_value(item[0])
15298                return False
15299            for name, value, the_type in self.get_pickle_values():
15300                if is_tensor_value(value):
15301                    continue
15302                self.assertEqual(value, getattr(loaded, "_" + name))
15303
15304
15305    def test_submodule_attribute_serialization(self):
15306        class S(torch.jit.ScriptModule):
15307            def __init__(self, list_data):
15308                super().__init__()
15309                self.table = torch.jit.Attribute({"I": "am", "a test": "test"}, Dict[str, str])
15310                self.list = torch.jit.Attribute(list_data, List[Tuple[int, int]])
15311
15312            @torch.jit.script_method
15313            def forward(self):
15314                return (self.table, self.list)
15315
15316        class M(torch.jit.ScriptModule):
15317            def __init__(self) -> None:
15318                super().__init__()
15319                self.table = torch.jit.Attribute({"this": "is", "a different": "dict"}, Dict[str, str])
15320                self.tensor = torch.jit.Attribute(torch.randn(2, 2), torch.Tensor)
15321                self.s1 = S([(1, 2)])
15322                self.s2 = S([(4, 5)])
15323
15324            @torch.jit.script_method
15325            def forward(self):
15326                return (self.table, self.tensor, self.s1.table, self.s2.list, self.s1.list)
15327
15328        m = M()
15329        imported_m = self.getExportImportCopy(m)
15330        self.assertEqual(m(), imported_m())
15331
15332    def test_serialization_big_ints(self):
15333        class M(torch.jit.ScriptModule):
15334            def __init__(self) -> None:
15335                super().__init__()
15336                self.int32_max = torch.jit.Attribute(2**31 - 1, int)
15337                self.int32_min = torch.jit.Attribute(-2**31, int)
15338                self.uint32_max = torch.jit.Attribute(2**32, int)
15339
15340                self.int64_max = torch.jit.Attribute(2**63 - 1, int)
15341                self.int64_min = torch.jit.Attribute(-2**63, int)
15342
15343                self.tensor = torch.nn.Parameter(torch.ones(2, 2))
15344
15345            @torch.jit.script_method
15346            def forward(self, x):
15347                # type: (int) -> (int)
15348                return x + (self.int32_max + self.int32_min) + (self.int64_max + self.int64_min)
15349
15350        m = M()
15351        imported = self.getExportImportCopy(m)
15352        self.assertEqual(m(10), imported(10))
15353
15354        self.assertEqual(m.int32_max, imported.int32_max)
15355        self.assertEqual(m.int32_min, imported.int32_min)
15356        self.assertEqual(m.uint32_max, imported.uint32_max)
15357        self.assertEqual(m.int64_max, imported.int64_max)
15358        self.assertEqual(m.int64_min, imported.int64_min)
15359
15360    def test_script_scope(self):
15361        scripted = torch.jit.script(torch.nn.functional.triplet_margin_loss)
15362
15363    @unittest.skipIf(IS_WINDOWS, "NYI: TemporaryFileName on Windows")
15364    def test_serialization_sharing(self):
15365        class M(torch.jit.ScriptModule):
15366            def __init__(self) -> None:
15367                super().__init__()
15368                self.list = torch.jit.Attribute([], List[str])
15369
15370            @torch.jit.script_method
15371            def forward(self, key):
15372                # type: (str) -> List[str]
15373                self.list.append(key)
15374                self.list.append(key)
15375                self.list.append(key)
15376                return self.list
15377
15378        # the text of the string should only appear once in the pickling
15379        m = M()
15380        s1 = "a long string"
15381        s2 = "a different, even longer string"
15382        self.assertEqual(m(s1), [s1] * 3)
15383        self.assertEqual(m(s2), [s1] * 3 + [s2] * 3)
15384        with TemporaryFileName() as fname:
15385            m.save(fname)
15386            archive_name = os.path.basename(os.path.normpath(fname))
15387            archive = zipfile.ZipFile(fname, 'r')
15388            pickled_data = archive.read(os.path.join(archive_name, 'data.pkl'))
15389
15390            out = io.StringIO()
15391            pickletools.dis(pickled_data, out=out)
15392            disassembled = out.getvalue()
15393
15394            FileCheck().check_count(s1, 1, exactly=True) \
15395                .check_count("BINGET", 2, exactly=True) \
15396                .check_count(s2, 1, exactly=True) \
15397                .check_count("BINGET", 2, exactly=True).run(out.getvalue())
15398
15399    def test_sys_stdout_override(self):
15400        @torch.jit.script
15401        def foo():
15402            print('foo')
15403
15404        class Redirect:
15405            def __init__(self) -> None:
15406                self.s = ''
15407
15408            def write(self, s):
15409                self.s += s
15410
15411        old_stdout = sys.stdout
15412        redirect = Redirect()
15413        try:
15414            sys.stdout = redirect
15415            foo()
15416        finally:
15417            sys.stdout = old_stdout
15418
15419        FileCheck().check('foo').run(redirect.s)
15420
15421    def test_dtype_attr(self):
15422        class Foo(torch.nn.Module):
15423            def __init__(self) -> None:
15424                super().__init__()
15425                self.dtype = torch.zeros([]).dtype
15426
15427            def forward(self):
15428                return torch.zeros(3, 4, dtype=self.dtype)
15429
15430        f = Foo()
15431        torch.jit.script(f)
15432
15433
15434    def test_named_buffers_are_iterable(self):
15435        class MyMod(torch.nn.Module):
15436            def __init__(self) -> None:
15437                super().__init__()
15438                self.mod = (torch.nn.ReLU())
15439                self.mod2 = (torch.nn.ReLU())
15440                self.mod3 = torch.nn.Sequential(torch.nn.Sequential(torch.nn.ReLU()))
15441                self.x = nn.Buffer(torch.zeros(3))
15442                self.y = nn.Buffer(torch.zeros(3))
15443                self.z = torch.zeros(3)
15444
15445            def bleh(self):
15446                return self.z + 4
15447
15448            @torch.jit.export
15449            def method(self):
15450                names = [""]
15451                vals = []
15452                for name, buffer in self.named_buffers():
15453                    names.append(name)
15454                    vals.append(buffer + 2)
15455
15456                return names, vals
15457
15458            def forward(self, x):
15459                return x
15460
15461        model = MyMod()
15462        x = torch.jit.script(model)
15463        z = self.getExportImportCopy(x)
15464
15465        self.assertEqual(z.method(), x.method())
15466        self.assertEqual(z.method(), model.method())
15467        self.assertEqual(x.method(), model.method())
15468        names = x.method()
15469        for name in names:
15470            self.assertNotEqual('z', name)
15471
15472
15473    def test_static_if_prop(self):
15474        class MaybeHasAttr(torch.nn.Module):
15475            def __init__(self, add_attr):
15476                super().__init__()
15477                if add_attr:
15478                    self.maybe_attr = 1
15479
15480            def forward(self):
15481                if hasattr(self, "maybe_attr") and True:
15482                    return self.maybe_attr
15483                else:
15484                    return 0
15485
15486        class MaybeHasAttr2(torch.nn.Module):
15487            def __init__(self, add_attr):
15488                super().__init__()
15489                if add_attr:
15490                    self.maybe_attr = 1
15491
15492            def forward(self):
15493                if not hasattr(self, "maybe_attr") or False:
15494                    return 0
15495                else:
15496                    return self.maybe_attr
15497
15498        torch.jit.script(MaybeHasAttr(True))
15499        torch.jit.script(MaybeHasAttr(False))
15500        torch.jit.script(MaybeHasAttr2(True))
15501        torch.jit.script(MaybeHasAttr2(False))
15502
15503        class MyMod(torch.nn.Module):
15504            def forward(self):
15505                if hasattr(self, "foo"):
15506                    return 1
15507                else:
15508                    return 0
15509
15510            @torch.jit.export
15511            def fee(self):
15512                return 1
15513
15514        self.checkModule(MyMod(), ())
15515
15516        class HasAttrMod(torch.nn.Module):
15517            __constants__ = ["fee"]
15518
15519            def __init__(self) -> None:
15520                super().__init__()
15521                self.fee = 3
15522
15523            def forward(self):
15524                a = hasattr(self, "fee")
15525                b = hasattr(self, "foo")
15526                c = hasattr(self, "hi")
15527                d = hasattr(self, "nonexistant")
15528                return (a, b, c, d)
15529
15530            def foo(self):
15531                return 1
15532
15533            @torch.jit._overload_method
15534            def hi(self, x: Tensor): ...  # noqa: E704
15535
15536            def hi(self, x):  # noqa: F811
15537                return 2
15538
15539        self.checkModule(HasAttrMod(), ())
15540
15541        @torch.jit.script
15542        class FooTest:
15543            def __init__(self) -> None:
15544                self.x = 1
15545
15546            def foo(self, y):
15547                return self.x + y
15548
15549        def foo():
15550            a = FooTest()
15551            val1 = hasattr(a, "foo"), hasattr(a, "x"), hasattr(a, "bla")
15552            val2 = hasattr(FooTest, "foo"), hasattr(FooTest, "a")
15553            return val1, val2
15554
15555        self.assertEqual(foo(), torch.jit.script(foo)())
15556
15557    def _test_pickle_checkpoint(self, device):
15558        with TemporaryFileName() as fname:
15559            class M(torch.jit.ScriptModule):
15560                __constants__ = ['fname']
15561
15562                def __init__(self, tensor):
15563                    super().__init__()
15564                    self.fname = fname
15565                    self.tensor = torch.nn.Parameter(tensor)
15566
15567                @torch.jit.script_method
15568                def forward(self, x):
15569                    y = self.tensor + x
15570                    torch.save(y, self.fname)
15571                    return y
15572
15573            param = torch.randn(2, 2).to(device)
15574            input = torch.randn(2, 2).to(device)
15575            m = M(param)
15576            m(input)
15577            with open(fname, "rb") as handle:
15578                loaded_tensor = torch.load(fname)
15579                self.assertEqual(loaded_tensor, input + param)
15580
15581    def _test_pickle_checkpoint_views(self, device):
15582        with TemporaryFileName() as fname:
15583            class M(torch.jit.ScriptModule):
15584                __constants__ = ['fname']
15585
15586                def __init__(self, tensor):
15587                    super().__init__()
15588                    self.fname = fname
15589                    self.tensor = torch.nn.Parameter(tensor)
15590
15591                @torch.jit.script_method
15592                def forward(self, x):
15593                    y = self.tensor + x
15594                    y_view = y.view(4)
15595                    torch.save((y, y_view, y), self.fname)
15596                    return y
15597
15598            param = torch.randn(2, 2).to(device)
15599            input = torch.randn(2, 2).to(device)
15600            m = M(param)
15601            m(input)
15602            with open(fname, "rb") as handle:
15603                loaded_y, loaded_y_view, loaded_y_2 = torch.load(fname)
15604                self.assertEqual(loaded_y, input + param)
15605                with torch.no_grad():
15606                    loaded_y_view[1] += 20
15607                    # assert that loaded_y changed as well
15608                    self.assertEqual(loaded_y.view(4), loaded_y_view)
15609                    self.assertEqual(loaded_y_2.view(4), loaded_y_view)
15610
15611    @unittest.skipIf(not RUN_CUDA, "no CUDA")
15612    def test_pickle_checkpoint_cuda(self):
15613        self._test_pickle_checkpoint('cuda')
15614        self._test_pickle_checkpoint_views('cuda')
15615
15616    def test_pickle_checkpoint(self):
15617        self._test_pickle_checkpoint('cpu')
15618        self._test_pickle_checkpoint_views('cpu')
15619
15620    def test_pickle_checkpoint_tup(self):
15621        @torch.jit.script
15622        def foo(fname):
15623            # type: (str) -> None
15624            torch.save((3, 4), fname)
15625        with TemporaryFileName() as name:
15626            foo(name)
15627            self.assertEqual(torch.load(name), (3, 4))
15628
15629    def test_string_list(self):
15630        def fn(string):
15631            # type: (str) -> List[str]
15632            return list(string)
15633
15634        self.checkScript(fn, ("abcdefgh",))
15635
15636    def test_unicode_comments(self):
15637        @torch.jit.script
15638        def test(self, a):
15639            # ��������
15640            return torch.nn.functional.relu(a)
15641
15642    def test_get_set_state_with_tensors(self):
15643        class M(torch.nn.Module):
15644            def __init__(self) -> None:
15645                super().__init__()
15646                self.tensor = torch.randn(2, 2)
15647
15648            @torch.jit.export
15649            def __getstate__(self):
15650                return (self.tensor, self.training)
15651
15652            @torch.jit.export
15653            def __setstate__(self, state):
15654                self.tensor = state[0]
15655                self.training = state[1]
15656
15657            def forward(self, x):
15658                return x + self.tensor
15659
15660        with TemporaryFileName() as fname:
15661            m = torch.jit.script(M())
15662            m.save(fname)
15663            loaded = torch.jit.load(fname)
15664            self.assertEqual(loaded.tensor, m.tensor)
15665
15666    def test_in_for_and_comp_expr(self):
15667        def fn(d):
15668            # type: (Dict[str, int]) -> List[int]
15669            out = [1]
15670            for i in range(d["hi"] if "hi" in d else 6):
15671                out.append(i)  # noqa: PERF402
15672            return out
15673
15674        self.checkScript(fn, ({'hi': 2, 'bye': 3},))
15675        self.checkScript(fn, ({'bye': 3},))
15676
15677    def test_for_else(self):
15678        def fn():
15679            c = 0
15680            for i in range(4):
15681                c += 10
15682            else:
15683                print("In else block of for...else")
15684
15685        with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, "else branches of for loops aren't supported"):
15686            torch.jit.script(fn)
15687
15688    def test_split(self):
15689        def split_two(tensor):
15690            a, b, c = torch.split(tensor, 2, dim=1)
15691            return a, b, c
15692        x = torch.randn(3, 6)
15693        y = torch.randn(3, 6)
15694        self.checkScript(split_two, [(x + y)])
15695
15696    def test_conv_error(self):
15697        @torch.jit.script
15698        def fn(x, y):
15699            return F.conv2d(x, y)
15700
15701        try:
15702            fn(torch.ones(2, 2), torch.ones(4, 4))
15703        except RuntimeError as e:
15704            self.assertFalse('frame' in str(e))
15705
15706    def test_python_op_name(self):
15707        import random
15708
15709        with self.assertRaisesRegex(RuntimeError, "randint"):
15710            @torch.jit.script
15711            def fn():
15712                return random.randint()
15713
15714    def test_dir(self):
15715        class M(torch.jit.ScriptModule):
15716            def forward(self, t):
15717                return t
15718
15719        self.assertTrue('forward' in dir(M()))
15720
15721    def test_kwarg_expansion_error(self):
15722        @torch.jit.ignore
15723        def something_else(h, i):
15724            pass
15725
15726        def fn(x):
15727            something_else(**x)
15728
15729        with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, "keyword-arg expansion is not supported"):
15730            torch.jit.script(fn)
15731
15732    def test_kwargs_error_msg(self):
15733        def other(**kwargs):
15734            print(kwargs)
15735
15736        def fn():
15737            return other()
15738
15739        with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, 'variable number'):
15740            torch.jit.script(fn)
15741
15742        def another_other(*args):
15743            print(args)
15744
15745        def another_fn():
15746            return another_other()
15747
15748        with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, 'variable number'):
15749            torch.jit.script(another_fn)
15750
15751    def test_inferred_error_msg(self):
15752        """
15753        Test that when we get a type mismatch on a function where we inferred
15754        the type to be tensor, a good error message is given.
15755        """
15756        @torch.jit.script
15757        def foo(a):
15758            return a
15759
15760        with self.assertRaisesRegex(RuntimeError, (r"Expected a value of type \'Tensor \(inferred\)\'"
15761                                                   r"[\S\s]*Inferred \'a\' to be of type \'Tensor\'")):
15762            foo("1")
15763
15764    def test_type_comments_in_body(self):
15765        @torch.jit.script
15766        def foo(a,  # type: int
15767                b,  # type: int
15768                ):
15769            # type: (...) -> int
15770            # type: int
15771            return a + b
15772
15773        class M(torch.nn.Module):
15774            def __init__(self,
15775                         a,  # type: int
15776                         b   # type: int
15777                         ):
15778                # type: (...) -> None
15779                super().__init__()
15780                self.a = a  # type: int
15781                self.b = b  # type: int
15782
15783        torch.jit.script(M(2, 3))
15784
15785    def test_input_keyword_in_schema(self):
15786        def f(x):
15787            return torch.ceil(input=x)
15788
15789        inp = torch.randn(10)
15790        self.checkScript(f, (inp, ))
15791
15792    def test_module_method_reassignment(self):
15793        class Foo(torch.nn.Module):
15794            def _forward(self, x):
15795                return x
15796
15797            forward = _forward
15798
15799        sm = torch.jit.script(Foo())
15800        input = torch.ones(2, 2)
15801        self.assertEqual(input, sm(input))
15802
15803    # Tests the case where a torch.Tensor subclass (like Parameter) is used as
15804    # input.
15805    def test_script_module_tensor_subclass_argument(self):
15806        @torch.jit.script
15807        def parameter_script(x: torch.nn.Parameter):
15808            return x
15809
15810        input = torch.ones(2, 2)
15811        self.assertEqual(input, parameter_script(input))
15812
15813    def test_save_load_attr_error(self):
15814        class Inner(nn.Module):
15815            def forward(self, x):
15816                return x
15817
15818        class Wrapper(nn.Module):
15819            def __init__(self, inner):
15820                super().__init__()
15821                self.inner = inner
15822
15823            def forward(self, x):
15824                # this attribute doesn't exist on `Inner`
15825                return self.inner.b(x)
15826
15827        inner_module = torch.jit.script(Inner())
15828        inner_module = self.getExportImportCopy(inner_module)
15829        wrapped = Wrapper(inner_module)
15830        # This should properly complain that `self.inner` doesn't have the attribute `b`
15831        with self.assertRaisesRegex(RuntimeError, 'has no attribute'):
15832            torch.jit.script(wrapped)
15833
15834    def test_rescripting_loaded_modules(self):
15835        class InnerSubmod(nn.Module):
15836            __constants__ = ['my_constant']
15837
15838            def __init__(self) -> None:
15839                super().__init__()
15840                self.foo = torch.nn.Buffer(torch.ones(1))
15841                self.register_parameter("bar", torch.nn.Parameter(torch.ones(1)))
15842                self.baz = torch.ones(1)
15843                self.my_constant = 1
15844
15845            def forward(self, x):
15846                return x + x
15847
15848        class Inner(nn.Module):
15849            def __init__(self) -> None:
15850                super().__init__()
15851                self.submod = InnerSubmod()
15852
15853            def forward(self, x):
15854                return self.submod(x)
15855
15856        class Wrapper(nn.Module):
15857            def __init__(self, inner):
15858                super().__init__()
15859                self.inner = inner
15860
15861            def forward(self, x):
15862                # access inner elements
15863                ret = self.inner.submod(x) + self.inner.submod.foo + self.inner.submod.bar + self.inner.submod.baz
15864                ret = ret + self.inner.submod.my_constant
15865                return ret
15866
15867        inner_module = torch.jit.script(Inner())
15868        wrapped = Wrapper(inner_module)
15869        self.checkModule(wrapped, torch.ones(1))
15870
15871        inner_module_loaded = self.getExportImportCopy(inner_module)
15872        wrapped_loaded = Wrapper(inner_module_loaded)
15873        self.assertEqual(wrapped(torch.ones(1)), wrapped_loaded(torch.ones(1)))
15874
15875    def test_interpret_graph(self):
15876        def fn(x):
15877            return x.unfold(0, 1, 1)
15878
15879        graph_str = """
15880        graph(%a : Tensor, %b : Tensor):
15881          %c : Tensor = aten::mul(%a, %b)
15882          return (%c)
15883        """
15884        graph = parse_ir(graph_str)
15885        a = torch.rand(10)
15886        b = torch.rand(10)
15887        test = torch._C._jit_interpret_graph(graph, (a, b))
15888        ref = a * b
15889        self.assertEqual(test, ref)
15890
15891    def test_signed_float_zero(self):
15892
15893        class MyModule(torch.nn.Module):
15894            def forward(self, x):
15895                return torch.div(x, -0.)
15896
15897        inp = torch.ones(1)
15898        self.checkModule(MyModule(), inp)
15899
15900    def test_index_with_tuple(self):
15901        class MyModule(torch.nn.Module):
15902            def forward(self, x):
15903                return x[(1,)]
15904
15905        self.checkModule(MyModule(), (torch.ones(2, 3),))
15906
15907    def test_context_manager(self):
15908        class MyModule(torch.nn.Module):
15909            def forward(self, x, y):
15910                p = x + y
15911                q = p + 2.0
15912                return q
15913
15914        x = torch.randn(3, 2, dtype=torch.float)
15915        y = torch.randn(3, 2, dtype=torch.float)
15916        for fuser_name in ['fuser0', 'fuser1', 'none']:
15917            with torch.jit.fuser(fuser_name):
15918                self.checkModule(MyModule(), (x, y))
15919
15920    def test_zero_dimension_tensor_trace(self):
15921        def f(x):
15922            return x[x > 0]
15923        jf = torch.jit.trace(f, torch.tensor(2., device="cpu"))
15924
15925# known to be failing in tracer
15926EXCLUDE_TRACED = {
15927    # The following fail due to #12024.
15928    # A prim::ListConstruct is involved and the indices get traced as TensorType,
15929    # which always require_grad. This causes a crash in autodiff.
15930    'test___getitem___adv_index',
15931    'test___getitem___adv_index_beg',
15932    'test___getitem___adv_index_comb',
15933    'test___getitem___adv_index_dup',
15934    'test___getitem___adv_index_sub',
15935    'test___getitem___adv_index_sub_2',
15936    'test___getitem___adv_index_sub_3',
15937    'test___getitem___adv_index_var',
15938
15939    # jit doesn't support sparse tensors.
15940    'test_to_sparse',
15941    'test_to_sparse_dim',
15942}
15943
15944EXCLUDE_TYPE_CHECK = {
15945    # slogdet tests use itemgetter to select its only differentiable output,
15946    # but this happens outside of the graph we handle, so there are fewer
15947    # reference outputs than graph outputs.
15948    'test_slogdet_1x1_neg_det',
15949    'test_slogdet_1x1_pos_det',
15950    'test_slogdet_distinct_singular_values',
15951    'test_slogdet_neg_det',
15952    'test_slogdet_pos_det',
15953    'test_slogdet_symmetric',
15954    'test_slogdet_symmetric_pd',
15955    'test_slogdet_batched_1x1_neg_det',
15956    'test_slogdet_batched_pos_det',
15957    'test_slogdet_batched_symmetric',
15958    'test_slogdet_batched_symmetric_pd',
15959    'test_slogdet_batched_distinct_singular_values'
15960}
15961
15962# chunk returns a list in scripting and we don't unpack the list,
15963# Thus it won't be replaced by ConstantChunk and run AD.
15964# It's explicitly checked in test_chunk_constant_script_ad
15965# Similary for split, it's replaced by split_with_sizes in tracing,
15966# but we don't have AD formula for aten::split(Tensor, int[], int),
15967# an op registered in JIT so AD is not triggered in scripting.
15968EXCLUDE_SCRIPT_AD_CHECK = {
15969    'test_chunk',
15970    'test_chunk_dim',
15971    'test_chunk_dim_neg0',
15972    'test_split_size_list',
15973    'test_split_size_list_dim',
15974    'test_split_size_list_dim_neg0',
15975    'test_tensor_indices_sections',
15976    'test_tensor_indices_sections_dim',
15977    'test_tensor_indices_sections_dim_neg0',
15978    'test_tensor_split_sections',
15979    'test_tensor_split_sections_dim',
15980    'test_tensor_split_sections_dim_neg0'
15981}
15982
15983EXCLUDE_PYTHON_PRINT = {
15984    # no support for BroadcastingList in python printer
15985    'test_nn_max_unpool1d',
15986    'test_nn_max_unpool2d',
15987    'test_nn_max_unpool3d',
15988    'test_nn_max_pool1d',
15989    'test_nn_max_pool2d',
15990    'test_nn_max_pool3d',
15991    'test_nn_max_pool1d_with_indices',
15992}
15993
15994EXCLUDE_ALIAS = {
15995    # aliases, which may appear in method_tests but are tested elsewhere
15996    'true_divide',
15997
15998    # Disable tests for lu from common_methods_invocations.py
15999    # TODO(@nikitaved) Enable jit tests once autograd.Function does support scripting
16000    'lu'
16001}
16002
16003
16004class TestJitGeneratedModule(JitTestCase):
16005    pass
16006
16007
16008class TestJitGeneratedFunctional(JitTestCase):
16009    pass
16010
16011# UBSAN per-function exclusions don't seem to work with OpenMP pragmas,
16012# and we have to disable the failing tests here instead.
16013UBSAN_DISABLED_TESTS = [
16014    "test___rdiv___constant",
16015    "test___rdiv___scalar_constant",
16016    "test_addcdiv",
16017    "test_addcdiv_broadcast_all",
16018    "test_addcdiv_broadcast_rhs",
16019    "test_addcdiv_scalar",
16020    "test_addcdiv_scalar_broadcast_lhs",
16021    "test_addcdiv_scalar_broadcast_rhs",
16022    "test_addcdiv_scalar_scale",
16023    "test_addcdiv_scalar_scale_broadcast_lhs",
16024    "test_addcdiv_scalar_scale_broadcast_rhs",
16025    "test_addcdiv_scale",
16026    "test_addcdiv_scale_broadcast_all",
16027    "test_addcdiv_scale_broadcast_rhs",
16028    "test_add_broadcast_all",
16029    "test_add_broadcast_lhs",
16030    "test_add_broadcast_rhs",
16031    "test_add_constant",
16032    "test_add_scalar",
16033    "test_add_scalar_broadcast_lhs",
16034    "test_add_scalar_broadcast_rhs",
16035    "test_div",
16036    "test_div_broadcast_all",
16037    "test_div_broadcast_lhs",
16038    "test_div_broadcast_rhs",
16039    "test_div_scalar",
16040    "test_div_scalar_broadcast_lhs",
16041    "test_div_scalar_broadcast_rhs",
16042    "test_rsqrt",
16043    "test_rsqrt_scalar",
16044    "test_add",
16045    "test_reciprocal",
16046    "test_reciprocal_scalar",
16047]
16048
16049L = 20
16050M = 10
16051S = 5
16052
16053def add_nn_module_test(*args, **kwargs):
16054    no_grad = False if 'no_grad' not in kwargs else kwargs['no_grad']
16055
16056    if 'desc' in kwargs and 'eval' in kwargs['desc']:
16057        # eval() is not supported, so skip these tests
16058        return
16059
16060    test_name = get_nn_mod_test_name(**kwargs)
16061
16062    @suppress_warnings
16063    def do_test(self):
16064        if test_name in EXCLUDE_SCRIPT_MODULES:
16065            return
16066        if not kwargs.get('check_jit', True):
16067            raise unittest.SkipTest('module test skipped on JIT')
16068
16069        default_dtype = torch.get_default_dtype()
16070        if 'default_dtype' in kwargs and kwargs['default_dtype'] is not None:
16071            default_dtype = kwargs['default_dtype']
16072
16073        module_name = get_nn_module_name_from_kwargs(**kwargs)
16074
16075        if 'constructor' in kwargs:
16076            nn_module = kwargs['constructor']
16077        else:
16078            nn_module = getattr(torch.nn, module_name)
16079
16080        if "FunctionalModule" in str(nn_module):
16081            return
16082
16083        with set_default_dtype(default_dtype):
16084            if 'constructor_args_fn' in kwargs:
16085                constructor_args = kwargs['constructor_args_fn']()
16086            else:
16087                constructor_args = kwargs.get('constructor_args', ())
16088
16089            def create_script_module(*args, **kwargs):
16090                """Construct a script module that passes arguments through to self.submodule"""
16091                formals, tensors, actuals = get_script_args(args)
16092
16093                method_args = ', '.join(['self'] + actuals)
16094                call_args_str = ', '.join(actuals)
16095                call = f"self.submodule({call_args_str})"
16096                script = script_method_template.format(method_args, call)
16097
16098                submodule_constants = []
16099                if kwargs.get('is_constant'):
16100                    submodule_constants = ['submodule']
16101
16102                # Create module to use the script method
16103                class TheModule(torch.jit.ScriptModule):
16104                    __constants__ = submodule_constants
16105
16106                    def __init__(self) -> None:
16107                        super().__init__()
16108                        self.submodule = nn_module(*constructor_args)
16109
16110                def make_module(script):
16111                    module = TheModule()
16112                    # check __repr__
16113                    str(module)
16114                    module.define(script)
16115                    return module
16116
16117                module = make_module(script)
16118                self.assertExportImportModule(module, tensors)
16119                create_script_module.last_graph = module.graph
16120                mod = module(*args)
16121                return mod
16122
16123            # Construct a normal nn module to stay consistent with create_script_module
16124            # and make use of a single global rng_state in module initialization
16125            def create_nn_module(*args, **kwargs):
16126                module = nn_module(*constructor_args)
16127                return module(*args)
16128
16129            # Set up inputs from tuple of sizes or constructor fn
16130            dtype = torch.get_default_dtype()
16131            if 'input_fn' in kwargs:
16132                input = kwargs['input_fn']()
16133                if isinstance(input, Tensor):
16134                    input = (input,)
16135
16136                if all(tensor.is_complex() for tensor in input):
16137                    if dtype == torch.float:
16138                        dtype = torch.cfloat
16139                    elif dtype == torch.double:
16140                        dtype = torch.cdouble
16141                    else:
16142                        raise AssertionError(f"default_dtype {default_dtype} is not supported")
16143
16144            else:
16145                input = (kwargs['input_size'],)
16146
16147            if 'target_size' in kwargs:
16148                input = input + (kwargs['target_size'],)
16149            elif 'target_fn' in kwargs:
16150                if torch.is_tensor(input):
16151                    input = (input,)
16152                input = input + (kwargs['target_fn'](),)
16153            elif 'target' in kwargs:
16154                input = input + (kwargs['target'],)
16155
16156            # Extra parameters to forward()
16157            if 'extra_args' in kwargs:
16158                input = input + kwargs['extra_args']
16159
16160            args_variable, kwargs_variable = create_input(input, dtype=dtype)
16161            f_args_variable = deepcopy(unpack_variables(args_variable))
16162
16163            # TODO(issue#52052) Neither this nor no_grad should be required
16164            # if check_against_reference() is updated to check gradients
16165            # w.r.t. weights and then only check w.r.t. inputs if any
16166            # inputs require it.
16167            any_requires_grad = any(input.requires_grad for input in f_args_variable)
16168
16169            # Check against Python module as reference
16170            check_against_reference(self, create_script_module, create_nn_module,
16171                                    lambda x: x, f_args_variable,
16172                                    no_grad=no_grad or not any_requires_grad)
16173
16174    if 'slowTest' in kwargs:
16175        do_test = slowTest(do_test)
16176
16177    post_add_test(test_name, (), do_test, TestJitGeneratedModule)
16178
16179
16180def post_add_test(test_name, skipTestIf, do_test, test_class):
16181    assert not hasattr(test_class, test_name), 'Two tests have the same name: ' + test_name
16182
16183    for skip in skipTestIf:
16184        do_test = skip(do_test)
16185
16186    if not (TEST_WITH_UBSAN and test_name in UBSAN_DISABLED_TESTS):
16187        setattr(test_class, test_name, do_test)
16188
16189
16190def normalize_check_ad(check_ad, name):
16191    # normalized check_ad is 3-element tuple: (bool, List[str], List[str])
16192    if len(check_ad) == 0:
16193        check_ad = [False, ['aten::' + name], []]
16194    elif len(check_ad) == 1:
16195        check_ad = [check_ad[0], ['aten::' + name], []]
16196    elif len(check_ad) == 2:
16197        check_ad = [check_ad[0], check_ad[1], []]
16198    elif len(check_ad) == 3:
16199        check_ad = list(check_ad)
16200    else:
16201        raise Exception('Invalid check_ad, requires (bool, str|List[str], str|List[str])')  # noqa: TRY002
16202
16203    check_ad = [[t] if isinstance(t, str) else t for t in check_ad]
16204
16205    return check_ad
16206
16207
16208class TestProducerVersion(TestCase):
16209
16210    def test_version(self):
16211        # issue gh-32561
16212        self.assertTrue(torch.__version__.startswith(torch.onnx.producer_version))
16213
16214for test in module_tests + new_module_tests + additional_module_tests:
16215    add_nn_module_test(**test)
16216
16217for test in criterion_tests:
16218    test['no_grad'] = True
16219    add_nn_module_test(**test)
16220
16221if __name__ == '__main__':
16222    TestCase._default_dtype_check_enabled = True
16223    run_tests()
16224    import jit.test_module_interface
16225    suite = unittest.findTestCases(jit.test_module_interface)
16226    unittest.TextTestRunner().run(suite)
16227