xref: /aosp_15_r20/external/pytorch/test/test_jit.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport torch
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Worker# This is how we include tests located in test/jit/...
6*da0073e9SAndroid Build Coastguard Worker# They are included here so that they are invoked when you call `test_jit.py`,
7*da0073e9SAndroid Build Coastguard Worker# do not run these test files directly.
8*da0073e9SAndroid Build Coastguard Workerfrom jit.test_tracer import TestTracer, TestMixTracingScripting  # noqa: F401
9*da0073e9SAndroid Build Coastguard Workerfrom jit.test_recursive_script import TestRecursiveScript  # noqa: F401
10*da0073e9SAndroid Build Coastguard Workerfrom jit.test_type_sharing import TestTypeSharing  # noqa: F401
11*da0073e9SAndroid Build Coastguard Workerfrom jit.test_logging import TestLogging  # noqa: F401
12*da0073e9SAndroid Build Coastguard Workerfrom jit.test_backends import TestBackends, TestBackendsWithCompiler  # noqa: F401
13*da0073e9SAndroid Build Coastguard Workerfrom jit.test_backend_nnapi import TestNnapiBackend  # noqa: F401
14*da0073e9SAndroid Build Coastguard Workerfrom jit.test_list_dict import TestList, TestDict, TestNamedTuple, TestScriptDict, TestScriptList  # noqa: F401
15*da0073e9SAndroid Build Coastguard Workerfrom jit.test_async import TestAsync  # noqa: F401
16*da0073e9SAndroid Build Coastguard Workerfrom jit.test_await import TestAwait  # noqa: F401
17*da0073e9SAndroid Build Coastguard Workerfrom jit.test_data_parallel import TestDataParallel  # noqa: F401
18*da0073e9SAndroid Build Coastguard Workerfrom jit.test_models import TestModels  # noqa: F401
19*da0073e9SAndroid Build Coastguard Workerfrom jit.test_modules import TestModules  # noqa: F401
20*da0073e9SAndroid Build Coastguard Workerfrom jit.test_autodiff import TestAutodiffJit  # noqa: F401
21*da0073e9SAndroid Build Coastguard Workerfrom jit.test_autodiff_subgraph_slicing import TestAutodiffSubgraphSlicing  # noqa: F401
22*da0073e9SAndroid Build Coastguard Workerfrom jit.test_custom_operators import TestCustomOperators  # noqa: F401
23*da0073e9SAndroid Build Coastguard Workerfrom jit.test_graph_rewrite_passes import TestGraphRewritePasses  # noqa: F401
24*da0073e9SAndroid Build Coastguard Workerfrom jit.test_class_type import TestClassType  # noqa: F401
25*da0073e9SAndroid Build Coastguard Workerfrom jit.test_builtins import TestBuiltins, TestTensorBuiltins  # noqa: F401
26*da0073e9SAndroid Build Coastguard Workerfrom jit.test_ignore_context_manager import TestIgnoreContextManager  # noqa: F401
27*da0073e9SAndroid Build Coastguard Workerfrom jit.test_symbolic_shape_analysis import TestSymbolicShapeAnalysis  # noqa: F401
28*da0073e9SAndroid Build Coastguard Workerfrom jit.test_op_decompositions import TestOpDecompositions  # noqa: F401
29*da0073e9SAndroid Build Coastguard Workerfrom jit.test_unsupported_ops import TestUnsupportedOps  # noqa: F401
30*da0073e9SAndroid Build Coastguard Workerfrom jit.test_freezing import TestFreezing, TestFrozenOptimizations, TestMKLDNNReinplacing  # noqa: F401
31*da0073e9SAndroid Build Coastguard Workerfrom jit.test_peephole import TestPeephole  # noqa: F401
32*da0073e9SAndroid Build Coastguard Workerfrom jit.test_alias_analysis import TestAliasAnalysis  # noqa: F401
33*da0073e9SAndroid Build Coastguard Workerfrom jit.test_save_load import TestSaveLoad, TestSaveLoadFlatbuffer  # noqa: F401
34*da0073e9SAndroid Build Coastguard Workerfrom jit.test_save_load_for_op_version import TestSaveLoadForOpVersion  # noqa: F401
35*da0073e9SAndroid Build Coastguard Workerfrom jit.test_module_containers import TestModuleContainers  # noqa: F401
36*da0073e9SAndroid Build Coastguard Workerfrom jit.test_python_bindings import TestPythonBindings  # noqa: F401
37*da0073e9SAndroid Build Coastguard Workerfrom jit.test_python_ir import TestPythonIr  # noqa: F401
38*da0073e9SAndroid Build Coastguard Workerfrom jit.test_functional_blocks import TestFunctionalBlocks  # noqa: F401
39*da0073e9SAndroid Build Coastguard Workerfrom jit.test_remove_mutation import TestRemoveMutation  # noqa: F401
40*da0073e9SAndroid Build Coastguard Workerfrom jit.test_torchbind import TestTorchbind  # noqa: F401
41*da0073e9SAndroid Build Coastguard Workerfrom jit.test_module_interface import TestModuleInterface  # noqa: F401
42*da0073e9SAndroid Build Coastguard Workerfrom jit.test_with import TestWith  # noqa: F401
43*da0073e9SAndroid Build Coastguard Workerfrom jit.test_enum import TestEnum  # noqa: F401
44*da0073e9SAndroid Build Coastguard Workerfrom jit.test_string_formatting import TestStringFormatting  # noqa: F401
45*da0073e9SAndroid Build Coastguard Workerfrom jit.test_profiler import TestProfiler  # noqa: F401
46*da0073e9SAndroid Build Coastguard Workerfrom jit.test_slice import TestSlice  # noqa: F401
47*da0073e9SAndroid Build Coastguard Workerfrom jit.test_ignorable_args import TestIgnorableArgs  # noqa: F401
48*da0073e9SAndroid Build Coastguard Workerfrom jit.test_hooks import TestHooks  # noqa: F401
49*da0073e9SAndroid Build Coastguard Workerfrom jit.test_warn import TestWarn  # noqa: F401
50*da0073e9SAndroid Build Coastguard Workerfrom jit.test_isinstance import TestIsinstance  # noqa: F401
51*da0073e9SAndroid Build Coastguard Workerfrom jit.test_cuda import TestCUDA  # noqa: F401
52*da0073e9SAndroid Build Coastguard Workerfrom jit.test_python_builtins import TestPythonBuiltinOP  # noqa: F401
53*da0073e9SAndroid Build Coastguard Workerfrom jit.test_typing import TestTyping  # noqa: F401
54*da0073e9SAndroid Build Coastguard Workerfrom jit.test_hash import TestHash  # noqa: F401
55*da0073e9SAndroid Build Coastguard Workerfrom jit.test_complex import TestComplex  # noqa: F401
56*da0073e9SAndroid Build Coastguard Workerfrom jit.test_jit_utils import TestJitUtils  # noqa: F401
57*da0073e9SAndroid Build Coastguard Workerfrom jit.test_scriptmod_ann import TestScriptModuleInstanceAttributeTypeAnnotation  # noqa: F401
58*da0073e9SAndroid Build Coastguard Workerfrom jit.test_types import TestTypesAndAnnotation  # noqa: F401
59*da0073e9SAndroid Build Coastguard Workerfrom jit.test_misc import TestMisc  # noqa: F401
60*da0073e9SAndroid Build Coastguard Workerfrom jit.test_upgraders import TestUpgraders  # noqa: F401
61*da0073e9SAndroid Build Coastguard Workerfrom jit.test_pdt import TestPDT  # noqa: F401
62*da0073e9SAndroid Build Coastguard Workerfrom jit.test_tensor_creation_ops import TestTensorCreationOps  # noqa: F401
63*da0073e9SAndroid Build Coastguard Workerfrom jit.test_module_apis import TestModuleAPIs  # noqa: F401
64*da0073e9SAndroid Build Coastguard Workerfrom jit.test_script_profile import TestScriptProfile  # noqa: F401
65*da0073e9SAndroid Build Coastguard Workerfrom jit.test_convert_activation import TestFunctionalToInplaceActivation, TestInplaceToFunctionalActivation  # noqa: F401
66*da0073e9SAndroid Build Coastguard Workerfrom jit.test_parametrization import TestParametrization  # noqa: F401
67*da0073e9SAndroid Build Coastguard Workerfrom jit.test_attr import TestGetDefaultAttr  # noqa: F401
68*da0073e9SAndroid Build Coastguard Workerfrom jit.test_aten_pow import TestAtenPow  # noqa: F401
69*da0073e9SAndroid Build Coastguard Workerfrom jit.test_optimize_for_mobile_preserve_debug_info import TestOptimizeForMobilePreserveDebugInfo  # noqa: F401
70*da0073e9SAndroid Build Coastguard Workerfrom jit.test_union import TestUnion  # noqa: F401
71*da0073e9SAndroid Build Coastguard Workerfrom jit.test_batch_mm import TestBatchMM  # noqa: F401
72*da0073e9SAndroid Build Coastguard Workerfrom jit.test_dtype_analysis import TestDtypeAnalysis, TestDtypeCustomRulesCPU  # noqa: F401
73*da0073e9SAndroid Build Coastguard Workerfrom jit.test_device_analysis import TestDeviceAnalysis  # noqa: F401
74*da0073e9SAndroid Build Coastguard Workerfrom jit.test_dce import TestDCE  # noqa: F401
75*da0073e9SAndroid Build Coastguard Workerfrom jit.test_sparse import TestSparse  # noqa: F401
76*da0073e9SAndroid Build Coastguard Workerfrom jit.test_tensor_methods import TestTensorMethods  # noqa: F401
77*da0073e9SAndroid Build Coastguard Workerfrom jit.test_dataclasses import TestDataclasses  # noqa: F401
78*da0073e9SAndroid Build Coastguard Workerfrom jit.test_generator import TestGenerator  # noqa: F401
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker# Torch
81*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor
82*da0073e9SAndroid Build Coastguard Workerfrom torch._C import TensorType, BoolType, parse_ir, _propagate_shapes
83*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd import Variable
84*da0073e9SAndroid Build Coastguard Workerfrom torch.jit.annotations import BroadcastingList2, BroadcastingList3, Any  # noqa: F401
85*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.utils.rnn import PackedSequence
86*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck, make_tensor
87*da0073e9SAndroid Build Coastguard Workerimport torch.autograd.profiler
88*da0073e9SAndroid Build Coastguard Workerimport torch.cuda
89*da0073e9SAndroid Build Coastguard Workerimport torch.jit
90*da0073e9SAndroid Build Coastguard Workerimport torch.jit._logging
91*da0073e9SAndroid Build Coastguard Workerimport torch.jit.frontend
92*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn
93*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker# Testing utils
96*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal import jit_utils
97*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_jit import check_against_reference
98*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
99*da0073e9SAndroid Build Coastguard Worker    suppress_warnings, IS_SANDCASTLE, GRAPH_EXECUTOR, ProfilingMode, TestCase, \
100*da0073e9SAndroid Build Coastguard Worker    freeze_rng_state, slowTest, TemporaryFileName, \
101*da0073e9SAndroid Build Coastguard Worker    enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs, \
102*da0073e9SAndroid Build Coastguard Worker    skipIfCrossRef, skipIfTorchDynamo
103*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, disable_autodiff_subgraph_inlining, \
104*da0073e9SAndroid Build Coastguard Worker    _trace, do_input_map, get_execution_plan, make_global, \
105*da0073e9SAndroid Build Coastguard Worker    execWrapper, _inline_everything, _tmp_donotuse_dont_inline_everything, \
106*da0073e9SAndroid Build Coastguard Worker    RUN_CUDA
107*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_metaprogramming_utils import (
108*da0073e9SAndroid Build Coastguard Worker    get_script_args,
109*da0073e9SAndroid Build Coastguard Worker    create_input, unpack_variables,
110*da0073e9SAndroid Build Coastguard Worker    additional_module_tests, EXCLUDE_SCRIPT_MODULES,
111*da0073e9SAndroid Build Coastguard Worker    get_nn_module_name_from_kwargs, get_nn_mod_test_name, script_method_template)
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_nn import module_tests, new_module_tests, criterion_tests
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker# For testing truediv in python 2
116*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.test_module.future_div import div_int_future, div_float_future
117*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.test_module.no_future_div import div_int_nofuture, div_float_nofuture
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker# Standard library
120*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict, namedtuple, OrderedDict
121*da0073e9SAndroid Build Coastguard Workerfrom copy import deepcopy
122*da0073e9SAndroid Build Coastguard Workerfrom itertools import product
123*da0073e9SAndroid Build Coastguard Workerfrom textwrap import dedent
124*da0073e9SAndroid Build Coastguard Workerfrom typing import List, Dict, NamedTuple, Optional, Tuple, Union
125*da0073e9SAndroid Build Coastguard Workerimport copy
126*da0073e9SAndroid Build Coastguard Workerimport functools
127*da0073e9SAndroid Build Coastguard Workerimport inspect
128*da0073e9SAndroid Build Coastguard Workerimport io
129*da0073e9SAndroid Build Coastguard Workerimport itertools
130*da0073e9SAndroid Build Coastguard Workerimport math
131*da0073e9SAndroid Build Coastguard Workerimport numpy as np
132*da0073e9SAndroid Build Coastguard Workerimport os
133*da0073e9SAndroid Build Coastguard Workerimport pickle
134*da0073e9SAndroid Build Coastguard Workerimport pickletools
135*da0073e9SAndroid Build Coastguard Workerimport random
136*da0073e9SAndroid Build Coastguard Workerimport re
137*da0073e9SAndroid Build Coastguard Workerimport shutil
138*da0073e9SAndroid Build Coastguard Workerimport string
139*da0073e9SAndroid Build Coastguard Workerimport sys
140*da0073e9SAndroid Build Coastguard Workerimport tempfile
141*da0073e9SAndroid Build Coastguard Workerimport types
142*da0073e9SAndroid Build Coastguard Workerimport typing
143*da0073e9SAndroid Build Coastguard Workerimport unittest
144*da0073e9SAndroid Build Coastguard Workerimport warnings
145*da0073e9SAndroid Build Coastguard Workerimport zipfile
146*da0073e9SAndroid Build Coastguard Workerimport tracemalloc
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Workerdef canonical(graph):
150*da0073e9SAndroid Build Coastguard Worker    return torch._C._jit_pass_canonicalize(graph).str(False)
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Workerdef LSTMCellF(input, hx, cx, *params):
153*da0073e9SAndroid Build Coastguard Worker    return LSTMCell(input, (hx, cx), *params)
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Workerdef doAutodiffCheck(testname):
156*da0073e9SAndroid Build Coastguard Worker    # TODO: setting false on test itself is not working
157*da0073e9SAndroid Build Coastguard Worker    if "test_t_" in testname or testname == "test_t":
158*da0073e9SAndroid Build Coastguard Worker        return False
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker    if GRAPH_EXECUTOR == ProfilingMode.SIMPLE:
161*da0073e9SAndroid Build Coastguard Worker        return False
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker    if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
164*da0073e9SAndroid Build Coastguard Worker        return True
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker    # these tests are disabled because BailOut nodes
168*da0073e9SAndroid Build Coastguard Worker    # inserted by ProfilingExecutor interfere with
169*da0073e9SAndroid Build Coastguard Worker    # subgraph slicing of Differentiable Graphs
170*da0073e9SAndroid Build Coastguard Worker    test_exceptions = (
171*da0073e9SAndroid Build Coastguard Worker        # functional
172*da0073e9SAndroid Build Coastguard Worker        'test_nn_dropout',
173*da0073e9SAndroid Build Coastguard Worker        'test_nn_log_softmax',
174*da0073e9SAndroid Build Coastguard Worker        'test_nn_relu',
175*da0073e9SAndroid Build Coastguard Worker        'test_nn_softmax',
176*da0073e9SAndroid Build Coastguard Worker        'test_nn_threshold',
177*da0073e9SAndroid Build Coastguard Worker        'test_nn_lp_pool2d',
178*da0073e9SAndroid Build Coastguard Worker        'test_nn_lp_pool1d',
179*da0073e9SAndroid Build Coastguard Worker        'test_nn_gumbel_softmax_hard',
180*da0073e9SAndroid Build Coastguard Worker        'test_nn_gumbel_softmax',
181*da0073e9SAndroid Build Coastguard Worker        'test_nn_multilabel_soft_margin_loss',
182*da0073e9SAndroid Build Coastguard Worker        'test_nn_batch_norm',
183*da0073e9SAndroid Build Coastguard Worker        'test_nn_max_pool2d_with_indices',
184*da0073e9SAndroid Build Coastguard Worker        # AutogradJitGenerated
185*da0073e9SAndroid Build Coastguard Worker        'test___rdiv___constant',
186*da0073e9SAndroid Build Coastguard Worker        'test___rdiv___scalar_constant',
187*da0073e9SAndroid Build Coastguard Worker        'test_split',
188*da0073e9SAndroid Build Coastguard Worker        'test_split_dim',
189*da0073e9SAndroid Build Coastguard Worker        'test_split_dim_neg0',
190*da0073e9SAndroid Build Coastguard Worker        'test_split_size_list',
191*da0073e9SAndroid Build Coastguard Worker        'test_split_size_list_dim',
192*da0073e9SAndroid Build Coastguard Worker        'test_split_size_list_dim_neg0',
193*da0073e9SAndroid Build Coastguard Worker        'test_split_with_sizes',
194*da0073e9SAndroid Build Coastguard Worker        'test_split_with_sizes_dim',
195*da0073e9SAndroid Build Coastguard Worker        'test_split_with_sizes_dim_neg0',
196*da0073e9SAndroid Build Coastguard Worker        'test_split_with_sizes_size_0',
197*da0073e9SAndroid Build Coastguard Worker        'test_nn_max_pool2d_with_indices',
198*da0073e9SAndroid Build Coastguard Worker    )
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Worker    return testname not in test_exceptions
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker# TODO: enable TE in PE when all tests are fixed
204*da0073e9SAndroid Build Coastguard Workertorch._C._jit_set_texpr_fuser_enabled(GRAPH_EXECUTOR == ProfilingMode.PROFILING)
205*da0073e9SAndroid Build Coastguard Workertorch._C._jit_set_profiling_executor(GRAPH_EXECUTOR != ProfilingMode.LEGACY)
206*da0073e9SAndroid Build Coastguard Worker
207*da0073e9SAndroid Build Coastguard Workerdef LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
208*da0073e9SAndroid Build Coastguard Worker    hx, cx = hidden
209*da0073e9SAndroid Build Coastguard Worker    gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
212*da0073e9SAndroid Build Coastguard Worker    ingate = torch.sigmoid(ingate)
213*da0073e9SAndroid Build Coastguard Worker    forgetgate = torch.sigmoid(forgetgate)
214*da0073e9SAndroid Build Coastguard Worker    cellgate = torch.tanh(cellgate)
215*da0073e9SAndroid Build Coastguard Worker    outgate = torch.sigmoid(outgate)
216*da0073e9SAndroid Build Coastguard Worker
217*da0073e9SAndroid Build Coastguard Worker    cy = (forgetgate * cx) + (ingate * cellgate)
218*da0073e9SAndroid Build Coastguard Worker    hy = outgate * torch.tanh(cy)
219*da0073e9SAndroid Build Coastguard Worker    return hy, cy
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Workerdef LSTMCellC(*args, **kwargs):
223*da0073e9SAndroid Build Coastguard Worker    hy, cy = LSTMCellF(*args, **kwargs)
224*da0073e9SAndroid Build Coastguard Worker    return torch.cat((hy, cy))
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Workerdef LSTMCellS(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
228*da0073e9SAndroid Build Coastguard Worker    gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh
229*da0073e9SAndroid Build Coastguard Worker    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
230*da0073e9SAndroid Build Coastguard Worker    ingate = torch.sigmoid(ingate)
231*da0073e9SAndroid Build Coastguard Worker    forgetgate = torch.sigmoid(forgetgate)
232*da0073e9SAndroid Build Coastguard Worker    cellgate = torch.tanh(cellgate)
233*da0073e9SAndroid Build Coastguard Worker    outgate = torch.sigmoid(outgate)
234*da0073e9SAndroid Build Coastguard Worker    cy = (forgetgate * cx) + (ingate * cellgate)
235*da0073e9SAndroid Build Coastguard Worker    hy = outgate * torch.tanh(cy)
236*da0073e9SAndroid Build Coastguard Worker    return hy, cy
237*da0073e9SAndroid Build Coastguard Worker
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker# Code reference: https://github.com/pytorch/translate/blob/master/pytorch_translate/rnn_cell.py#L27:44
240*da0073e9SAndroid Build Coastguard Workerdef MiLSTMCell(x, hx, cx, w_ih, w_hh, alpha, beta_i, beta_h, bias):
241*da0073e9SAndroid Build Coastguard Worker    Wx = x.mm(w_ih.t())
242*da0073e9SAndroid Build Coastguard Worker    Uz = hx.mm(w_hh.t())
243*da0073e9SAndroid Build Coastguard Worker    # Section 2.1 in https://arxiv.org/pdf/1606.06630.pdf
244*da0073e9SAndroid Build Coastguard Worker    gates = alpha * Wx * Uz + beta_i * Wx + beta_h * Uz + bias
245*da0073e9SAndroid Build Coastguard Worker    # Same as LSTMCell after this point
246*da0073e9SAndroid Build Coastguard Worker    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
247*da0073e9SAndroid Build Coastguard Worker    ingate = ingate.sigmoid()
248*da0073e9SAndroid Build Coastguard Worker    forgetgate = forgetgate.sigmoid()
249*da0073e9SAndroid Build Coastguard Worker    cellgate = cellgate.tanh()
250*da0073e9SAndroid Build Coastguard Worker    outgate = outgate.sigmoid()
251*da0073e9SAndroid Build Coastguard Worker    cy = (forgetgate * cx) + (ingate * cellgate)
252*da0073e9SAndroid Build Coastguard Worker    hy = outgate * cy.tanh()
253*da0073e9SAndroid Build Coastguard Worker    return hy, cy
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Workerdef get_lstm_inputs(device, training=False, seq_length=None):
258*da0073e9SAndroid Build Coastguard Worker    input_shape = (3, 10) if seq_length is None else (seq_length, 3, 10)
259*da0073e9SAndroid Build Coastguard Worker    input = torch.randn(*input_shape, dtype=torch.float, device=device, requires_grad=training)
260*da0073e9SAndroid Build Coastguard Worker    hx = torch.randn(3, 20, dtype=torch.float, device=device, requires_grad=training)
261*da0073e9SAndroid Build Coastguard Worker    cx = torch.randn(3, 20, dtype=torch.float, device=device, requires_grad=training)
262*da0073e9SAndroid Build Coastguard Worker    module = nn.LSTMCell(10, 20).to(device, torch.float)  # Just to allocate weights with correct sizes
263*da0073e9SAndroid Build Coastguard Worker    if training:
264*da0073e9SAndroid Build Coastguard Worker        params = tuple(module.parameters())
265*da0073e9SAndroid Build Coastguard Worker    else:
266*da0073e9SAndroid Build Coastguard Worker        params = tuple(p.requires_grad_(False) for p in module.parameters())
267*da0073e9SAndroid Build Coastguard Worker    return (input, hx, cx) + params
268*da0073e9SAndroid Build Coastguard Worker
269*da0073e9SAndroid Build Coastguard Worker
270*da0073e9SAndroid Build Coastguard Workerdef get_milstm_inputs(device, training=False):
271*da0073e9SAndroid Build Coastguard Worker    minibatch = 3
272*da0073e9SAndroid Build Coastguard Worker    input_size = 10
273*da0073e9SAndroid Build Coastguard Worker    hidden_size = 20
274*da0073e9SAndroid Build Coastguard Worker    x = torch.randn(minibatch, input_size, device=device, dtype=torch.float)
275*da0073e9SAndroid Build Coastguard Worker    hx = torch.randn(minibatch, hidden_size, device=device, dtype=torch.float)
276*da0073e9SAndroid Build Coastguard Worker    cx = torch.randn(minibatch, hidden_size, device=device, dtype=torch.float)
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker    ih = torch.randn(4 * hidden_size, input_size, device=device, dtype=torch.float, requires_grad=training)
279*da0073e9SAndroid Build Coastguard Worker    hh = torch.randn(4 * hidden_size, hidden_size, device=device, dtype=torch.float, requires_grad=training)
280*da0073e9SAndroid Build Coastguard Worker    alpha = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training)
281*da0073e9SAndroid Build Coastguard Worker    ibeta = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training)
282*da0073e9SAndroid Build Coastguard Worker    hbeta = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training)
283*da0073e9SAndroid Build Coastguard Worker    bias = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training)
284*da0073e9SAndroid Build Coastguard Worker    return x, hx, cx, ih, hh, alpha, ibeta, hbeta, bias
285*da0073e9SAndroid Build Coastguard Worker
286*da0073e9SAndroid Build Coastguard Worker
287*da0073e9SAndroid Build Coastguard Workerdef get_fn(file_name, script_path):
288*da0073e9SAndroid Build Coastguard Worker    import importlib.util
289*da0073e9SAndroid Build Coastguard Worker    spec = importlib.util.spec_from_file_location(file_name, script_path)
290*da0073e9SAndroid Build Coastguard Worker    module = importlib.util.module_from_spec(spec)
291*da0073e9SAndroid Build Coastguard Worker    spec.loader.exec_module(module)
292*da0073e9SAndroid Build Coastguard Worker    fn = module.fn
293*da0073e9SAndroid Build Coastguard Worker    return fn
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Workerdef get_grad_executor(plan_state, diff_graph_idx=None, skip_check=False):
296*da0073e9SAndroid Build Coastguard Worker    if diff_graph_idx is None:
297*da0073e9SAndroid Build Coastguard Worker        nodes = list(plan_state.graph.nodes())
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker        if not skip_check:
300*da0073e9SAndroid Build Coastguard Worker            nodes = list(filter(lambda n : n.kind() != "prim::BailOut" and n.kind() != "prim::BailoutTemplate", nodes))
301*da0073e9SAndroid Build Coastguard Worker            if len(nodes) == 1 or (len(nodes) == 2 and nodes[1].kind() == "prim::TupleConstruct"):
302*da0073e9SAndroid Build Coastguard Worker                pass
303*da0073e9SAndroid Build Coastguard Worker            elif len(nodes) == 2 and nodes[0].kind() == "prim::RequiresGradCheck" and nodes[1].kind() == "prim::If":
304*da0073e9SAndroid Build Coastguard Worker                pass
305*da0073e9SAndroid Build Coastguard Worker            else:
306*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError("Can't get a grad_executor for a non-differentiable graph")
307*da0073e9SAndroid Build Coastguard Worker    grad_executors = list(plan_state.code.grad_executor_states())
308*da0073e9SAndroid Build Coastguard Worker    return grad_executors[diff_graph_idx or 0]
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Worker
311*da0073e9SAndroid Build Coastguard Workerdef all_backward_graphs(script_module, diff_graph_idx=None):
312*da0073e9SAndroid Build Coastguard Worker    # Note: for Python 2 the order seems to be unstable
313*da0073e9SAndroid Build Coastguard Worker    ge_state = script_module.get_debug_state()
314*da0073e9SAndroid Build Coastguard Worker    fwd_plan = get_execution_plan(ge_state)
315*da0073e9SAndroid Build Coastguard Worker    grad_executor_state = get_grad_executor(fwd_plan, diff_graph_idx=diff_graph_idx)
316*da0073e9SAndroid Build Coastguard Worker    bwd_plans = list(grad_executor_state.execution_plans.values())
317*da0073e9SAndroid Build Coastguard Worker    return [p.graph.copy() for p in bwd_plans]
318*da0073e9SAndroid Build Coastguard Worker
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Workerdef backward_graph(script_module, diff_graph_idx=None, skip_check=False):
321*da0073e9SAndroid Build Coastguard Worker    ge_state = script_module.get_debug_state()
322*da0073e9SAndroid Build Coastguard Worker    fwd_plan = get_execution_plan(ge_state)
323*da0073e9SAndroid Build Coastguard Worker    grad_executor_state = get_grad_executor(fwd_plan, diff_graph_idx=diff_graph_idx, skip_check=skip_check)
324*da0073e9SAndroid Build Coastguard Worker    bwd_plan = get_execution_plan(grad_executor_state)
325*da0073e9SAndroid Build Coastguard Worker    # Running JIT passes requires that we own the graph (with a shared_ptr).
326*da0073e9SAndroid Build Coastguard Worker    # The debug state struct does not own its graph so we make a copy of it.
327*da0073e9SAndroid Build Coastguard Worker    return bwd_plan.graph.copy()
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker
330*da0073e9SAndroid Build Coastguard Worker# helper function to get sum of List[Tensor]
331*da0073e9SAndroid Build Coastguard Workerdef _sum_of_list(tensorlist):
332*da0073e9SAndroid Build Coastguard Worker    s = 0
333*da0073e9SAndroid Build Coastguard Worker    for t in tensorlist:
334*da0073e9SAndroid Build Coastguard Worker        s += t.sum()
335*da0073e9SAndroid Build Coastguard Worker    return s
336*da0073e9SAndroid Build Coastguard Worker
337*da0073e9SAndroid Build Coastguard Worker
338*da0073e9SAndroid Build Coastguard Worker# has to be at top level or Pickle complains
339*da0073e9SAndroid Build Coastguard Workerclass FooToPickle(torch.nn.Module):
340*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
341*da0073e9SAndroid Build Coastguard Worker        super().__init__()
342*da0073e9SAndroid Build Coastguard Worker        self.bar = torch.jit.ScriptModule()
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker
345*da0073e9SAndroid Build Coastguard Workerclass TestJitProfiler(JitTestCase):
346*da0073e9SAndroid Build Coastguard Worker    """
347*da0073e9SAndroid Build Coastguard Worker    This runs tests that requires setting some global states like torch._C._set_graph_executor_optimize
348*da0073e9SAndroid Build Coastguard Worker    and restore the values afterward, i.e. test_profiler. This is to address the flaky issue in
349*da0073e9SAndroid Build Coastguard Worker    https://github.com/pytorch/pytorch/issues/91483 in which test_profiler was flaky and failed in the
350*da0073e9SAndroid Build Coastguard Worker    middle without the chance to restore torch._C._set_graph_executor_optimize to its original value.
351*da0073e9SAndroid Build Coastguard Worker    This causes issues for all future tests running after.
352*da0073e9SAndroid Build Coastguard Worker
353*da0073e9SAndroid Build Coastguard Worker    Using a separate test class here, so that there is no need to run setup and teardown for all tests
354*da0073e9SAndroid Build Coastguard Worker    in TestJit.
355*da0073e9SAndroid Build Coastguard Worker    """
356*da0073e9SAndroid Build Coastguard Worker
357*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
358*da0073e9SAndroid Build Coastguard Worker        super().setUp()
359*da0073e9SAndroid Build Coastguard Worker        self.graph_executor_optimize_opt = torch._C._get_graph_executor_optimize()
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
362*da0073e9SAndroid Build Coastguard Worker        super().tearDown()
363*da0073e9SAndroid Build Coastguard Worker        # Resetting
364*da0073e9SAndroid Build Coastguard Worker        torch._C._set_graph_executor_optimize(
365*da0073e9SAndroid Build Coastguard Worker            self.graph_executor_optimize_opt
366*da0073e9SAndroid Build Coastguard Worker        )
367*da0073e9SAndroid Build Coastguard Worker
368*da0073e9SAndroid Build Coastguard Worker    def test_profiler(self):
369*da0073e9SAndroid Build Coastguard Worker        torch._C._set_graph_executor_optimize(False)
370*da0073e9SAndroid Build Coastguard Worker
371*da0073e9SAndroid Build Coastguard Worker        def other_fn(x):
372*da0073e9SAndroid Build Coastguard Worker            return x * 2
373*da0073e9SAndroid Build Coastguard Worker
374*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
375*da0073e9SAndroid Build Coastguard Worker        traced_other_fn = torch.jit.trace(other_fn, x)
376*da0073e9SAndroid Build Coastguard Worker
377*da0073e9SAndroid Build Coastguard Worker        def fn(x):
378*da0073e9SAndroid Build Coastguard Worker            y = traced_other_fn(x)
379*da0073e9SAndroid Build Coastguard Worker            fut = torch.jit._fork(traced_other_fn, x)
380*da0073e9SAndroid Build Coastguard Worker            y = torch.jit._wait(fut)
381*da0073e9SAndroid Build Coastguard Worker            return y
382*da0073e9SAndroid Build Coastguard Worker
383*da0073e9SAndroid Build Coastguard Worker        traced_fn = torch.jit.trace(fn, x)
384*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.profiler.profile() as prof:
385*da0073e9SAndroid Build Coastguard Worker            traced_fn(x)
386*da0073e9SAndroid Build Coastguard Worker
387*da0073e9SAndroid Build Coastguard Worker        # expecting to see other_fn TS function call
388*da0073e9SAndroid Build Coastguard Worker        # with cpu time >= mul cpu time and
389*da0073e9SAndroid Build Coastguard Worker        # a forked other_fn
390*da0073e9SAndroid Build Coastguard Worker
391*da0073e9SAndroid Build Coastguard Worker        mul_events = defaultdict(int)
392*da0073e9SAndroid Build Coastguard Worker        other_fn_events = defaultdict(int)
393*da0073e9SAndroid Build Coastguard Worker        for e in prof.function_events:
394*da0073e9SAndroid Build Coastguard Worker            if e.name == "aten::mul":
395*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(e.thread not in mul_events)
396*da0073e9SAndroid Build Coastguard Worker                mul_events[e.thread] = e.time_range.elapsed_us()
397*da0073e9SAndroid Build Coastguard Worker            elif e.name == "other_fn":
398*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(e.thread not in other_fn_events)
399*da0073e9SAndroid Build Coastguard Worker                other_fn_events[e.thread] = e.time_range.elapsed_us()
400*da0073e9SAndroid Build Coastguard Worker
401*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(mul_events) == 2)
402*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(other_fn_events) == 2)
403*da0073e9SAndroid Build Coastguard Worker
404*da0073e9SAndroid Build Coastguard Worker        for thread, mul_time in mul_events.items():
405*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(thread in other_fn_events)
406*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(other_fn_events[thread] >= mul_time)
407*da0073e9SAndroid Build Coastguard Worker
408*da0073e9SAndroid Build Coastguard Worker
409*da0073e9SAndroid Build Coastguard Workerclass TestJit(JitTestCase):
410*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("Requires a lot of RAM")
411*da0073e9SAndroid Build Coastguard Worker    def test_big(self):
412*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.ScriptModule()
413*da0073e9SAndroid Build Coastguard Worker        gig = int(1024 * 1024 * 1024 / 4)
414*da0073e9SAndroid Build Coastguard Worker        # a small tensor in the first 4GB
415*da0073e9SAndroid Build Coastguard Worker        m.v0 = nn.Parameter(torch.full((2,), 1, dtype=torch.float))
416*da0073e9SAndroid Build Coastguard Worker        # a large tensor in the first 4GB that ends outside of it
417*da0073e9SAndroid Build Coastguard Worker        m.v1 = nn.Parameter(torch.full((5, gig), 2, dtype=torch.float))
418*da0073e9SAndroid Build Coastguard Worker        # a small tensor in >4GB space
419*da0073e9SAndroid Build Coastguard Worker        m.v2 = nn.Parameter(torch.full((2,), 3, dtype=torch.float))
420*da0073e9SAndroid Build Coastguard Worker        # s large tensor in the > 4GB space
421*da0073e9SAndroid Build Coastguard Worker        m.v3 = nn.Parameter(torch.full((5, gig), 4, dtype=torch.float))
422*da0073e9SAndroid Build Coastguard Worker
423*da0073e9SAndroid Build Coastguard Worker        m2 = self.getExportImportCopy(m)
424*da0073e9SAndroid Build Coastguard Worker
425*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
426*da0073e9SAndroid Build Coastguard Worker
427*da0073e9SAndroid Build Coastguard Worker    def test_inferred_as_tensor(self):
428*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Inferred the value for argument 'dim' to be of type 'Tensor' "
429*da0073e9SAndroid Build Coastguard Worker                                                  "because it was not annotated with an explicit type"):
430*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
431*da0073e9SAndroid Build Coastguard Worker            def dot(points, query, dim):
432*da0073e9SAndroid Build Coastguard Worker                return (points * query).sum(dim)
433*da0073e9SAndroid Build Coastguard Worker
434*da0073e9SAndroid Build Coastguard Worker    def test_constants_pkl(self):
435*da0073e9SAndroid Build Coastguard Worker        # This test asserts that the serialization archive includes a `constants.pkl`
436*da0073e9SAndroid Build Coastguard Worker        # file. This file is used by `torch.load` to determine whether a zip file
437*da0073e9SAndroid Build Coastguard Worker        # is a normal eager-mode serialization zip or a jit serialization zip. If
438*da0073e9SAndroid Build Coastguard Worker        # you are deleting `constants.pkl`, make sure to update `torch.serialization.load`
439*da0073e9SAndroid Build Coastguard Worker        # so it is still able to figure out which is which.
440*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
441*da0073e9SAndroid Build Coastguard Worker        def fn(x):
442*da0073e9SAndroid Build Coastguard Worker            return x
443*da0073e9SAndroid Build Coastguard Worker
444*da0073e9SAndroid Build Coastguard Worker        buf = io.BytesIO()
445*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(fn, buf)
446*da0073e9SAndroid Build Coastguard Worker        buf.seek(0)
447*da0073e9SAndroid Build Coastguard Worker
448*da0073e9SAndroid Build Coastguard Worker        files = zipfile.ZipFile(buf).filelist
449*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(any('archive/constants.pkl' == f.filename for f in files))
450*da0073e9SAndroid Build Coastguard Worker
451*da0073e9SAndroid Build Coastguard Worker    def test_script_fn_pkl(self):
452*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(pickle.PickleError, "ScriptFunction cannot be pickled"):
453*da0073e9SAndroid Build Coastguard Worker
454*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
455*da0073e9SAndroid Build Coastguard Worker            def fn(x: torch.Tensor) -> torch.Tensor:
456*da0073e9SAndroid Build Coastguard Worker                return x
457*da0073e9SAndroid Build Coastguard Worker
458*da0073e9SAndroid Build Coastguard Worker            pkl_fn = pickle.dumps(fn, protocol=0)
459*da0073e9SAndroid Build Coastguard Worker
460*da0073e9SAndroid Build Coastguard Worker    def test_restore_device(self):
461*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
462*da0073e9SAndroid Build Coastguard Worker            def __init__(self, cpu_device_str):
463*da0073e9SAndroid Build Coastguard Worker                super().__init__()
464*da0073e9SAndroid Build Coastguard Worker                self.p0 = nn.Parameter(torch.tensor([0.3], dtype=torch.float,
465*da0073e9SAndroid Build Coastguard Worker                                                    device=cpu_device_str))
466*da0073e9SAndroid Build Coastguard Worker                self.b0 = torch.tensor([0.9], dtype=torch.float,
467*da0073e9SAndroid Build Coastguard Worker                                       device=cpu_device_str)
468*da0073e9SAndroid Build Coastguard Worker
469*da0073e9SAndroid Build Coastguard Worker        # main purpose is checking map_location works
470*da0073e9SAndroid Build Coastguard Worker        m = M("cpu")
471*da0073e9SAndroid Build Coastguard Worker        m2 = self.getExportImportCopy(m)
472*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
473*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tuple(m.buffers()), tuple(m2.buffers()))
474*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m2.p0.is_cuda)
475*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(m2.b0.is_cuda)
476*da0073e9SAndroid Build Coastguard Worker
477*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "restore device requires CUDA")
478*da0073e9SAndroid Build Coastguard Worker    def test_restore_device_cuda(self):
479*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.jit.ScriptModule):
480*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
481*da0073e9SAndroid Build Coastguard Worker                super().__init__()
482*da0073e9SAndroid Build Coastguard Worker                self.b0 = nn.Buffer(torch.randn(1, 3))
483*da0073e9SAndroid Build Coastguard Worker                self.p0 = nn.Parameter(torch.randn(2, 3))
484*da0073e9SAndroid Build Coastguard Worker
485*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
486*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
487*da0073e9SAndroid Build Coastguard Worker                return x + self.b0 + self.p0
488*da0073e9SAndroid Build Coastguard Worker
489*da0073e9SAndroid Build Coastguard Worker        m = MyModule()
490*da0073e9SAndroid Build Coastguard Worker        m.cuda(torch.cuda.device_count() - 1)
491*da0073e9SAndroid Build Coastguard Worker        cuda_device_str = 'cuda:' + str(torch.cuda.device_count() - 1)
492*da0073e9SAndroid Build Coastguard Worker
493*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m.p0.is_cuda)
494*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m.b0.is_cuda)
495*da0073e9SAndroid Build Coastguard Worker
496*da0073e9SAndroid Build Coastguard Worker        # restore to the saved devices
497*da0073e9SAndroid Build Coastguard Worker        m2 = self.getExportImportCopy(m)
498*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
499*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tuple(m.buffers()), tuple(m2.buffers()))
500*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(str(m2.p0.device), cuda_device_str)
501*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(str(m2.b0.device), cuda_device_str)
502*da0073e9SAndroid Build Coastguard Worker
503*da0073e9SAndroid Build Coastguard Worker        # restore all to cpu using string
504*da0073e9SAndroid Build Coastguard Worker        cpu_device_str = 'cpu'
505*da0073e9SAndroid Build Coastguard Worker        m3 = self.getExportImportCopy(m, map_location=cpu_device_str)
506*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(str(m3.p0.device), cpu_device_str)
507*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(str(m3.b0.device), cpu_device_str)
508*da0073e9SAndroid Build Coastguard Worker
509*da0073e9SAndroid Build Coastguard Worker        # restore all to first gpu using device
510*da0073e9SAndroid Build Coastguard Worker        m4 = self.getExportImportCopy(
511*da0073e9SAndroid Build Coastguard Worker            m3, map_location=torch.device('cuda:0'))
512*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(str(m4.p0.device), 'cuda:0')
513*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(str(m4.b0.device), 'cuda:0')
514*da0073e9SAndroid Build Coastguard Worker
515*da0073e9SAndroid Build Coastguard Worker        # compute and compare the results
516*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(2, 3).cuda(torch.cuda.device_count() - 1)
517*da0073e9SAndroid Build Coastguard Worker        origin_result = m(input)
518*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(origin_result, m2(input))
519*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(origin_result, m3(input.cpu()))
520*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(origin_result, m4(input.cuda(0)))
521*da0073e9SAndroid Build Coastguard Worker
522*da0073e9SAndroid Build Coastguard Worker    def test_trace_retains_train(self):
523*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
524*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
525*da0073e9SAndroid Build Coastguard Worker                return x
526*da0073e9SAndroid Build Coastguard Worker        m = M()
527*da0073e9SAndroid Build Coastguard Worker        m.eval()
528*da0073e9SAndroid Build Coastguard Worker        tm = torch.jit.trace(m, (torch.rand(3)))
529*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tm.training, m.training)
530*da0073e9SAndroid Build Coastguard Worker
531*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "restore device requires CUDA")
532*da0073e9SAndroid Build Coastguard Worker    def test_restore_shared_storage_on_cuda(self):
533*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.jit.ScriptModule):
534*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
535*da0073e9SAndroid Build Coastguard Worker                super().__init__()
536*da0073e9SAndroid Build Coastguard Worker                whole_tensor = torch.randn(4, 5, dtype=torch.float, device='cpu')
537*da0073e9SAndroid Build Coastguard Worker                self.p0 = nn.Parameter(whole_tensor.narrow(0, 0, 1))
538*da0073e9SAndroid Build Coastguard Worker                self.b0 = nn.Buffer(whole_tensor.narrow(0, 3, 1))
539*da0073e9SAndroid Build Coastguard Worker
540*da0073e9SAndroid Build Coastguard Worker        m = Foo()
541*da0073e9SAndroid Build Coastguard Worker        m2 = self.getExportImportCopy(m, map_location=torch.device('cuda:0'))
542*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
543*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tuple(m.buffers()), tuple(m2.buffers()))
544*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m2.p0.is_cuda)
545*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m2.b0.is_cuda)
546*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m2.p0.is_shared())
547*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m2.b0.is_shared())
548*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m2.b0.storage().data_ptr(), m2.p0.storage().data_ptr())
549*da0073e9SAndroid Build Coastguard Worker
550*da0073e9SAndroid Build Coastguard Worker    def test_add_relu_fusion(self):
551*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
552*da0073e9SAndroid Build Coastguard Worker            def __init__(self, relu_op):
553*da0073e9SAndroid Build Coastguard Worker                super().__init__()
554*da0073e9SAndroid Build Coastguard Worker                self.relu_op = relu_op
555*da0073e9SAndroid Build Coastguard Worker
556*da0073e9SAndroid Build Coastguard Worker            def forward(self, a, b, c):
557*da0073e9SAndroid Build Coastguard Worker                tmp = torch.add(a, b)
558*da0073e9SAndroid Build Coastguard Worker                x = self.relu_op(tmp)
559*da0073e9SAndroid Build Coastguard Worker                d = torch.add(a, c)
560*da0073e9SAndroid Build Coastguard Worker                return x + d
561*da0073e9SAndroid Build Coastguard Worker        a = torch.rand((7, 11))
562*da0073e9SAndroid Build Coastguard Worker        a = a * -10
563*da0073e9SAndroid Build Coastguard Worker        a = a + 5
564*da0073e9SAndroid Build Coastguard Worker        b = torch.rand((7, 11))
565*da0073e9SAndroid Build Coastguard Worker        c = torch.rand((7, 11))
566*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(M(torch.relu))
567*da0073e9SAndroid Build Coastguard Worker        orig_res = m(a, b, c)
568*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_fuse_add_relu(m.graph)
569*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO()
570*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(m, buffer)
571*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
572*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.load(buffer)
573*da0073e9SAndroid Build Coastguard Worker        new_res = m(a, b, c)
574*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("aten::relu(") \
575*da0073e9SAndroid Build Coastguard Worker            .check("aten::_add_relu(") \
576*da0073e9SAndroid Build Coastguard Worker            .run(m.graph)
577*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(orig_res, new_res)
578*da0073e9SAndroid Build Coastguard Worker
579*da0073e9SAndroid Build Coastguard Worker        # add, relu_
580*da0073e9SAndroid Build Coastguard Worker        a = torch.rand((7, 11))
581*da0073e9SAndroid Build Coastguard Worker        a = a * -10
582*da0073e9SAndroid Build Coastguard Worker        a = a + 5
583*da0073e9SAndroid Build Coastguard Worker        b = torch.rand((7, 11))
584*da0073e9SAndroid Build Coastguard Worker        c = torch.rand((7, 11))
585*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(M(torch.relu_))
586*da0073e9SAndroid Build Coastguard Worker        orig_res = m(a, b, c)
587*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_fuse_add_relu(m.graph)
588*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO()
589*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(m, buffer)
590*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
591*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.load(buffer)
592*da0073e9SAndroid Build Coastguard Worker        new_res = m(a, b, c)
593*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("aten::relu_(") \
594*da0073e9SAndroid Build Coastguard Worker            .check("aten::_add_relu(") \
595*da0073e9SAndroid Build Coastguard Worker            .run(m.graph)
596*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(orig_res, new_res)
597*da0073e9SAndroid Build Coastguard Worker
598*da0073e9SAndroid Build Coastguard Worker        class Madd_(torch.nn.Module):
599*da0073e9SAndroid Build Coastguard Worker            def __init__(self, relu_op):
600*da0073e9SAndroid Build Coastguard Worker                super().__init__()
601*da0073e9SAndroid Build Coastguard Worker                self.relu_op = relu_op
602*da0073e9SAndroid Build Coastguard Worker
603*da0073e9SAndroid Build Coastguard Worker            def forward(self, a, b):
604*da0073e9SAndroid Build Coastguard Worker                x = a.add_(b)
605*da0073e9SAndroid Build Coastguard Worker                x = self.relu_op(x)
606*da0073e9SAndroid Build Coastguard Worker                return x
607*da0073e9SAndroid Build Coastguard Worker
608*da0073e9SAndroid Build Coastguard Worker        # add_, relu_
609*da0073e9SAndroid Build Coastguard Worker        a = torch.rand((7, 11))
610*da0073e9SAndroid Build Coastguard Worker        a = a * -10
611*da0073e9SAndroid Build Coastguard Worker        a = a + 5
612*da0073e9SAndroid Build Coastguard Worker        b = torch.rand((7, 11))
613*da0073e9SAndroid Build Coastguard Worker        # Because in place add_ will overwrite a
614*da0073e9SAndroid Build Coastguard Worker        a_copy = a.clone()
615*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(Madd_(torch.relu_))
616*da0073e9SAndroid Build Coastguard Worker        orig_res = m(a, b)
617*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_fuse_add_relu(m.graph)
618*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO()
619*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(m, buffer)
620*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
621*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.load(buffer)
622*da0073e9SAndroid Build Coastguard Worker        new_res = m(a_copy, b)
623*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("aten::add_(") \
624*da0073e9SAndroid Build Coastguard Worker            .check_not("aten::relu_(") \
625*da0073e9SAndroid Build Coastguard Worker            .check("aten::_add_relu_(") \
626*da0073e9SAndroid Build Coastguard Worker            .run(m.graph)
627*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(orig_res, new_res)
628*da0073e9SAndroid Build Coastguard Worker        # Since _add_relu_ does inplace mutation ensure
629*da0073e9SAndroid Build Coastguard Worker        # a_copy is modified
630*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(orig_res, a_copy)
631*da0073e9SAndroid Build Coastguard Worker
632*da0073e9SAndroid Build Coastguard Worker        class Madd_out(torch.nn.Module):
633*da0073e9SAndroid Build Coastguard Worker            def __init__(self, relu_op):
634*da0073e9SAndroid Build Coastguard Worker                super().__init__()
635*da0073e9SAndroid Build Coastguard Worker                self.relu_op = relu_op
636*da0073e9SAndroid Build Coastguard Worker
637*da0073e9SAndroid Build Coastguard Worker            def forward(self, a, b):
638*da0073e9SAndroid Build Coastguard Worker                x = torch.add(a, b, out=a)
639*da0073e9SAndroid Build Coastguard Worker                x = self.relu_op(x)
640*da0073e9SAndroid Build Coastguard Worker                return x
641*da0073e9SAndroid Build Coastguard Worker        a = torch.rand((7, 11))
642*da0073e9SAndroid Build Coastguard Worker        a = a * -10
643*da0073e9SAndroid Build Coastguard Worker        a = a + 5
644*da0073e9SAndroid Build Coastguard Worker        b = torch.rand((7, 11))
645*da0073e9SAndroid Build Coastguard Worker
646*da0073e9SAndroid Build Coastguard Worker        # add_out, relu_
647*da0073e9SAndroid Build Coastguard Worker        a = torch.rand((7, 11))
648*da0073e9SAndroid Build Coastguard Worker        a = a * -10
649*da0073e9SAndroid Build Coastguard Worker        a = a + 5
650*da0073e9SAndroid Build Coastguard Worker        b = torch.rand((7, 11))
651*da0073e9SAndroid Build Coastguard Worker        # Because in place add_ will overwrite a
652*da0073e9SAndroid Build Coastguard Worker        a_copy = a.clone()
653*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(Madd_out(torch.relu_))
654*da0073e9SAndroid Build Coastguard Worker        orig_res = m(a, b)
655*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_fuse_add_relu(m.graph)
656*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO()
657*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(m, buffer)
658*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
659*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.load(buffer)
660*da0073e9SAndroid Build Coastguard Worker        new_res = m(a_copy, b)
661*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("aten::add(") \
662*da0073e9SAndroid Build Coastguard Worker            .check_not("aten::relu_(") \
663*da0073e9SAndroid Build Coastguard Worker            .check("aten::_add_relu(") \
664*da0073e9SAndroid Build Coastguard Worker            .run(m.graph)
665*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(orig_res, new_res)
666*da0073e9SAndroid Build Coastguard Worker        # Since _add_relu_ with out=a does inplace mutation ensure
667*da0073e9SAndroid Build Coastguard Worker        # a_copy is modified
668*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(orig_res, a_copy)
669*da0073e9SAndroid Build Coastguard Worker
670*da0073e9SAndroid Build Coastguard Worker    def test_repeat_interleave_script(self):
671*da0073e9SAndroid Build Coastguard Worker        def fn(input: torch.Tensor, repeats: torch.Tensor) -> torch.Tensor:
672*da0073e9SAndroid Build Coastguard Worker            output = input.repeat_interleave(repeats)
673*da0073e9SAndroid Build Coastguard Worker            return output
674*da0073e9SAndroid Build Coastguard Worker        fn_scripted = torch.jit.script(fn)
675*da0073e9SAndroid Build Coastguard Worker
676*da0073e9SAndroid Build Coastguard Worker        input = torch.tensor([5, 7], dtype=torch.int64)
677*da0073e9SAndroid Build Coastguard Worker        repeats = torch.tensor([3, 6], dtype=torch.int64)
678*da0073e9SAndroid Build Coastguard Worker
679*da0073e9SAndroid Build Coastguard Worker        output = fn(input, repeats)
680*da0073e9SAndroid Build Coastguard Worker        output_scripted = fn_scripted(input, repeats)
681*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_scripted, output)
682*da0073e9SAndroid Build Coastguard Worker
683*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "Simple executor doesn't have shape information")
684*da0073e9SAndroid Build Coastguard Worker    def test_peephole_optimize_shape_ops(self):
685*da0073e9SAndroid Build Coastguard Worker        def test_input(func, input, result):
686*da0073e9SAndroid Build Coastguard Worker            # if result == 2 we will trigger a bailout and
687*da0073e9SAndroid Build Coastguard Worker            # the unprofiled graph should return the correct result
688*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(func(input, profile_and_replay=True), result)
689*da0073e9SAndroid Build Coastguard Worker            gre = func.graph_for(input)
690*da0073e9SAndroid Build Coastguard Worker            FileCheck().check_not("prim::If").run(gre)
691*da0073e9SAndroid Build Coastguard Worker
692*da0073e9SAndroid Build Coastguard Worker        def test_dim():
693*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
694*da0073e9SAndroid Build Coastguard Worker            def func(x):
695*da0073e9SAndroid Build Coastguard Worker                if x.dim() == 1:
696*da0073e9SAndroid Build Coastguard Worker                    return 1
697*da0073e9SAndroid Build Coastguard Worker                else:
698*da0073e9SAndroid Build Coastguard Worker                    return 2
699*da0073e9SAndroid Build Coastguard Worker
700*da0073e9SAndroid Build Coastguard Worker            test_input(func, torch.tensor([0.5]), 1)
701*da0073e9SAndroid Build Coastguard Worker            test_input(func, torch.tensor([[0.5]]), 2)
702*da0073e9SAndroid Build Coastguard Worker        test_dim()
703*da0073e9SAndroid Build Coastguard Worker
704*da0073e9SAndroid Build Coastguard Worker        def test_size_index():
705*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
706*da0073e9SAndroid Build Coastguard Worker            def func(x):
707*da0073e9SAndroid Build Coastguard Worker                if x.size(0) == 1:
708*da0073e9SAndroid Build Coastguard Worker                    return 1
709*da0073e9SAndroid Build Coastguard Worker                else:
710*da0073e9SAndroid Build Coastguard Worker                    return 2
711*da0073e9SAndroid Build Coastguard Worker
712*da0073e9SAndroid Build Coastguard Worker            test_input(func, torch.rand([1, 2]), 1)
713*da0073e9SAndroid Build Coastguard Worker            test_input(func, torch.rand([1, 3]), 1)
714*da0073e9SAndroid Build Coastguard Worker
715*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
716*da0073e9SAndroid Build Coastguard Worker            def neg_index(x):
717*da0073e9SAndroid Build Coastguard Worker                if x.size(-2) == 1:
718*da0073e9SAndroid Build Coastguard Worker                    return 1
719*da0073e9SAndroid Build Coastguard Worker                else:
720*da0073e9SAndroid Build Coastguard Worker                    return 2
721*da0073e9SAndroid Build Coastguard Worker
722*da0073e9SAndroid Build Coastguard Worker            test_input(neg_index, torch.rand([1, 2]), 1)
723*da0073e9SAndroid Build Coastguard Worker            test_input(neg_index, torch.rand([1, 3]), 1)
724*da0073e9SAndroid Build Coastguard Worker
725*da0073e9SAndroid Build Coastguard Worker        if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
726*da0073e9SAndroid Build Coastguard Worker            test_size_index()
727*da0073e9SAndroid Build Coastguard Worker
728*da0073e9SAndroid Build Coastguard Worker        def test_dtype():
729*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
730*da0073e9SAndroid Build Coastguard Worker            def func(x):
731*da0073e9SAndroid Build Coastguard Worker                if x.dtype == torch.float32:
732*da0073e9SAndroid Build Coastguard Worker                    return 1
733*da0073e9SAndroid Build Coastguard Worker                else:
734*da0073e9SAndroid Build Coastguard Worker                    return 2
735*da0073e9SAndroid Build Coastguard Worker
736*da0073e9SAndroid Build Coastguard Worker            test_input(func, torch.tensor(0.5, dtype=torch.float32), 1)
737*da0073e9SAndroid Build Coastguard Worker            test_input(func, torch.tensor(0.5, dtype=torch.int64), 2)
738*da0073e9SAndroid Build Coastguard Worker        test_dtype()
739*da0073e9SAndroid Build Coastguard Worker
740*da0073e9SAndroid Build Coastguard Worker        def test_is_floating_poiint():
741*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
742*da0073e9SAndroid Build Coastguard Worker            def func(x):
743*da0073e9SAndroid Build Coastguard Worker                if x.is_floating_point():
744*da0073e9SAndroid Build Coastguard Worker                    return 1
745*da0073e9SAndroid Build Coastguard Worker                else:
746*da0073e9SAndroid Build Coastguard Worker                    return 2
747*da0073e9SAndroid Build Coastguard Worker
748*da0073e9SAndroid Build Coastguard Worker            test_input(func, torch.tensor(0.5, dtype=torch.float32), 1)
749*da0073e9SAndroid Build Coastguard Worker            test_input(func, torch.tensor(0.5, dtype=torch.int64), 2)
750*da0073e9SAndroid Build Coastguard Worker        test_is_floating_poiint()
751*da0073e9SAndroid Build Coastguard Worker
752*da0073e9SAndroid Build Coastguard Worker        def test_device():
753*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
754*da0073e9SAndroid Build Coastguard Worker            def func_1(x):
755*da0073e9SAndroid Build Coastguard Worker                if x.device == torch.device('cuda:0'):
756*da0073e9SAndroid Build Coastguard Worker                    a = 0
757*da0073e9SAndroid Build Coastguard Worker                else:
758*da0073e9SAndroid Build Coastguard Worker                    a = 1
759*da0073e9SAndroid Build Coastguard Worker                return a
760*da0073e9SAndroid Build Coastguard Worker
761*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
762*da0073e9SAndroid Build Coastguard Worker            def func_2(x):
763*da0073e9SAndroid Build Coastguard Worker                if x.is_cuda:
764*da0073e9SAndroid Build Coastguard Worker                    a = 0
765*da0073e9SAndroid Build Coastguard Worker                else:
766*da0073e9SAndroid Build Coastguard Worker                    a = 1
767*da0073e9SAndroid Build Coastguard Worker                return a
768*da0073e9SAndroid Build Coastguard Worker
769*da0073e9SAndroid Build Coastguard Worker            test_input(func_1, torch.tensor(0.5), 1)
770*da0073e9SAndroid Build Coastguard Worker            test_input(func_2, torch.tensor(0.5), 1)
771*da0073e9SAndroid Build Coastguard Worker
772*da0073e9SAndroid Build Coastguard Worker            if RUN_CUDA:
773*da0073e9SAndroid Build Coastguard Worker                test_input(func_1, torch.tensor(0.5, device="cuda:0"), 0)
774*da0073e9SAndroid Build Coastguard Worker                test_input(func_2, torch.tensor(0.5, device="cuda:0"), 0)
775*da0073e9SAndroid Build Coastguard Worker
776*da0073e9SAndroid Build Coastguard Worker        test_device()
777*da0073e9SAndroid Build Coastguard Worker
778*da0073e9SAndroid Build Coastguard Worker    def test_attrs(self):
779*da0073e9SAndroid Build Coastguard Worker        def foo(x):
780*da0073e9SAndroid Build Coastguard Worker            return (
781*da0073e9SAndroid Build Coastguard Worker                # x.dtype, TODO: dtype long -> instance conversion
782*da0073e9SAndroid Build Coastguard Worker                x.device,
783*da0073e9SAndroid Build Coastguard Worker                x.shape,
784*da0073e9SAndroid Build Coastguard Worker                x.is_cuda,
785*da0073e9SAndroid Build Coastguard Worker                x.is_mkldnn,
786*da0073e9SAndroid Build Coastguard Worker                x.is_quantized,
787*da0073e9SAndroid Build Coastguard Worker                x.requires_grad,
788*da0073e9SAndroid Build Coastguard Worker                x.T,
789*da0073e9SAndroid Build Coastguard Worker                x.mT,
790*da0073e9SAndroid Build Coastguard Worker                x.H,
791*da0073e9SAndroid Build Coastguard Worker                x.mH
792*da0073e9SAndroid Build Coastguard Worker                # x.layout TODO: layout long -> instance conversion
793*da0073e9SAndroid Build Coastguard Worker            )
794*da0073e9SAndroid Build Coastguard Worker
795*da0073e9SAndroid Build Coastguard Worker        scripted = torch.jit.script(foo)
796*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
797*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted(x), foo(x))
798*da0073e9SAndroid Build Coastguard Worker
799*da0073e9SAndroid Build Coastguard Worker    def test_layout(self):
800*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
801*da0073e9SAndroid Build Coastguard Worker        def check(x, y):
802*da0073e9SAndroid Build Coastguard Worker            return x.layout == y.layout
803*da0073e9SAndroid Build Coastguard Worker
804*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
805*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(3, 4)
806*da0073e9SAndroid Build Coastguard Worker
807*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(check(x, y))
808*da0073e9SAndroid Build Coastguard Worker
809*da0073e9SAndroid Build Coastguard Worker    def test_matrix_transpose(self):
810*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
811*da0073e9SAndroid Build Coastguard Worker        def check(x):
812*da0073e9SAndroid Build Coastguard Worker            return torch.equal(x.mT, x.transpose(-2, -1))
813*da0073e9SAndroid Build Coastguard Worker
814*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
815*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(check(x))
816*da0073e9SAndroid Build Coastguard Worker
817*da0073e9SAndroid Build Coastguard Worker    def test_transpose(self):
818*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
819*da0073e9SAndroid Build Coastguard Worker        def check(x):
820*da0073e9SAndroid Build Coastguard Worker            return torch.equal(x.T, x.t())
821*da0073e9SAndroid Build Coastguard Worker
822*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
823*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(check(x))
824*da0073e9SAndroid Build Coastguard Worker
825*da0073e9SAndroid Build Coastguard Worker    def test_matrix_conj_transpose(self):
826*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
827*da0073e9SAndroid Build Coastguard Worker        def check(x):
828*da0073e9SAndroid Build Coastguard Worker            return torch.equal(x.mH, x.transpose(-2, -1).conj())
829*da0073e9SAndroid Build Coastguard Worker
830*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
831*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(check(x))
832*da0073e9SAndroid Build Coastguard Worker
833*da0073e9SAndroid Build Coastguard Worker        x = make_tensor((3, 4), device="cpu", dtype=torch.complex64)
834*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(check(x))
835*da0073e9SAndroid Build Coastguard Worker
836*da0073e9SAndroid Build Coastguard Worker    def test_conj_transpose(self):
837*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
838*da0073e9SAndroid Build Coastguard Worker        def check(x):
839*da0073e9SAndroid Build Coastguard Worker            return torch.equal(x.H, x.t().conj())
840*da0073e9SAndroid Build Coastguard Worker
841*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
842*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(check(x))
843*da0073e9SAndroid Build Coastguard Worker
844*da0073e9SAndroid Build Coastguard Worker        x = make_tensor((3, 4), device="cpu", dtype=torch.complex64)
845*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(check(x))
846*da0073e9SAndroid Build Coastguard Worker
847*da0073e9SAndroid Build Coastguard Worker    def test_T_mT_H_mH(self):
848*da0073e9SAndroid Build Coastguard Worker        def T(x):
849*da0073e9SAndroid Build Coastguard Worker            return x.mT
850*da0073e9SAndroid Build Coastguard Worker
851*da0073e9SAndroid Build Coastguard Worker        def mT(x):
852*da0073e9SAndroid Build Coastguard Worker            return x.mT
853*da0073e9SAndroid Build Coastguard Worker
854*da0073e9SAndroid Build Coastguard Worker        def H(x):
855*da0073e9SAndroid Build Coastguard Worker            return x.H
856*da0073e9SAndroid Build Coastguard Worker
857*da0073e9SAndroid Build Coastguard Worker        def mH(x):
858*da0073e9SAndroid Build Coastguard Worker            return x.mH
859*da0073e9SAndroid Build Coastguard Worker
860*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
861*da0073e9SAndroid Build Coastguard Worker        y = make_tensor((3, 4), device="cpu", dtype=torch.complex64)
862*da0073e9SAndroid Build Coastguard Worker
863*da0073e9SAndroid Build Coastguard Worker        self.checkScript(T, (x, ))
864*da0073e9SAndroid Build Coastguard Worker        self.checkScript(mT, (x, ))
865*da0073e9SAndroid Build Coastguard Worker        self.checkScript(H, (x, ))
866*da0073e9SAndroid Build Coastguard Worker        self.checkScript(mH, (x, ))
867*da0073e9SAndroid Build Coastguard Worker        self.checkScript(T, (y, ))
868*da0073e9SAndroid Build Coastguard Worker        self.checkScript(mT, (y, ))
869*da0073e9SAndroid Build Coastguard Worker        self.checkScript(H, (y, ))
870*da0073e9SAndroid Build Coastguard Worker        self.checkScript(mH, (y, ))
871*da0073e9SAndroid Build Coastguard Worker
872*da0073e9SAndroid Build Coastguard Worker    def test_nn_conv(self):
873*da0073e9SAndroid Build Coastguard Worker        class Mod(nn.Module):
874*da0073e9SAndroid Build Coastguard Worker            def __init__(self, conv):
875*da0073e9SAndroid Build Coastguard Worker                super().__init__()
876*da0073e9SAndroid Build Coastguard Worker                self.conv = conv
877*da0073e9SAndroid Build Coastguard Worker
878*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
879*da0073e9SAndroid Build Coastguard Worker                return self.conv(input)
880*da0073e9SAndroid Build Coastguard Worker
881*da0073e9SAndroid Build Coastguard Worker        inputs = [
882*da0073e9SAndroid Build Coastguard Worker            # Conv
883*da0073e9SAndroid Build Coastguard Worker            (Mod(nn.Conv1d(16, 33, 3, stride=2)), torch.randn(20, 16, 5)),
884*da0073e9SAndroid Build Coastguard Worker            (Mod(nn.Conv2d(16, 33, 3, stride=2)), torch.randn(20, 16, 5, 10)),
885*da0073e9SAndroid Build Coastguard Worker            (Mod(nn.Conv3d(16, 33, 3, stride=2)), torch.randn(20, 16, 3, 5, 4)),
886*da0073e9SAndroid Build Coastguard Worker            # ConvTransposed
887*da0073e9SAndroid Build Coastguard Worker            (Mod(nn.ConvTranspose1d(16, 33, 3, stride=2)), torch.randn(20, 16, 5)),
888*da0073e9SAndroid Build Coastguard Worker            (Mod(nn.ConvTranspose2d(16, 33, 3, stride=2)), torch.randn(20, 16, 5, 10)),
889*da0073e9SAndroid Build Coastguard Worker            (Mod(nn.ConvTranspose3d(16, 33, 3, stride=2)), torch.randn(20, 16, 3, 5, 4)),
890*da0073e9SAndroid Build Coastguard Worker        ]
891*da0073e9SAndroid Build Coastguard Worker
892*da0073e9SAndroid Build Coastguard Worker        for m, inp in inputs:
893*da0073e9SAndroid Build Coastguard Worker            self.checkModule(m, (inp,))
894*da0073e9SAndroid Build Coastguard Worker
895*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, 'Not implemented for Simple or Legacy')
896*da0073e9SAndroid Build Coastguard Worker    def test_debug_flush_compilation_cache(self):
897*da0073e9SAndroid Build Coastguard Worker        def foo(x):
898*da0073e9SAndroid Build Coastguard Worker            return x + 2
899*da0073e9SAndroid Build Coastguard Worker
900*da0073e9SAndroid Build Coastguard Worker        class Mod(nn.Module):
901*da0073e9SAndroid Build Coastguard Worker            def forward(self, t):
902*da0073e9SAndroid Build Coastguard Worker                return t + 2
903*da0073e9SAndroid Build Coastguard Worker
904*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(Mod())
905*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(1, 10)
906*da0073e9SAndroid Build Coastguard Worker
907*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode_for_profiling_tests():
908*da0073e9SAndroid Build Coastguard Worker            jitted = self.checkScript(foo, (x,))
909*da0073e9SAndroid Build Coastguard Worker            # shouldn't throw
910*da0073e9SAndroid Build Coastguard Worker            states = jitted.get_debug_state()
911*da0073e9SAndroid Build Coastguard Worker
912*da0073e9SAndroid Build Coastguard Worker            # after flushing there shouldn't be
913*da0073e9SAndroid Build Coastguard Worker            # no opt plan
914*da0073e9SAndroid Build Coastguard Worker            jitted._debug_flush_compilation_cache()
915*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "INTERNAL ASSERT FAILED"):
916*da0073e9SAndroid Build Coastguard Worker                states = jitted.get_debug_state()
917*da0073e9SAndroid Build Coastguard Worker
918*da0073e9SAndroid Build Coastguard Worker            NUM_RUNS = 1
919*da0073e9SAndroid Build Coastguard Worker            with num_profiled_runs(NUM_RUNS):
920*da0073e9SAndroid Build Coastguard Worker                m(x)
921*da0073e9SAndroid Build Coastguard Worker                m(x)
922*da0073e9SAndroid Build Coastguard Worker                fwd = m._c._get_method("forward")
923*da0073e9SAndroid Build Coastguard Worker                states = m.get_debug_state()
924*da0073e9SAndroid Build Coastguard Worker
925*da0073e9SAndroid Build Coastguard Worker                # after flushing there shouldn't be
926*da0073e9SAndroid Build Coastguard Worker                # no opt plan
927*da0073e9SAndroid Build Coastguard Worker                fwd._debug_flush_compilation_cache()
928*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, "INTERNAL ASSERT FAILED"):
929*da0073e9SAndroid Build Coastguard Worker                    states = m.get_debug_state()
930*da0073e9SAndroid Build Coastguard Worker
931*da0073e9SAndroid Build Coastguard Worker    def test_numel(self):
932*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
933*da0073e9SAndroid Build Coastguard Worker        def get_numel_script(x):
934*da0073e9SAndroid Build Coastguard Worker            return x.numel()
935*da0073e9SAndroid Build Coastguard Worker
936*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
937*da0073e9SAndroid Build Coastguard Worker        numel = get_numel_script(x)
938*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(numel, x.numel())
939*da0073e9SAndroid Build Coastguard Worker
940*da0073e9SAndroid Build Coastguard Worker    def test_element_size(self):
941*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
942*da0073e9SAndroid Build Coastguard Worker        def get_element_size_script(x):
943*da0073e9SAndroid Build Coastguard Worker            return x.element_size()
944*da0073e9SAndroid Build Coastguard Worker
945*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
946*da0073e9SAndroid Build Coastguard Worker        element_size = get_element_size_script(x)
947*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(element_size, x.element_size())
948*da0073e9SAndroid Build Coastguard Worker
949*da0073e9SAndroid Build Coastguard Worker    def test_Sequential(self):
950*da0073e9SAndroid Build Coastguard Worker        class Seq(nn.Module):
951*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
952*da0073e9SAndroid Build Coastguard Worker                super().__init__()
953*da0073e9SAndroid Build Coastguard Worker                self.seq = nn.Sequential(nn.Linear(10, 20), nn.Linear(20, 30))
954*da0073e9SAndroid Build Coastguard Worker
955*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
956*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
957*da0073e9SAndroid Build Coastguard Worker                for l in self.seq:
958*da0073e9SAndroid Build Coastguard Worker                    x = l(x)
959*da0073e9SAndroid Build Coastguard Worker                return x
960*da0073e9SAndroid Build Coastguard Worker
961*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(Seq())
962*da0073e9SAndroid Build Coastguard Worker        assert m.graph  # ensure jit was able to compile
963*da0073e9SAndroid Build Coastguard Worker
964*da0073e9SAndroid Build Coastguard Worker    def test_ModuleList(self):
965*da0073e9SAndroid Build Coastguard Worker        class Mod(nn.Module):
966*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
967*da0073e9SAndroid Build Coastguard Worker                super().__init__()
968*da0073e9SAndroid Build Coastguard Worker                self.model = nn.ModuleList([nn.Linear(10, 10) for _ in range(10)])
969*da0073e9SAndroid Build Coastguard Worker                self.model += (nn.Linear(10, 20),)
970*da0073e9SAndroid Build Coastguard Worker                self.model.append(nn.Linear(20, 30))
971*da0073e9SAndroid Build Coastguard Worker                self.model.extend([nn.Linear(30, 40), nn.Linear(40, 50)])
972*da0073e9SAndroid Build Coastguard Worker
973*da0073e9SAndroid Build Coastguard Worker            def forward(self, v):
974*da0073e9SAndroid Build Coastguard Worker                for m in self.model:
975*da0073e9SAndroid Build Coastguard Worker                    v = m(v)
976*da0073e9SAndroid Build Coastguard Worker                return v
977*da0073e9SAndroid Build Coastguard Worker
978*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(Mod())
979*da0073e9SAndroid Build Coastguard Worker        assert m.graph  # ensure jit was able to compile
980*da0073e9SAndroid Build Coastguard Worker
981*da0073e9SAndroid Build Coastguard Worker    def test_disabled(self):
982*da0073e9SAndroid Build Coastguard Worker        torch.jit._state.disable()
983*da0073e9SAndroid Build Coastguard Worker        try:
984*da0073e9SAndroid Build Coastguard Worker            def f(x, y):
985*da0073e9SAndroid Build Coastguard Worker                return x + y
986*da0073e9SAndroid Build Coastguard Worker
987*da0073e9SAndroid Build Coastguard Worker            self.assertIs(torch.jit.trace(f, (torch.randn(2, 2), torch.randn(2, 2))), f)
988*da0073e9SAndroid Build Coastguard Worker            self.assertIs(torch.jit.script(f), f)
989*da0073e9SAndroid Build Coastguard Worker
990*da0073e9SAndroid Build Coastguard Worker            class MyModule(torch.jit.ScriptModule):
991*da0073e9SAndroid Build Coastguard Worker                @torch.jit.script_method
992*da0073e9SAndroid Build Coastguard Worker                def method(self, x):
993*da0073e9SAndroid Build Coastguard Worker                    return x
994*da0073e9SAndroid Build Coastguard Worker
995*da0073e9SAndroid Build Coastguard Worker            # XXX: Unfortunately ScriptModule won't simply become Module now,
996*da0073e9SAndroid Build Coastguard Worker            # because that requires disabling the JIT at startup time, which
997*da0073e9SAndroid Build Coastguard Worker            # we can't do in here.
998*da0073e9SAndroid Build Coastguard Worker            # We need to or those two conditions to make it work with all versions of Python
999*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(inspect.ismethod(MyModule.method) or inspect.isfunction(MyModule.method))
1000*da0073e9SAndroid Build Coastguard Worker        finally:
1001*da0073e9SAndroid Build Coastguard Worker            torch.jit._state.enable()
1002*da0073e9SAndroid Build Coastguard Worker
1003*da0073e9SAndroid Build Coastguard Worker    def test_train_eval(self):
1004*da0073e9SAndroid Build Coastguard Worker        class Sub(nn.Module):
1005*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
1006*da0073e9SAndroid Build Coastguard Worker                if self.training:
1007*da0073e9SAndroid Build Coastguard Worker                    return input
1008*da0073e9SAndroid Build Coastguard Worker                else:
1009*da0073e9SAndroid Build Coastguard Worker                    return -input
1010*da0073e9SAndroid Build Coastguard Worker
1011*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.jit.ScriptModule):
1012*da0073e9SAndroid Build Coastguard Worker            def __init__(self, module):
1013*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1014*da0073e9SAndroid Build Coastguard Worker                self.module = module
1015*da0073e9SAndroid Build Coastguard Worker
1016*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
1017*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
1018*da0073e9SAndroid Build Coastguard Worker                return self.module(input) + 1
1019*da0073e9SAndroid Build Coastguard Worker
1020*da0073e9SAndroid Build Coastguard Worker        m = MyModule(Sub())
1021*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(3, 4)
1022*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(input + 1, m(input))
1023*da0073e9SAndroid Build Coastguard Worker        m.eval()
1024*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(-input + 1, m(input))
1025*da0073e9SAndroid Build Coastguard Worker
1026*da0073e9SAndroid Build Coastguard Worker        # test batchnorm and dropout train/eval
1027*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(6, 10)
1028*da0073e9SAndroid Build Coastguard Worker        batchnorm = nn.BatchNorm1d(10)
1029*da0073e9SAndroid Build Coastguard Worker        dropout = nn.Dropout(p=0.2)
1030*da0073e9SAndroid Build Coastguard Worker
1031*da0073e9SAndroid Build Coastguard Worker        m_batchnorm = MyModule(batchnorm)
1032*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(batchnorm(input) + 1, m_batchnorm(input))
1033*da0073e9SAndroid Build Coastguard Worker        batchnorm.eval()
1034*da0073e9SAndroid Build Coastguard Worker        m_batchnorm.eval()
1035*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(batchnorm(input) + 1, m_batchnorm(input))
1036*da0073e9SAndroid Build Coastguard Worker
1037*da0073e9SAndroid Build Coastguard Worker        m_dropout = MyModule(dropout)
1038*da0073e9SAndroid Build Coastguard Worker        dropout.eval()
1039*da0073e9SAndroid Build Coastguard Worker        m_dropout.eval()
1040*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dropout(input) + 1, m_dropout(input))
1041*da0073e9SAndroid Build Coastguard Worker
1042*da0073e9SAndroid Build Coastguard Worker    def test_nn_lp_pool2d(self):
1043*da0073e9SAndroid Build Coastguard Worker        class Mod(torch.nn.Module):
1044*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1045*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1046*da0073e9SAndroid Build Coastguard Worker                self.l = torch.nn.LPPool2d(2, 3)
1047*da0073e9SAndroid Build Coastguard Worker                self.n = torch.nn.LPPool2d(2, (7, 1))
1048*da0073e9SAndroid Build Coastguard Worker
1049*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1050*da0073e9SAndroid Build Coastguard Worker                return (self.l(x),
1051*da0073e9SAndroid Build Coastguard Worker                        self.n(x),
1052*da0073e9SAndroid Build Coastguard Worker                        torch.nn.functional.lp_pool2d(x, float(2), 3),
1053*da0073e9SAndroid Build Coastguard Worker                        torch.nn.functional.lp_pool2d(x, 2, 3),
1054*da0073e9SAndroid Build Coastguard Worker                        torch.nn.functional.lp_pool2d(x, float(2), (7, 1)))
1055*da0073e9SAndroid Build Coastguard Worker
1056*da0073e9SAndroid Build Coastguard Worker        self.checkModule(Mod(), (torch.rand(1, 3, 7, 7),))
1057*da0073e9SAndroid Build Coastguard Worker
1058*da0073e9SAndroid Build Coastguard Worker    def test_nn_lp_pool1d(self):
1059*da0073e9SAndroid Build Coastguard Worker        class Mod(torch.nn.Module):
1060*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1061*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1062*da0073e9SAndroid Build Coastguard Worker                self.l = torch.nn.LPPool1d(2, 3)
1063*da0073e9SAndroid Build Coastguard Worker                self.n = torch.nn.LPPool1d(2, 7)
1064*da0073e9SAndroid Build Coastguard Worker
1065*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1066*da0073e9SAndroid Build Coastguard Worker                return (self.l(x),
1067*da0073e9SAndroid Build Coastguard Worker                        self.n(x),
1068*da0073e9SAndroid Build Coastguard Worker                        torch.nn.functional.lp_pool1d(x, float(2), 3),
1069*da0073e9SAndroid Build Coastguard Worker                        torch.nn.functional.lp_pool1d(x, 2, 3),
1070*da0073e9SAndroid Build Coastguard Worker                        torch.nn.functional.lp_pool1d(x, float(2), 7))
1071*da0073e9SAndroid Build Coastguard Worker
1072*da0073e9SAndroid Build Coastguard Worker        self.checkModule(Mod(), (torch.rand(1, 3, 7),))
1073*da0073e9SAndroid Build Coastguard Worker
1074*da0073e9SAndroid Build Coastguard Worker    def test_nn_padding_functional(self):
1075*da0073e9SAndroid Build Coastguard Worker        class Mod(nn.Module):
1076*da0073e9SAndroid Build Coastguard Worker            def __init__(self, *pad):
1077*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1078*da0073e9SAndroid Build Coastguard Worker                self.pad = pad
1079*da0073e9SAndroid Build Coastguard Worker
1080*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1081*da0073e9SAndroid Build Coastguard Worker                return F.pad(x, self.pad, mode='constant', value=3.5)
1082*da0073e9SAndroid Build Coastguard Worker
1083*da0073e9SAndroid Build Coastguard Worker        inputs = [
1084*da0073e9SAndroid Build Coastguard Worker            (Mod(1, 2), torch.randn(1, 3, 4)),  # 1D
1085*da0073e9SAndroid Build Coastguard Worker            (Mod(1, 2, 3, 4), torch.randn(1, 3, 4)),  # 2D
1086*da0073e9SAndroid Build Coastguard Worker            (Mod(1, 2, 3, 4, 5, 6), torch.randn(1, 3, 4)),  # 3D
1087*da0073e9SAndroid Build Coastguard Worker        ]
1088*da0073e9SAndroid Build Coastguard Worker
1089*da0073e9SAndroid Build Coastguard Worker        for m, inp in inputs:
1090*da0073e9SAndroid Build Coastguard Worker            self.checkModule(m, (inp,))
1091*da0073e9SAndroid Build Coastguard Worker
1092*da0073e9SAndroid Build Coastguard Worker    def test_nn_padding(self):
1093*da0073e9SAndroid Build Coastguard Worker        class Mod(nn.Module):
1094*da0073e9SAndroid Build Coastguard Worker            def __init__(self, padding):
1095*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1096*da0073e9SAndroid Build Coastguard Worker                self.padding = padding
1097*da0073e9SAndroid Build Coastguard Worker
1098*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
1099*da0073e9SAndroid Build Coastguard Worker                return self.padding(input)
1100*da0073e9SAndroid Build Coastguard Worker
1101*da0073e9SAndroid Build Coastguard Worker        inputs = [
1102*da0073e9SAndroid Build Coastguard Worker            (Mod(nn.ConstantPad1d(2, 3.5)), torch.randn(1, 2, 4)),
1103*da0073e9SAndroid Build Coastguard Worker            (Mod(nn.ConstantPad2d(2, 3.5)), torch.randn(1, 2, 2)),
1104*da0073e9SAndroid Build Coastguard Worker            (Mod(nn.ConstantPad3d(3, 3.5)), torch.randn(16, 3, 10, 20, 30)),
1105*da0073e9SAndroid Build Coastguard Worker            (Mod(nn.ReflectionPad1d(2)), torch.arange(8, dtype=torch.float).reshape(1, 2, 4)),
1106*da0073e9SAndroid Build Coastguard Worker            (Mod(nn.ReflectionPad2d(2)), torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3)),
1107*da0073e9SAndroid Build Coastguard Worker            (Mod(nn.ReflectionPad3d(3)), torch.randn(16, 3, 8, 32, 48)),
1108*da0073e9SAndroid Build Coastguard Worker            (Mod(nn.ReplicationPad1d(2)), torch.arange(8, dtype=torch.float).reshape(1, 2, 4)),
1109*da0073e9SAndroid Build Coastguard Worker            (Mod(nn.ReplicationPad2d(2)), torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3)),
1110*da0073e9SAndroid Build Coastguard Worker            (Mod(nn.ReplicationPad3d(3)), torch.randn(16, 3, 8, 32, 48)),
1111*da0073e9SAndroid Build Coastguard Worker            (Mod(nn.ZeroPad2d(2)), torch.randn(1, 1, 3, 3))
1112*da0073e9SAndroid Build Coastguard Worker        ]
1113*da0073e9SAndroid Build Coastguard Worker
1114*da0073e9SAndroid Build Coastguard Worker        for m, inp in inputs:
1115*da0073e9SAndroid Build Coastguard Worker            self.checkModule(m, (inp,))
1116*da0073e9SAndroid Build Coastguard Worker
1117*da0073e9SAndroid Build Coastguard Worker    def test_script_autograd_grad(self):
1118*da0073e9SAndroid Build Coastguard Worker        def test_simple_grad(x, y):
1119*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor, Tensor) -> List[Optional[Tensor]]
1120*da0073e9SAndroid Build Coastguard Worker            z = x + 2 * y + x * y
1121*da0073e9SAndroid Build Coastguard Worker            return torch.autograd.grad((z.sum(), ), (x, y))
1122*da0073e9SAndroid Build Coastguard Worker
1123*da0073e9SAndroid Build Coastguard Worker        def test_simple_grad_with_grad_outputs(x, y):
1124*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor, Tensor) -> List[Optional[Tensor]]
1125*da0073e9SAndroid Build Coastguard Worker            z = x + 2 * y + x * y
1126*da0073e9SAndroid Build Coastguard Worker            grad_outputs = torch.jit.annotate(List[Optional[torch.Tensor]], [torch.ones((2, 2)), ])
1127*da0073e9SAndroid Build Coastguard Worker            return torch.autograd.grad((z, ), (x, y), grad_outputs)
1128*da0073e9SAndroid Build Coastguard Worker
1129*da0073e9SAndroid Build Coastguard Worker        def test_one_output_not_requires_grad(x, y):
1130*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor, Tensor) -> List[Optional[Tensor]]
1131*da0073e9SAndroid Build Coastguard Worker            z = 2 * y + y
1132*da0073e9SAndroid Build Coastguard Worker            return torch.autograd.grad((z.sum(),), (x, y), allow_unused=True)
1133*da0073e9SAndroid Build Coastguard Worker
1134*da0073e9SAndroid Build Coastguard Worker        def test_retain_graph(x, y):
1135*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor, Tensor) -> None
1136*da0073e9SAndroid Build Coastguard Worker            z = x + 2 * y + x * y
1137*da0073e9SAndroid Build Coastguard Worker            torch.autograd.grad((z.sum(), ), (x, y), retain_graph=True)
1138*da0073e9SAndroid Build Coastguard Worker            torch.autograd.grad((z.sum(), ), (x, y))
1139*da0073e9SAndroid Build Coastguard Worker
1140*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2, requires_grad=True)
1141*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2, 2, requires_grad=True)
1142*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_simple_grad, (x, y), inputs_requires_grad=True)
1143*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_simple_grad_with_grad_outputs, (x, y), inputs_requires_grad=True)
1144*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_one_output_not_requires_grad, (x, y), inputs_requires_grad=True)
1145*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_retain_graph, (x, y), inputs_requires_grad=True)
1146*da0073e9SAndroid Build Coastguard Worker
1147*da0073e9SAndroid Build Coastguard Worker    def test_script_backward(self):
1148*da0073e9SAndroid Build Coastguard Worker        def checkBackwardScript(fn, inputs):
1149*da0073e9SAndroid Build Coastguard Worker            scripted_fn = torch.jit.script(fn)
1150*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("torch.autograd.backward").run(scripted_fn.code)
1151*da0073e9SAndroid Build Coastguard Worker            recording_inputs = do_input_map(lambda t: t.detach().requires_grad_(), inputs)
1152*da0073e9SAndroid Build Coastguard Worker
1153*da0073e9SAndroid Build Coastguard Worker            fn(*inputs)
1154*da0073e9SAndroid Build Coastguard Worker            scripted_fn(*recording_inputs)
1155*da0073e9SAndroid Build Coastguard Worker
1156*da0073e9SAndroid Build Coastguard Worker            for inp1, inp2 in zip(inputs, recording_inputs):
1157*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(inp1.grad, inp2.grad)
1158*da0073e9SAndroid Build Coastguard Worker
1159*da0073e9SAndroid Build Coastguard Worker        def test_tensor_backward(input):
1160*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor) -> None
1161*da0073e9SAndroid Build Coastguard Worker            output = torch.relu(input)
1162*da0073e9SAndroid Build Coastguard Worker            output = output.softmax(0)
1163*da0073e9SAndroid Build Coastguard Worker            sum_out = output.sum()
1164*da0073e9SAndroid Build Coastguard Worker            sum_out.backward()
1165*da0073e9SAndroid Build Coastguard Worker
1166*da0073e9SAndroid Build Coastguard Worker        def test_torch_autograd_backward(input):
1167*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor) -> None
1168*da0073e9SAndroid Build Coastguard Worker            output = torch.relu(input)
1169*da0073e9SAndroid Build Coastguard Worker            output = output.softmax(0)
1170*da0073e9SAndroid Build Coastguard Worker            torch.autograd.backward(output.sum())
1171*da0073e9SAndroid Build Coastguard Worker
1172*da0073e9SAndroid Build Coastguard Worker        def test_torch_autograd_backward_with_grad_tensors(input):
1173*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor) -> None
1174*da0073e9SAndroid Build Coastguard Worker            output = torch.relu(input)
1175*da0073e9SAndroid Build Coastguard Worker            output = output.softmax(0)
1176*da0073e9SAndroid Build Coastguard Worker            grad_outputs = torch.jit.annotate(List[Optional[torch.Tensor]], [torch.ones((2, 2)), ])
1177*da0073e9SAndroid Build Coastguard Worker            torch.autograd.backward((output,), grad_outputs)
1178*da0073e9SAndroid Build Coastguard Worker
1179*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(2, 2, requires_grad=True)
1180*da0073e9SAndroid Build Coastguard Worker        checkBackwardScript(test_tensor_backward, (inp,))
1181*da0073e9SAndroid Build Coastguard Worker        checkBackwardScript(test_torch_autograd_backward, (inp,))
1182*da0073e9SAndroid Build Coastguard Worker        checkBackwardScript(test_torch_autograd_backward_with_grad_tensors, (inp,))
1183*da0073e9SAndroid Build Coastguard Worker
1184*da0073e9SAndroid Build Coastguard Worker    def test_script_backward_twice(self):
1185*da0073e9SAndroid Build Coastguard Worker        def checkBackwardTwiceScript(fn, inputs, retain_graph_=False):
1186*da0073e9SAndroid Build Coastguard Worker            class jit_profiling_executor_false:
1187*da0073e9SAndroid Build Coastguard Worker                def __enter__(self):
1188*da0073e9SAndroid Build Coastguard Worker                    torch._C._jit_set_profiling_executor(False)
1189*da0073e9SAndroid Build Coastguard Worker
1190*da0073e9SAndroid Build Coastguard Worker                def __exit__(self, *args):
1191*da0073e9SAndroid Build Coastguard Worker                    torch._C._jit_set_profiling_executor(GRAPH_EXECUTOR != ProfilingMode.LEGACY)
1192*da0073e9SAndroid Build Coastguard Worker
1193*da0073e9SAndroid Build Coastguard Worker            with jit_profiling_executor_false(), torch.jit.optimized_execution(True):
1194*da0073e9SAndroid Build Coastguard Worker                scripted_fn = torch.jit.script(fn, inputs)
1195*da0073e9SAndroid Build Coastguard Worker                FileCheck().check("prim::DifferentiableGraph").run(scripted_fn.graph_for(*inputs))
1196*da0073e9SAndroid Build Coastguard Worker
1197*da0073e9SAndroid Build Coastguard Worker                result = scripted_fn(*inputs)
1198*da0073e9SAndroid Build Coastguard Worker                result.sum().backward(retain_graph=retain_graph_)
1199*da0073e9SAndroid Build Coastguard Worker                if not retain_graph_:
1200*da0073e9SAndroid Build Coastguard Worker                    self.assertRaisesRegex(RuntimeError, 'Specify retain_graph=True',
1201*da0073e9SAndroid Build Coastguard Worker                                           lambda: result.sum().backward())
1202*da0073e9SAndroid Build Coastguard Worker                else:
1203*da0073e9SAndroid Build Coastguard Worker                    result.sum().backward()
1204*da0073e9SAndroid Build Coastguard Worker
1205*da0073e9SAndroid Build Coastguard Worker        def test_script_backward_twice_with_saved_values(input1, input2):
1206*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor, Tensor) -> Tensor
1207*da0073e9SAndroid Build Coastguard Worker            tmp1 = torch.mul(input1, input2)
1208*da0073e9SAndroid Build Coastguard Worker            tmp2 = torch.abs(tmp1)
1209*da0073e9SAndroid Build Coastguard Worker            if torch.equal(input1, input2):
1210*da0073e9SAndroid Build Coastguard Worker                tmp2 = torch.acos(tmp2)
1211*da0073e9SAndroid Build Coastguard Worker            else:
1212*da0073e9SAndroid Build Coastguard Worker                tmp2 = torch.atan(tmp2)
1213*da0073e9SAndroid Build Coastguard Worker            result = torch.add(tmp2, input2)
1214*da0073e9SAndroid Build Coastguard Worker            return result
1215*da0073e9SAndroid Build Coastguard Worker
1216*da0073e9SAndroid Build Coastguard Worker        inp1 = torch.randn(2, 2, requires_grad=True)
1217*da0073e9SAndroid Build Coastguard Worker        inp2 = torch.randn(2, 2, requires_grad=True)
1218*da0073e9SAndroid Build Coastguard Worker        checkBackwardTwiceScript(test_script_backward_twice_with_saved_values, (inp1, inp2), False)
1219*da0073e9SAndroid Build Coastguard Worker        checkBackwardTwiceScript(test_script_backward_twice_with_saved_values, (inp1, inp2), True)
1220*da0073e9SAndroid Build Coastguard Worker
1221*da0073e9SAndroid Build Coastguard Worker    def test_diff_subgraph_clones_constants(self):
1222*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
1223*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
1224*da0073e9SAndroid Build Coastguard Worker            return x + x + y + x + y + x + y + x + y + x
1225*da0073e9SAndroid Build Coastguard Worker
1226*da0073e9SAndroid Build Coastguard Worker        def count_constants(graph):
1227*da0073e9SAndroid Build Coastguard Worker            return sum(node.kind() == 'prim::Constant' for node in graph.nodes())
1228*da0073e9SAndroid Build Coastguard Worker
1229*da0073e9SAndroid Build Coastguard Worker        graph = f.graph.copy()
1230*da0073e9SAndroid Build Coastguard Worker        self.run_pass('cse', graph)
1231*da0073e9SAndroid Build Coastguard Worker        self.run_pass('create_autodiff_subgraphs', graph)
1232*da0073e9SAndroid Build Coastguard Worker        nodes = list(graph.nodes())
1233*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(count_constants(graph), 1)
1234*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(count_constants(nodes[1].g('Subgraph')), 1)
1235*da0073e9SAndroid Build Coastguard Worker
1236*da0073e9SAndroid Build Coastguard Worker    # TODO: adapt this test to check that GraphExecutor treats them differently
1237*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("Need to be adjusted to Graph Executor")
1238*da0073e9SAndroid Build Coastguard Worker    def test_arg_configurations(self):
1239*da0073e9SAndroid Build Coastguard Worker        """Different arg configurations should trigger different traces"""
1240*da0073e9SAndroid Build Coastguard Worker        x = Variable(torch.FloatTensor(4, 4).uniform_())
1241*da0073e9SAndroid Build Coastguard Worker        x_double = Variable(x.data.double())
1242*da0073e9SAndroid Build Coastguard Worker        x_grad = Variable(x.data.clone(), requires_grad=True)
1243*da0073e9SAndroid Build Coastguard Worker        y = Variable(torch.randn(4))
1244*da0073e9SAndroid Build Coastguard Worker
1245*da0073e9SAndroid Build Coastguard Worker        configurations = [
1246*da0073e9SAndroid Build Coastguard Worker            (x,),
1247*da0073e9SAndroid Build Coastguard Worker            (x_double,),
1248*da0073e9SAndroid Build Coastguard Worker            (x_grad,),
1249*da0073e9SAndroid Build Coastguard Worker            (y,),
1250*da0073e9SAndroid Build Coastguard Worker            ([x, x],),
1251*da0073e9SAndroid Build Coastguard Worker            ([x, y],),
1252*da0073e9SAndroid Build Coastguard Worker        ]
1253*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
1254*da0073e9SAndroid Build Coastguard Worker            x_cuda = Variable(x.data.cuda())
1255*da0073e9SAndroid Build Coastguard Worker            configurations += [
1256*da0073e9SAndroid Build Coastguard Worker                (x_cuda,),
1257*da0073e9SAndroid Build Coastguard Worker                ([x, x_cuda],),
1258*da0073e9SAndroid Build Coastguard Worker                ([x_cuda, x],),
1259*da0073e9SAndroid Build Coastguard Worker                ([[x_cuda, x]],),
1260*da0073e9SAndroid Build Coastguard Worker            ]
1261*da0073e9SAndroid Build Coastguard Worker            if torch.cuda.device_count() > 1:
1262*da0073e9SAndroid Build Coastguard Worker                x_cuda_1 = Variable(x.data.cuda(1))
1263*da0073e9SAndroid Build Coastguard Worker                configurations += [
1264*da0073e9SAndroid Build Coastguard Worker                    (x_cuda_1,),
1265*da0073e9SAndroid Build Coastguard Worker                    ([x_cuda, x_cuda_1],),
1266*da0073e9SAndroid Build Coastguard Worker                ]
1267*da0073e9SAndroid Build Coastguard Worker
1268*da0073e9SAndroid Build Coastguard Worker        @torch.jit.compile(nderivs=0)
1269*da0073e9SAndroid Build Coastguard Worker        def fn(*args):
1270*da0073e9SAndroid Build Coastguard Worker            in_vars, _ = torch._C._jit_flatten(args)
1271*da0073e9SAndroid Build Coastguard Worker            return in_vars[0] + 1
1272*da0073e9SAndroid Build Coastguard Worker
1273*da0073e9SAndroid Build Coastguard Worker        for i, config in enumerate(configurations):
1274*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(fn.has_trace_for(*config))
1275*da0073e9SAndroid Build Coastguard Worker            fn(*config)
1276*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(fn.has_trace_for(*config))
1277*da0073e9SAndroid Build Coastguard Worker            for unk_config in configurations[i + 1:]:
1278*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(fn.has_trace_for(*unk_config))
1279*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn.hits, 0)
1280*da0073e9SAndroid Build Coastguard Worker
1281*da0073e9SAndroid Build Coastguard Worker    def test_torch_sum(self):
1282*da0073e9SAndroid Build Coastguard Worker        def fn(x):
1283*da0073e9SAndroid Build Coastguard Worker            return torch.sum(x)
1284*da0073e9SAndroid Build Coastguard Worker
1285*da0073e9SAndroid Build Coastguard Worker        def fn1(x, dim: int):
1286*da0073e9SAndroid Build Coastguard Worker            return torch.sum(x, dim)
1287*da0073e9SAndroid Build Coastguard Worker
1288*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 4)
1289*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (x, ))
1290*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn1, (x, 1, ))
1291*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn1, (x, 0, ))
1292*da0073e9SAndroid Build Coastguard Worker
1293*da0073e9SAndroid Build Coastguard Worker    def test_cse(self):
1294*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([0.4, 0.3], requires_grad=True)
1295*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor([0.7, 0.5], requires_grad=True)
1296*da0073e9SAndroid Build Coastguard Worker
1297*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
1298*da0073e9SAndroid Build Coastguard Worker            w = (x + y) * (x + y) * (x + y)
1299*da0073e9SAndroid Build Coastguard Worker            t = torch.tanh(w) + torch.tanh(w)
1300*da0073e9SAndroid Build Coastguard Worker            z = (x + y) * (x + y) * (x + y) + t
1301*da0073e9SAndroid Build Coastguard Worker            return z
1302*da0073e9SAndroid Build Coastguard Worker
1303*da0073e9SAndroid Build Coastguard Worker        g, _ = torch.jit._get_trace_graph(fn, (x, y))
1304*da0073e9SAndroid Build Coastguard Worker        self.run_pass('cse', g)
1305*da0073e9SAndroid Build Coastguard Worker        do_exactly = True
1306*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("add", 1).check_count("mul", 2, do_exactly) \
1307*da0073e9SAndroid Build Coastguard Worker            .check_count("tanh", 1, do_exactly).check_count("add", 2, do_exactly).check_next("return")  \
1308*da0073e9SAndroid Build Coastguard Worker            .run(str(g))
1309*da0073e9SAndroid Build Coastguard Worker
1310*da0073e9SAndroid Build Coastguard Worker        self.assertExportImport(g, (x, y))
1311*da0073e9SAndroid Build Coastguard Worker
1312*da0073e9SAndroid Build Coastguard Worker    def test_cse_not_introduce_aliasing(self):
1313*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
1314*da0073e9SAndroid Build Coastguard Worker        def tensor_alias_outputs(x):
1315*da0073e9SAndroid Build Coastguard Worker            return x + x, x + x
1316*da0073e9SAndroid Build Coastguard Worker
1317*da0073e9SAndroid Build Coastguard Worker        self.run_pass('cse', tensor_alias_outputs.graph)
1318*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("aten::add", 2).run(tensor_alias_outputs.graph)
1319*da0073e9SAndroid Build Coastguard Worker
1320*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
1321*da0073e9SAndroid Build Coastguard Worker        def ints_alias_outputs(x):
1322*da0073e9SAndroid Build Coastguard Worker            # type: (int) -> Tuple[int, int]
1323*da0073e9SAndroid Build Coastguard Worker            return x + x, x + x
1324*da0073e9SAndroid Build Coastguard Worker
1325*da0073e9SAndroid Build Coastguard Worker        # non-aliasing types can be CSEd
1326*da0073e9SAndroid Build Coastguard Worker        self.run_pass('cse', ints_alias_outputs.graph)
1327*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("aten::add", 1, exactly=True).run(ints_alias_outputs.graph)
1328*da0073e9SAndroid Build Coastguard Worker
1329*da0073e9SAndroid Build Coastguard Worker    def test_recursive_cse(self):
1330*da0073e9SAndroid Build Coastguard Worker        input_str = """
1331*da0073e9SAndroid Build Coastguard Workergraph(%x : Tensor,
1332*da0073e9SAndroid Build Coastguard Worker      %y : Tensor,
1333*da0073e9SAndroid Build Coastguard Worker      %20 : int):
1334*da0073e9SAndroid Build Coastguard Worker  %2 : int = prim::Constant[value=1]()
1335*da0073e9SAndroid Build Coastguard Worker  %3 : Tensor = aten::add(%x, %y, %2)
1336*da0073e9SAndroid Build Coastguard Worker  %4 : int = aten::add(%2, %20)
1337*da0073e9SAndroid Build Coastguard Worker  %5 : bool = aten::Bool(%4)
1338*da0073e9SAndroid Build Coastguard Worker  %z : int = prim::If(%5)
1339*da0073e9SAndroid Build Coastguard Worker    # CHECK: block
1340*da0073e9SAndroid Build Coastguard Worker    block0():
1341*da0073e9SAndroid Build Coastguard Worker      # CHECK-NOT: aten::add
1342*da0073e9SAndroid Build Coastguard Worker      %z.1 : int = aten::add(%2, %20)
1343*da0073e9SAndroid Build Coastguard Worker      -> (%z.1)
1344*da0073e9SAndroid Build Coastguard Worker    block1():
1345*da0073e9SAndroid Build Coastguard Worker      -> (%2)
1346*da0073e9SAndroid Build Coastguard Worker  return (%z)
1347*da0073e9SAndroid Build Coastguard Worker"""
1348*da0073e9SAndroid Build Coastguard Worker        graph = parse_ir(input_str)
1349*da0073e9SAndroid Build Coastguard Worker        self.run_pass('cse', graph)
1350*da0073e9SAndroid Build Coastguard Worker        FileCheck().run(input_str, graph)
1351*da0073e9SAndroid Build Coastguard Worker
1352*da0073e9SAndroid Build Coastguard Worker    def test_pattern_based_rewrite(self):
1353*da0073e9SAndroid Build Coastguard Worker        # mul(mul(mul(mul(x,y),z),x),y) --> mul(mul(mulmul(x,y,z), x), y) -->
1354*da0073e9SAndroid Build Coastguard Worker        # --> mulmul(mulmul(x,y,z), x, y)
1355*da0073e9SAndroid Build Coastguard Worker        input_str = """
1356*da0073e9SAndroid Build Coastguard Workergraph(%x, %y, %z):
1357*da0073e9SAndroid Build Coastguard Worker    # CHECK-NOT: aten::mul
1358*da0073e9SAndroid Build Coastguard Worker    # CHECK: my::fused_mulmul
1359*da0073e9SAndroid Build Coastguard Worker    %t = aten::mul(%x, %y)
1360*da0073e9SAndroid Build Coastguard Worker    %p = aten::mul(%t, %z)
1361*da0073e9SAndroid Build Coastguard Worker    # CHECK: my::fused_mulmul
1362*da0073e9SAndroid Build Coastguard Worker    %u = aten::mul(%p, %x)
1363*da0073e9SAndroid Build Coastguard Worker    %o = aten::mul(%u, %y)
1364*da0073e9SAndroid Build Coastguard Worker    return (%o)"""
1365*da0073e9SAndroid Build Coastguard Worker        graph = parse_ir(input_str)
1366*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_custom_pattern_based_rewrite_graph("""
1367*da0073e9SAndroid Build Coastguard Workergraph(%a, %b, %c):
1368*da0073e9SAndroid Build Coastguard Worker  %q = aten::mul(%a, %b)
1369*da0073e9SAndroid Build Coastguard Worker  %r = aten::mul(%q, %c)
1370*da0073e9SAndroid Build Coastguard Worker  return (%r)""", """
1371*da0073e9SAndroid Build Coastguard Workergraph(%a, %b, %c):
1372*da0073e9SAndroid Build Coastguard Worker  %r = my::fused_mulmul(%a, %b, %c)
1373*da0073e9SAndroid Build Coastguard Worker  return (%r)""", graph)
1374*da0073e9SAndroid Build Coastguard Worker        FileCheck().run(input_str, graph)
1375*da0073e9SAndroid Build Coastguard Worker
1376*da0073e9SAndroid Build Coastguard Worker        # Check that overlapping matches are handled correctly
1377*da0073e9SAndroid Build Coastguard Worker        # mul(mul(mul(x,y),z),x) --> mul(mulmul(x,y,z), x)
1378*da0073e9SAndroid Build Coastguard Worker        input_str = """
1379*da0073e9SAndroid Build Coastguard Workergraph(%x, %y, %z):
1380*da0073e9SAndroid Build Coastguard Worker    # CHECK-NOT: aten::mul
1381*da0073e9SAndroid Build Coastguard Worker    # CHECK: my::fused_mulmul
1382*da0073e9SAndroid Build Coastguard Worker    %t = aten::mul(%x, %y)
1383*da0073e9SAndroid Build Coastguard Worker    %p = aten::mul(%t, %z)
1384*da0073e9SAndroid Build Coastguard Worker    # CHECK-NEXT: aten::mul
1385*da0073e9SAndroid Build Coastguard Worker    %u = aten::mul(%p, %x)
1386*da0073e9SAndroid Build Coastguard Worker    return (%u)"""
1387*da0073e9SAndroid Build Coastguard Worker        graph = parse_ir(input_str)
1388*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_custom_pattern_based_rewrite_graph("""
1389*da0073e9SAndroid Build Coastguard Workergraph(%a, %b, %c):
1390*da0073e9SAndroid Build Coastguard Worker  %q = aten::mul(%a, %b)
1391*da0073e9SAndroid Build Coastguard Worker  %r = aten::mul(%q, %c)
1392*da0073e9SAndroid Build Coastguard Worker  return (%r)""", """
1393*da0073e9SAndroid Build Coastguard Workergraph(%a, %b, %c):
1394*da0073e9SAndroid Build Coastguard Worker  %r = my::fused_mulmul(%a, %b, %c)
1395*da0073e9SAndroid Build Coastguard Worker  return (%r)""", graph)
1396*da0073e9SAndroid Build Coastguard Worker        FileCheck().run(input_str, graph)
1397*da0073e9SAndroid Build Coastguard Worker
1398*da0073e9SAndroid Build Coastguard Worker        # Check add(mul(x,y),z) --> muladd(x,y,z) replacement
1399*da0073e9SAndroid Build Coastguard Worker        input_str = """
1400*da0073e9SAndroid Build Coastguard Workergraph(%x, %y, %z):
1401*da0073e9SAndroid Build Coastguard Worker    # CHECK-NOT: aten::mul
1402*da0073e9SAndroid Build Coastguard Worker    # CHECK-NOT: aten::add
1403*da0073e9SAndroid Build Coastguard Worker    %c = prim::Const[value=1]()
1404*da0073e9SAndroid Build Coastguard Worker    %t = aten::mul(%x, %y)
1405*da0073e9SAndroid Build Coastguard Worker    %p = aten::add(%t, %z, %c)
1406*da0073e9SAndroid Build Coastguard Worker    # CHECK: my::muladd
1407*da0073e9SAndroid Build Coastguard Worker    # CHECK-NEXT: return
1408*da0073e9SAndroid Build Coastguard Worker    return (%p)"""
1409*da0073e9SAndroid Build Coastguard Worker        graph = parse_ir(input_str)
1410*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_custom_pattern_based_rewrite_graph("""
1411*da0073e9SAndroid Build Coastguard Workergraph(%a, %b, %c, %d):
1412*da0073e9SAndroid Build Coastguard Worker  %q = aten::mul(%a, %b)
1413*da0073e9SAndroid Build Coastguard Worker  %r = aten::add(%q, %c, %d)
1414*da0073e9SAndroid Build Coastguard Worker  return (%r)""", """
1415*da0073e9SAndroid Build Coastguard Workergraph(%a, %b, %c, %d):
1416*da0073e9SAndroid Build Coastguard Worker  %r = my::muladd(%a, %b, %c, %d)
1417*da0073e9SAndroid Build Coastguard Worker  return (%r)""", graph)
1418*da0073e9SAndroid Build Coastguard Worker        FileCheck().run(input_str, graph)
1419*da0073e9SAndroid Build Coastguard Worker
1420*da0073e9SAndroid Build Coastguard Worker        # Check add(mul(x,y),z) --> sub(add(x,y),z) replacement
1421*da0073e9SAndroid Build Coastguard Worker        input_str = """
1422*da0073e9SAndroid Build Coastguard Workergraph(%x, %y, %z):
1423*da0073e9SAndroid Build Coastguard Worker    # CHECK-NOT: aten::mul
1424*da0073e9SAndroid Build Coastguard Worker    %c = prim::Const[value=1]()
1425*da0073e9SAndroid Build Coastguard Worker    # CHECK: aten::add
1426*da0073e9SAndroid Build Coastguard Worker    %t = aten::mul(%x, %y)
1427*da0073e9SAndroid Build Coastguard Worker    # CHECK-NEXT: aten::sub
1428*da0073e9SAndroid Build Coastguard Worker    %p = aten::add(%t, %z, %c)
1429*da0073e9SAndroid Build Coastguard Worker    # CHECK-NOT: aten::add
1430*da0073e9SAndroid Build Coastguard Worker    # CHECK-NEXT: return
1431*da0073e9SAndroid Build Coastguard Worker    return (%p)"""
1432*da0073e9SAndroid Build Coastguard Worker        graph = parse_ir(input_str)
1433*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_custom_pattern_based_rewrite_graph("""
1434*da0073e9SAndroid Build Coastguard Workergraph(%a, %b, %c, %d):
1435*da0073e9SAndroid Build Coastguard Worker  %q = aten::mul(%a, %b)
1436*da0073e9SAndroid Build Coastguard Worker  %r = aten::add(%q, %c, %d)
1437*da0073e9SAndroid Build Coastguard Worker  return (%r)""", """
1438*da0073e9SAndroid Build Coastguard Workergraph(%a, %b, %c, %d):
1439*da0073e9SAndroid Build Coastguard Worker  %q = aten::add(%a, %b, %d)
1440*da0073e9SAndroid Build Coastguard Worker  %r = aten::sub(%q, %c, %d)
1441*da0073e9SAndroid Build Coastguard Worker  return (%r)""", graph)
1442*da0073e9SAndroid Build Coastguard Worker        FileCheck().run(input_str, graph)
1443*da0073e9SAndroid Build Coastguard Worker
1444*da0073e9SAndroid Build Coastguard Worker        # Check mul(x,y) --> x replacement
1445*da0073e9SAndroid Build Coastguard Worker        input_str = """
1446*da0073e9SAndroid Build Coastguard Workergraph(%x, %y, %z):
1447*da0073e9SAndroid Build Coastguard Worker    %c = prim::Const[value=1]()
1448*da0073e9SAndroid Build Coastguard Worker    # CHECK-NOT: aten::mul
1449*da0073e9SAndroid Build Coastguard Worker    %t = aten::mul(%x, %y)
1450*da0073e9SAndroid Build Coastguard Worker    # CHECK: aten::add(%x, %z
1451*da0073e9SAndroid Build Coastguard Worker    %p = aten::add(%t, %z, %c)
1452*da0073e9SAndroid Build Coastguard Worker    # CHECK-NEXT: return
1453*da0073e9SAndroid Build Coastguard Worker    return (%p)"""
1454*da0073e9SAndroid Build Coastguard Worker        graph = parse_ir(input_str)
1455*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_custom_pattern_based_rewrite_graph("""
1456*da0073e9SAndroid Build Coastguard Workergraph(%Pa, %Pb):
1457*da0073e9SAndroid Build Coastguard Worker  %Pq = aten::mul(%Pa, %Pb)
1458*da0073e9SAndroid Build Coastguard Worker  return (%Pq)""", """
1459*da0073e9SAndroid Build Coastguard Workergraph(%Ra, %Rb):
1460*da0073e9SAndroid Build Coastguard Worker  return (%Ra)""", graph)
1461*da0073e9SAndroid Build Coastguard Worker        FileCheck().run(input_str, graph)
1462*da0073e9SAndroid Build Coastguard Worker
1463*da0073e9SAndroid Build Coastguard Worker    @_tmp_donotuse_dont_inline_everything
1464*da0073e9SAndroid Build Coastguard Worker    def test_pattern_based_module_rewrite(self):
1465*da0073e9SAndroid Build Coastguard Worker        # Check match::module behavior
1466*da0073e9SAndroid Build Coastguard Worker        class Test(torch.nn.Module):
1467*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1468*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1469*da0073e9SAndroid Build Coastguard Worker                self.conv = torch.nn.Conv2d(1, 20, 5, 1)
1470*da0073e9SAndroid Build Coastguard Worker                self.bn = torch.nn.BatchNorm2d(num_features=20)
1471*da0073e9SAndroid Build Coastguard Worker
1472*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1473*da0073e9SAndroid Build Coastguard Worker                x = self.conv(x)
1474*da0073e9SAndroid Build Coastguard Worker                x = self.bn(x)
1475*da0073e9SAndroid Build Coastguard Worker                return x
1476*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(Test())
1477*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_custom_pattern_based_rewrite_graph("""
1478*da0073e9SAndroid Build Coastguard Worker        graph(%self, %x):
1479*da0073e9SAndroid Build Coastguard Worker                %conv = match::module[name="Conv2d"](%self)
1480*da0073e9SAndroid Build Coastguard Worker                %y = prim::CallMethod[name="forward"](%conv, %x)
1481*da0073e9SAndroid Build Coastguard Worker                %bn = match::module[name="BatchNorm2d"](%self)
1482*da0073e9SAndroid Build Coastguard Worker                %z = prim::CallMethod[name="forward"](%bn, %y)
1483*da0073e9SAndroid Build Coastguard Worker                return (%z)""", """
1484*da0073e9SAndroid Build Coastguard Worker        graph(%self, %x):
1485*da0073e9SAndroid Build Coastguard Worker          %z = my::matched_conv_bn(%self, %x)
1486*da0073e9SAndroid Build Coastguard Worker          return (%z)""", m._c._get_method("forward").graph)
1487*da0073e9SAndroid Build Coastguard Worker
1488*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("my::matched_conv_bn").run(m._c._get_method("forward").graph)
1489*da0073e9SAndroid Build Coastguard Worker
1490*da0073e9SAndroid Build Coastguard Worker    def test_pattern_based_rewrite_with_source_range_preserved(self):
1491*da0073e9SAndroid Build Coastguard Worker        class TestModule1(torch.nn.Module):
1492*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y, z, w):
1493*da0073e9SAndroid Build Coastguard Worker                x = x + y
1494*da0073e9SAndroid Build Coastguard Worker                x = x * z
1495*da0073e9SAndroid Build Coastguard Worker                return w - x
1496*da0073e9SAndroid Build Coastguard Worker
1497*da0073e9SAndroid Build Coastguard Worker        input_pattern = """
1498*da0073e9SAndroid Build Coastguard Worker        graph(%x, %y, %z, %const):
1499*da0073e9SAndroid Build Coastguard Worker            %t = aten::add(%x, %y, %const)
1500*da0073e9SAndroid Build Coastguard Worker            %o = aten::mul(%t, %z)
1501*da0073e9SAndroid Build Coastguard Worker            return (%o)"""
1502*da0073e9SAndroid Build Coastguard Worker        replacement_pattern = """
1503*da0073e9SAndroid Build Coastguard Worker        graph(%x, %y, %z, %const):
1504*da0073e9SAndroid Build Coastguard Worker            %o = my::add_mul(%x, %y, %z, %const)
1505*da0073e9SAndroid Build Coastguard Worker            return (%o)"""
1506*da0073e9SAndroid Build Coastguard Worker        scripted_model = torch.jit.script(TestModule1())
1507*da0073e9SAndroid Build Coastguard Worker        graph = scripted_model.graph
1508*da0073e9SAndroid Build Coastguard Worker        value_mappings = [("o", "t")]
1509*da0073e9SAndroid Build Coastguard Worker        for node in graph.nodes():
1510*da0073e9SAndroid Build Coastguard Worker            if node.kind() == "aten::add":
1511*da0073e9SAndroid Build Coastguard Worker                source_range_1 = node.sourceRange()
1512*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_custom_pattern_based_rewrite_graph(
1513*da0073e9SAndroid Build Coastguard Worker            input_pattern, replacement_pattern, scripted_model.graph, value_name_pairs=value_mappings)
1514*da0073e9SAndroid Build Coastguard Worker        graph = scripted_model.graph
1515*da0073e9SAndroid Build Coastguard Worker        for node in graph.nodes():
1516*da0073e9SAndroid Build Coastguard Worker            if node.kind() == "my::add_mul":
1517*da0073e9SAndroid Build Coastguard Worker                source_range_2 = node.sourceRange()
1518*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(source_range_1 == source_range_2)
1519*da0073e9SAndroid Build Coastguard Worker
1520*da0073e9SAndroid Build Coastguard Worker        class TestModule2(torch.nn.Module):
1521*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y, z, w):
1522*da0073e9SAndroid Build Coastguard Worker                x = x + y
1523*da0073e9SAndroid Build Coastguard Worker                x = x + z
1524*da0073e9SAndroid Build Coastguard Worker                x = x * z
1525*da0073e9SAndroid Build Coastguard Worker                x = x * w
1526*da0073e9SAndroid Build Coastguard Worker                return x - 2
1527*da0073e9SAndroid Build Coastguard Worker
1528*da0073e9SAndroid Build Coastguard Worker        # Check source range preservation for two node transforms add -> my_add
1529*da0073e9SAndroid Build Coastguard Worker        input_pattern = """
1530*da0073e9SAndroid Build Coastguard Worker        graph(%x, %y, %const):
1531*da0073e9SAndroid Build Coastguard Worker            %o = aten::add(%x, %y, %const)
1532*da0073e9SAndroid Build Coastguard Worker            return (%o)"""
1533*da0073e9SAndroid Build Coastguard Worker        replacement_pattern = """
1534*da0073e9SAndroid Build Coastguard Worker        graph(%x, %y, %const):
1535*da0073e9SAndroid Build Coastguard Worker            %o = my::add(%x, %y, %const)
1536*da0073e9SAndroid Build Coastguard Worker            return (%o)"""
1537*da0073e9SAndroid Build Coastguard Worker        scripted_model = copy.deepcopy(torch.jit.script(TestModule2()))
1538*da0073e9SAndroid Build Coastguard Worker        graph_copy = scripted_model.graph.copy()
1539*da0073e9SAndroid Build Coastguard Worker        value_mappings = [("o", "o")]
1540*da0073e9SAndroid Build Coastguard Worker        source_range_add_1 = None
1541*da0073e9SAndroid Build Coastguard Worker        for node in graph_copy.nodes():
1542*da0073e9SAndroid Build Coastguard Worker            if source_range_add_1 is None and node.kind() == "aten::add":
1543*da0073e9SAndroid Build Coastguard Worker                source_range_add_1 = node.sourceRange()
1544*da0073e9SAndroid Build Coastguard Worker            if source_range_add_1 is not None and node.kind() == "aten::add":
1545*da0073e9SAndroid Build Coastguard Worker                source_range_add_2 = node.sourceRange()
1546*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_custom_pattern_based_rewrite_graph(
1547*da0073e9SAndroid Build Coastguard Worker            input_pattern, replacement_pattern, graph_copy, value_name_pairs=value_mappings)
1548*da0073e9SAndroid Build Coastguard Worker        source_range_my_add_1 = None
1549*da0073e9SAndroid Build Coastguard Worker        for node in graph_copy.nodes():
1550*da0073e9SAndroid Build Coastguard Worker            if source_range_my_add_1 is None and node.kind() == "my::add":
1551*da0073e9SAndroid Build Coastguard Worker                source_range_my_add_1 = node.sourceRange()
1552*da0073e9SAndroid Build Coastguard Worker            if source_range_my_add_1 is not None and node.kind() == "my::add":
1553*da0073e9SAndroid Build Coastguard Worker                source_range_my_add_2 = node.sourceRange()
1554*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(source_range_add_1 == source_range_my_add_1)
1555*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(source_range_add_2 == source_range_my_add_2)
1556*da0073e9SAndroid Build Coastguard Worker
1557*da0073e9SAndroid Build Coastguard Worker        # Check source range preservation for add-add -> double_add transform
1558*da0073e9SAndroid Build Coastguard Worker        # fuse nodes
1559*da0073e9SAndroid Build Coastguard Worker        input_pattern = """
1560*da0073e9SAndroid Build Coastguard Worker        graph(%x, %y, %z, %const):
1561*da0073e9SAndroid Build Coastguard Worker            %t = aten::add(%x, %y, %const)
1562*da0073e9SAndroid Build Coastguard Worker            %o = aten::add(%t, %z, %const)
1563*da0073e9SAndroid Build Coastguard Worker            return (%o)"""
1564*da0073e9SAndroid Build Coastguard Worker        replacement_pattern = """
1565*da0073e9SAndroid Build Coastguard Worker        graph(%x, %y, %z, %const):
1566*da0073e9SAndroid Build Coastguard Worker            %o = my::double_add(%x, %y, %z, %const)
1567*da0073e9SAndroid Build Coastguard Worker            return (%o)"""
1568*da0073e9SAndroid Build Coastguard Worker        scripted_model = torch.jit.script(TestModule2())
1569*da0073e9SAndroid Build Coastguard Worker        graph_copy = scripted_model.graph.copy()
1570*da0073e9SAndroid Build Coastguard Worker        value_mappings = [("o", "t")]
1571*da0073e9SAndroid Build Coastguard Worker        source_range_1 = None
1572*da0073e9SAndroid Build Coastguard Worker        source_range_2 = None
1573*da0073e9SAndroid Build Coastguard Worker        for node in graph_copy.nodes():
1574*da0073e9SAndroid Build Coastguard Worker            if node.kind() == "aten::add":
1575*da0073e9SAndroid Build Coastguard Worker                source_range_1 = node.sourceRange()
1576*da0073e9SAndroid Build Coastguard Worker                break
1577*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_custom_pattern_based_rewrite_graph(
1578*da0073e9SAndroid Build Coastguard Worker            input_pattern, replacement_pattern, graph_copy, value_name_pairs=value_mappings)
1579*da0073e9SAndroid Build Coastguard Worker        for node in graph_copy.nodes():
1580*da0073e9SAndroid Build Coastguard Worker            if node.kind() == "my::double_add":
1581*da0073e9SAndroid Build Coastguard Worker                source_range_2 = node.sourceRange()
1582*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(source_range_1 == source_range_2)
1583*da0073e9SAndroid Build Coastguard Worker
1584*da0073e9SAndroid Build Coastguard Worker        # Check source range preservation for mul -> add + add transform
1585*da0073e9SAndroid Build Coastguard Worker        # split node
1586*da0073e9SAndroid Build Coastguard Worker        input_pattern = """
1587*da0073e9SAndroid Build Coastguard Worker        graph(%x, %y):
1588*da0073e9SAndroid Build Coastguard Worker            %t = aten::mul(%x, %y)
1589*da0073e9SAndroid Build Coastguard Worker            return (%t)"""
1590*da0073e9SAndroid Build Coastguard Worker        replacement_pattern = """
1591*da0073e9SAndroid Build Coastguard Worker        graph(%x, %y):
1592*da0073e9SAndroid Build Coastguard Worker            %t = my::add(%x, %y)
1593*da0073e9SAndroid Build Coastguard Worker            %o = my::add(%t, %y)
1594*da0073e9SAndroid Build Coastguard Worker            return (%o)"""
1595*da0073e9SAndroid Build Coastguard Worker        scripted_model = torch.jit.script(TestModule2())
1596*da0073e9SAndroid Build Coastguard Worker        graph_copy = scripted_model.graph.copy()
1597*da0073e9SAndroid Build Coastguard Worker        value_mappings = [("t", "t"), ("o", "t")]
1598*da0073e9SAndroid Build Coastguard Worker        source_range_mul_1 = None
1599*da0073e9SAndroid Build Coastguard Worker        for node in graph_copy.nodes():
1600*da0073e9SAndroid Build Coastguard Worker            if source_range_mul_1 is None and node.kind() == "aten::mul":
1601*da0073e9SAndroid Build Coastguard Worker                source_range_mul_1 = node.sourceRange()
1602*da0073e9SAndroid Build Coastguard Worker            if source_range_mul_1 is not None and node.kind() == "aten::mul":
1603*da0073e9SAndroid Build Coastguard Worker                source_range_mul_2 = node.sourceRange()
1604*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_custom_pattern_based_rewrite_graph(
1605*da0073e9SAndroid Build Coastguard Worker            input_pattern, replacement_pattern, graph_copy, value_name_pairs=value_mappings)
1606*da0073e9SAndroid Build Coastguard Worker        source_range_add_1 = None
1607*da0073e9SAndroid Build Coastguard Worker        for node in graph_copy.nodes():
1608*da0073e9SAndroid Build Coastguard Worker            if source_range_add_1 is None and node.kind() == "my::add":
1609*da0073e9SAndroid Build Coastguard Worker                source_range_add_1 = node.sourceRange()
1610*da0073e9SAndroid Build Coastguard Worker            if source_range_add_1 is not None and node.kind() == "my::add":
1611*da0073e9SAndroid Build Coastguard Worker                source_range_add_2 = node.sourceRange()
1612*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(source_range_mul_1 == source_range_add_1)
1613*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(source_range_mul_2 == source_range_add_2)
1614*da0073e9SAndroid Build Coastguard Worker
1615*da0073e9SAndroid Build Coastguard Worker        # Check lack of source range preservation for mul-mul-> double_mul transform
1616*da0073e9SAndroid Build Coastguard Worker        input_pattern = """
1617*da0073e9SAndroid Build Coastguard Worker        graph(%x, %y, %z):
1618*da0073e9SAndroid Build Coastguard Worker            %t = aten::mul(%x, %y)
1619*da0073e9SAndroid Build Coastguard Worker            %o = aten::mul(%t, %z)
1620*da0073e9SAndroid Build Coastguard Worker            return (%o)"""
1621*da0073e9SAndroid Build Coastguard Worker        replacement_pattern = """
1622*da0073e9SAndroid Build Coastguard Worker        graph(%x, %y, %z):
1623*da0073e9SAndroid Build Coastguard Worker            %o = my::double_mul(%x, %y, %z)
1624*da0073e9SAndroid Build Coastguard Worker            return (%o)"""
1625*da0073e9SAndroid Build Coastguard Worker        scripted_model = torch.jit.script(TestModule2())
1626*da0073e9SAndroid Build Coastguard Worker        graph_copy = scripted_model.graph.copy()
1627*da0073e9SAndroid Build Coastguard Worker        for node in graph_copy.nodes():
1628*da0073e9SAndroid Build Coastguard Worker            if node.kind() == "aten::mul":
1629*da0073e9SAndroid Build Coastguard Worker                source_range_1 = node.sourceRange()
1630*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_custom_pattern_based_rewrite_graph(input_pattern, replacement_pattern, graph_copy)
1631*da0073e9SAndroid Build Coastguard Worker        for node in graph_copy.nodes():
1632*da0073e9SAndroid Build Coastguard Worker            if node.kind() == "my::double_mul":
1633*da0073e9SAndroid Build Coastguard Worker                source_range_2 = node.sourceRange()
1634*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(source_range_1 == source_range_2)
1635*da0073e9SAndroid Build Coastguard Worker
1636*da0073e9SAndroid Build Coastguard Worker    def test_expand_quantlint(self):
1637*da0073e9SAndroid Build Coastguard Worker        pass
1638*da0073e9SAndroid Build Coastguard Worker
1639*da0073e9SAndroid Build Coastguard Worker    def test_expand_fold_quant_inputs(self):
1640*da0073e9SAndroid Build Coastguard Worker        pass
1641*da0073e9SAndroid Build Coastguard Worker
1642*da0073e9SAndroid Build Coastguard Worker    def test_shape_analysis_broadcast(self):
1643*da0073e9SAndroid Build Coastguard Worker        def broadcast(a, b):
1644*da0073e9SAndroid Build Coastguard Worker            return a + b
1645*da0073e9SAndroid Build Coastguard Worker
1646*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 1, 5, requires_grad=True)
1647*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(4, 1, 8, 5, requires_grad=True)
1648*da0073e9SAndroid Build Coastguard Worker
1649*da0073e9SAndroid Build Coastguard Worker        graph = torch.jit.script(broadcast).graph
1650*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_complete_shape_analysis(graph, (x, y), False)
1651*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("Float(4, 3, 8, 5, strides=[120, 40, 5, 1], device=cpu)").run(str(graph))
1652*da0073e9SAndroid Build Coastguard Worker
1653*da0073e9SAndroid Build Coastguard Worker    def test_shape_analysis_unsqueeze_in_loop(self):
1654*da0073e9SAndroid Build Coastguard Worker        input_str = """graph(%x.1 : Tensor):
1655*da0073e9SAndroid Build Coastguard Worker          %4 : bool = prim::Constant[value=1]()
1656*da0073e9SAndroid Build Coastguard Worker          %1 : int = prim::Constant[value=2]()
1657*da0073e9SAndroid Build Coastguard Worker          %7 : int = prim::Constant[value=0]()
1658*da0073e9SAndroid Build Coastguard Worker          # CHECK: FloatTensor(requires_grad=0, device=cpu) = prim::Loop
1659*da0073e9SAndroid Build Coastguard Worker          %x : Tensor = prim::Loop(%1, %4, %x.1)
1660*da0073e9SAndroid Build Coastguard Worker            # CHECK: : FloatTensor(requires_grad=0, device=cpu)):
1661*da0073e9SAndroid Build Coastguard Worker            block0(%i : int, %x.6 : Tensor):
1662*da0073e9SAndroid Build Coastguard Worker              # CHECK: FloatTensor(requires_grad=0, device=cpu) = aten::unsqueeze
1663*da0073e9SAndroid Build Coastguard Worker              %x.3 : Tensor = aten::unsqueeze(%x.6, %7)
1664*da0073e9SAndroid Build Coastguard Worker              -> (%4, %x.3)
1665*da0073e9SAndroid Build Coastguard Worker          return (%x)"""
1666*da0073e9SAndroid Build Coastguard Worker        graph = parse_ir(input_str)
1667*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_complete_shape_analysis(graph, (torch.zeros(2, 2, dtype=torch.float32),), False)
1668*da0073e9SAndroid Build Coastguard Worker        FileCheck().run(input_str, graph)
1669*da0073e9SAndroid Build Coastguard Worker
1670*da0073e9SAndroid Build Coastguard Worker    def test_script_tensor_type(self):
1671*da0073e9SAndroid Build Coastguard Worker        def foo(x, t: torch.dtype):
1672*da0073e9SAndroid Build Coastguard Worker            return x.type(t)
1673*da0073e9SAndroid Build Coastguard Worker        scr = torch.jit.script(foo)
1674*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
1675*da0073e9SAndroid Build Coastguard Worker        for t in [torch.int8, torch.float64, torch.float32,
1676*da0073e9SAndroid Build Coastguard Worker                  torch.bfloat16, torch.complex64, torch.complex128, torch.bool]:
1677*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(scr(x, t), foo(x, t))
1678*da0073e9SAndroid Build Coastguard Worker
1679*da0073e9SAndroid Build Coastguard Worker    def test_script_bool_literal_conversion(self):
1680*da0073e9SAndroid Build Coastguard Worker        def foo(x):
1681*da0073e9SAndroid Build Coastguard Worker            return torch.mul(x, True)
1682*da0073e9SAndroid Build Coastguard Worker        scr = torch.jit.script(foo)
1683*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
1684*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scr(x), foo(x))
1685*da0073e9SAndroid Build Coastguard Worker
1686*da0073e9SAndroid Build Coastguard Worker    def test_shape_analysis_masked_select(self):
1687*da0073e9SAndroid Build Coastguard Worker        input_str = """graph(%0 : Float(),
1688*da0073e9SAndroid Build Coastguard Worker          %1 : Bool()):
1689*da0073e9SAndroid Build Coastguard Worker          # CHECK: Float(*, requires_grad=0, device=cpu) = aten::masked_select
1690*da0073e9SAndroid Build Coastguard Worker          %2 : Tensor = aten::masked_select(%0, %1) # test/test_jit.py:15261:0
1691*da0073e9SAndroid Build Coastguard Worker          return (%2)"""
1692*da0073e9SAndroid Build Coastguard Worker        graph = parse_ir(input_str)
1693*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(1, dtype=torch.float32)[0]
1694*da0073e9SAndroid Build Coastguard Worker        mask = x.ge(0.5)
1695*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_complete_shape_analysis(graph, (x, mask), False)
1696*da0073e9SAndroid Build Coastguard Worker        FileCheck().run(input_str, graph)
1697*da0073e9SAndroid Build Coastguard Worker
1698*da0073e9SAndroid Build Coastguard Worker    # TODO: update verify to work with GraphExecutors
1699*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("verify needs to be updated to work with GraphExecutors")
1700*da0073e9SAndroid Build Coastguard Worker    def test_verify(self):
1701*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([0.4], requires_grad=True)
1702*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor([0.7], requires_grad=True)
1703*da0073e9SAndroid Build Coastguard Worker
1704*da0073e9SAndroid Build Coastguard Worker        @torch.jit.compile
1705*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
1706*da0073e9SAndroid Build Coastguard Worker            z = torch.sigmoid(x * (x + y))
1707*da0073e9SAndroid Build Coastguard Worker            w = torch.abs(x * x * x + y) + Variable(torch.ones(1))
1708*da0073e9SAndroid Build Coastguard Worker            return z, w
1709*da0073e9SAndroid Build Coastguard Worker
1710*da0073e9SAndroid Build Coastguard Worker        torch.jit.verify(f, (x, y), loss_fn=lambda z, w: z * w, devices=[])
1711*da0073e9SAndroid Build Coastguard Worker
1712*da0073e9SAndroid Build Coastguard Worker    # TODO: adapt to a GraphExecutor test
1713*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("Need to instrument GraphExecutors a bit more")
1714*da0073e9SAndroid Build Coastguard Worker    def test_flags(self):
1715*da0073e9SAndroid Build Coastguard Worker        x, y = torch.randn(2, 2)
1716*da0073e9SAndroid Build Coastguard Worker        y = Variable(torch.randn(2, 2))
1717*da0073e9SAndroid Build Coastguard Worker
1718*da0073e9SAndroid Build Coastguard Worker        @torch.jit.compile
1719*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
1720*da0073e9SAndroid Build Coastguard Worker            return (x * x + y * y + x * y).sum()
1721*da0073e9SAndroid Build Coastguard Worker
1722*da0073e9SAndroid Build Coastguard Worker        grads = {}
1723*da0073e9SAndroid Build Coastguard Worker        for rx, ry in product((True, False), repeat=2):
1724*da0073e9SAndroid Build Coastguard Worker            x.requires_grad = rx
1725*da0073e9SAndroid Build Coastguard Worker            y.requires_grad = ry
1726*da0073e9SAndroid Build Coastguard Worker
1727*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(fn.has_trace_for(x, y))
1728*da0073e9SAndroid Build Coastguard Worker            out = fn(x, y)
1729*da0073e9SAndroid Build Coastguard Worker
1730*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(fn.has_trace_for(x, y))
1731*da0073e9SAndroid Build Coastguard Worker            for v, name, compute in [(x, 'x', rx), (y, 'y', ry)]:
1732*da0073e9SAndroid Build Coastguard Worker                if not compute:
1733*da0073e9SAndroid Build Coastguard Worker                    continue
1734*da0073e9SAndroid Build Coastguard Worker                grad_v, = torch.autograd.grad(out, v, retain_graph=True)
1735*da0073e9SAndroid Build Coastguard Worker                expected_grad = grads.setdefault(name, grad_v)
1736*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(grad_v, expected_grad)
1737*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(fn.has_trace_for(x, y), rx or ry)
1738*da0073e9SAndroid Build Coastguard Worker
1739*da0073e9SAndroid Build Coastguard Worker    def test_python_ir(self):
1740*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([0.4], requires_grad=True)
1741*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor([0.7], requires_grad=True)
1742*da0073e9SAndroid Build Coastguard Worker
1743*da0073e9SAndroid Build Coastguard Worker        def doit(x, y):
1744*da0073e9SAndroid Build Coastguard Worker            return torch.sigmoid(torch.tanh(x * (x + y)))
1745*da0073e9SAndroid Build Coastguard Worker
1746*da0073e9SAndroid Build Coastguard Worker        g, _ = torch.jit._get_trace_graph(doit, (x, y))
1747*da0073e9SAndroid Build Coastguard Worker        self.run_pass('dce', g)
1748*da0073e9SAndroid Build Coastguard Worker        self.run_pass('canonicalize', g)
1749*da0073e9SAndroid Build Coastguard Worker        g2 = torch._C.Graph()
1750*da0073e9SAndroid Build Coastguard Worker        g_to_g2 = {}
1751*da0073e9SAndroid Build Coastguard Worker        for node in g.inputs():
1752*da0073e9SAndroid Build Coastguard Worker            g_to_g2[node] = g2.addInput()
1753*da0073e9SAndroid Build Coastguard Worker        for node in g.nodes():
1754*da0073e9SAndroid Build Coastguard Worker            n_ = g2.createClone(node, lambda x: g_to_g2[x])
1755*da0073e9SAndroid Build Coastguard Worker            g2.appendNode(n_)
1756*da0073e9SAndroid Build Coastguard Worker            for o, no in zip(node.outputs(), n_.outputs()):
1757*da0073e9SAndroid Build Coastguard Worker                g_to_g2[o] = no
1758*da0073e9SAndroid Build Coastguard Worker
1759*da0073e9SAndroid Build Coastguard Worker        for node in g.outputs():
1760*da0073e9SAndroid Build Coastguard Worker            g2.registerOutput(g_to_g2[node])
1761*da0073e9SAndroid Build Coastguard Worker
1762*da0073e9SAndroid Build Coastguard Worker        t_node = g2.create("prim::TensorTest").t_("a", torch.ones([2, 2]))
1763*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t_node.attributeNames(), ["a"])
1764*da0073e9SAndroid Build Coastguard Worker        g2.appendNode(t_node)
1765*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.equal(torch.ones(2, 2), t_node.t("a")))
1766*da0073e9SAndroid Build Coastguard Worker        for node in g.nodes():
1767*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(g2.findNode(node.kind()) is not None)
1768*da0073e9SAndroid Build Coastguard Worker
1769*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_SANDCASTLE, "gtest runs these in sandcastle")
1770*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(RUN_CUDA, "covered by test_cpp_cuda")
1771*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch._C._jit_has_cpp_tests(), "Tests were not built, use BUILD_TEST=1")
1772*da0073e9SAndroid Build Coastguard Worker    def test_cpp(self):
1773*da0073e9SAndroid Build Coastguard Worker        from cpp.jit import tests_setup
1774*da0073e9SAndroid Build Coastguard Worker        tests_setup.setup()
1775*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_run_cpp_tests()
1776*da0073e9SAndroid Build Coastguard Worker        tests_setup.shutdown()
1777*da0073e9SAndroid Build Coastguard Worker
1778*da0073e9SAndroid Build Coastguard Worker    def test_batchnorm(self):
1779*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2, 2, 2, 2)
1780*da0073e9SAndroid Build Coastguard Worker        g, outputs, inputs = torch.jit._get_trace_graph(nn.BatchNorm2d(2), x,
1781*da0073e9SAndroid Build Coastguard Worker                                                        _force_outplace=True, return_inputs=True)
1782*da0073e9SAndroid Build Coastguard Worker        m = self.createFunctionFromGraph(g)
1783*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(outputs, m(*inputs))
1784*da0073e9SAndroid Build Coastguard Worker
1785*da0073e9SAndroid Build Coastguard Worker    def test_dropout(self):
1786*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2, 2)
1787*da0073e9SAndroid Build Coastguard Worker        with torch.random.fork_rng(devices=[]):
1788*da0073e9SAndroid Build Coastguard Worker            g, outputs, inputs = torch.jit._get_trace_graph(nn.Dropout(0.6), x, return_inputs=True)
1789*da0073e9SAndroid Build Coastguard Worker        with torch.random.fork_rng(devices=[]):
1790*da0073e9SAndroid Build Coastguard Worker            m = self.createFunctionFromGraph(g)
1791*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(outputs, m(*inputs))
1792*da0073e9SAndroid Build Coastguard Worker
1793*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "test requires CUDA")
1794*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "skip if profiling isn't enabled")
1795*da0073e9SAndroid Build Coastguard Worker    def test_native_dropout_corner_case(self):
1796*da0073e9SAndroid Build Coastguard Worker        with disable_autodiff_subgraph_inlining():
1797*da0073e9SAndroid Build Coastguard Worker            def t(x, p: float, t: bool):
1798*da0073e9SAndroid Build Coastguard Worker                o = torch.dropout(x, p, t)
1799*da0073e9SAndroid Build Coastguard Worker                return o
1800*da0073e9SAndroid Build Coastguard Worker
1801*da0073e9SAndroid Build Coastguard Worker            jit_t = torch.jit.script(t)
1802*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(5).requires_grad_()
1803*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("prim::DifferentiableGraph").run(jit_t.graph_for(x, 1.0, True, profile_and_replay=True))
1804*da0073e9SAndroid Build Coastguard Worker
1805*da0073e9SAndroid Build Coastguard Worker            for train in [True, False]:
1806*da0073e9SAndroid Build Coastguard Worker                for p in [0.0, 1.0]:
1807*da0073e9SAndroid Build Coastguard Worker                    for device in ["cuda", "cpu"]:
1808*da0073e9SAndroid Build Coastguard Worker                        x = torch.randn(5).to(device=device).requires_grad_()
1809*da0073e9SAndroid Build Coastguard Worker                        x_ref = x.detach().requires_grad_()
1810*da0073e9SAndroid Build Coastguard Worker                        o = jit_t(x, p, train)
1811*da0073e9SAndroid Build Coastguard Worker                        o_ref = t(x_ref, p, train)
1812*da0073e9SAndroid Build Coastguard Worker                        o.sum().backward()
1813*da0073e9SAndroid Build Coastguard Worker                        o_ref.sum().backward()
1814*da0073e9SAndroid Build Coastguard Worker                        assert o.equal(o_ref)
1815*da0073e9SAndroid Build Coastguard Worker                        assert x.grad.equal(x_ref.grad)
1816*da0073e9SAndroid Build Coastguard Worker
1817*da0073e9SAndroid Build Coastguard Worker    @slowTest
1818*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, 'Testing differentiable graph')
1819*da0073e9SAndroid Build Coastguard Worker    def test_dropout_module_requires_grad(self):
1820*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode_for_profiling_tests():
1821*da0073e9SAndroid Build Coastguard Worker            class MyModule(torch.nn.Module):
1822*da0073e9SAndroid Build Coastguard Worker                def __init__(self, M):
1823*da0073e9SAndroid Build Coastguard Worker                    super().__init__()
1824*da0073e9SAndroid Build Coastguard Worker                    self.dropout = torch.nn.Dropout(0.5)
1825*da0073e9SAndroid Build Coastguard Worker                    self.linear = torch.nn.Linear(M, M)
1826*da0073e9SAndroid Build Coastguard Worker
1827*da0073e9SAndroid Build Coastguard Worker                def forward(self, input):
1828*da0073e9SAndroid Build Coastguard Worker                    input = self.dropout(input)
1829*da0073e9SAndroid Build Coastguard Worker                    output = self.linear(input)
1830*da0073e9SAndroid Build Coastguard Worker                    return output
1831*da0073e9SAndroid Build Coastguard Worker
1832*da0073e9SAndroid Build Coastguard Worker            def profile(func, X):
1833*da0073e9SAndroid Build Coastguard Worker                with torch.autograd.profiler.profile() as prof:
1834*da0073e9SAndroid Build Coastguard Worker                    func(X)
1835*da0073e9SAndroid Build Coastguard Worker                return [e.name for e in prof.function_events]
1836*da0073e9SAndroid Build Coastguard Worker
1837*da0073e9SAndroid Build Coastguard Worker            M = 1000
1838*da0073e9SAndroid Build Coastguard Worker            scripted = torch.jit.script(MyModule(M))
1839*da0073e9SAndroid Build Coastguard Worker            # To reduce confusion about expected behaviors:
1840*da0073e9SAndroid Build Coastguard Worker            #   requires_grad controls whether dropout is symbolically differentiated.
1841*da0073e9SAndroid Build Coastguard Worker            #   training controls whether bernoulli_ is called inside symbolic differentiation of dropout.
1842*da0073e9SAndroid Build Coastguard Worker            # * When requires_grad == training, the expected behaviors are obvious.
1843*da0073e9SAndroid Build Coastguard Worker            # * When requires_grad=True and training=False, bernoulli_ might still show up in the graph.
1844*da0073e9SAndroid Build Coastguard Worker            #   But it's in a branch that's not called. That's why we have separate checks for autograd
1845*da0073e9SAndroid Build Coastguard Worker            #   profiler to make sure it's not run.
1846*da0073e9SAndroid Build Coastguard Worker            # * When requires_grad=False and training=True, bernoulli_ must be run since it's the expected
1847*da0073e9SAndroid Build Coastguard Worker            #   behavior for the dropout layer in training mode. It's independent of whether graph requires
1848*da0073e9SAndroid Build Coastguard Worker            #   gradient. In fact bernoulli_ comes from autograd instead of autodiff in this case.
1849*da0073e9SAndroid Build Coastguard Worker            for training in (True, False):
1850*da0073e9SAndroid Build Coastguard Worker                if training:
1851*da0073e9SAndroid Build Coastguard Worker                    scripted.train()
1852*da0073e9SAndroid Build Coastguard Worker                else:
1853*da0073e9SAndroid Build Coastguard Worker                    scripted.eval()
1854*da0073e9SAndroid Build Coastguard Worker                for requires_grad in (True, False):
1855*da0073e9SAndroid Build Coastguard Worker                    X = torch.randn(M, M, requires_grad=requires_grad)
1856*da0073e9SAndroid Build Coastguard Worker                    if requires_grad:
1857*da0073e9SAndroid Build Coastguard Worker                        FileCheck().check("aten::native_dropout").run(scripted.graph_for(X, profile_and_replay=True))
1858*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(training, 'aten::bernoulli_' in profile(scripted, X))
1859*da0073e9SAndroid Build Coastguard Worker
1860*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.SIMPLE, 'Testing differentiable graph')
1861*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Torchdynamo cannot correctly handle profiler.profile calls")
1862*da0073e9SAndroid Build Coastguard Worker    def test_dropout_func_requires_grad(self):
1863*da0073e9SAndroid Build Coastguard Worker        def dropout_training(input):
1864*da0073e9SAndroid Build Coastguard Worker            return F.dropout(input, 0.5, training=True)
1865*da0073e9SAndroid Build Coastguard Worker
1866*da0073e9SAndroid Build Coastguard Worker        def dropout_eval(input):
1867*da0073e9SAndroid Build Coastguard Worker            return F.dropout(input, 0.5, training=False)
1868*da0073e9SAndroid Build Coastguard Worker
1869*da0073e9SAndroid Build Coastguard Worker        def profile(func, X):
1870*da0073e9SAndroid Build Coastguard Worker            with torch.autograd.profiler.profile() as prof:
1871*da0073e9SAndroid Build Coastguard Worker                func(X)
1872*da0073e9SAndroid Build Coastguard Worker            return [e.name for e in prof.function_events]
1873*da0073e9SAndroid Build Coastguard Worker
1874*da0073e9SAndroid Build Coastguard Worker        M = 1000
1875*da0073e9SAndroid Build Coastguard Worker        scripted_training = torch.jit.script(dropout_training)
1876*da0073e9SAndroid Build Coastguard Worker        scripted_eval = torch.jit.script(dropout_eval)
1877*da0073e9SAndroid Build Coastguard Worker        # See comments in test_dropout_module_requires_grad.
1878*da0073e9SAndroid Build Coastguard Worker        with disable_autodiff_subgraph_inlining():
1879*da0073e9SAndroid Build Coastguard Worker            for requires_grad in (True, False):
1880*da0073e9SAndroid Build Coastguard Worker                X = torch.randn(M, M, requires_grad=requires_grad)
1881*da0073e9SAndroid Build Coastguard Worker                if requires_grad:
1882*da0073e9SAndroid Build Coastguard Worker                    FileCheck().check("aten::native_dropout").run(scripted_training.graph_for(X, profile_and_replay=True))
1883*da0073e9SAndroid Build Coastguard Worker                self.assertIn('aten::bernoulli_', profile(scripted_training, X))
1884*da0073e9SAndroid Build Coastguard Worker                self.assertNotIn('aten::bernoulli_', profile(scripted_eval, X))
1885*da0073e9SAndroid Build Coastguard Worker
1886*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "test_dropout_cuda require CUDA")
1887*da0073e9SAndroid Build Coastguard Worker    def test_dropout_cuda(self):
1888*da0073e9SAndroid Build Coastguard Worker        # Dropout AD is dispatched to _fused_dropout in CUDA case,
1889*da0073e9SAndroid Build Coastguard Worker        # which is not included in TestJitGeneratedFunctional
1890*da0073e9SAndroid Build Coastguard Worker        def _zero_rate(t):
1891*da0073e9SAndroid Build Coastguard Worker            return torch.true_divide((t == 0).sum(), t.numel())
1892*da0073e9SAndroid Build Coastguard Worker
1893*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(1000, 1000).cuda().requires_grad_()
1894*da0073e9SAndroid Build Coastguard Worker
1895*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode_for_profiling_tests():
1896*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
1897*da0073e9SAndroid Build Coastguard Worker            def func(x):
1898*da0073e9SAndroid Build Coastguard Worker                return torch.nn.functional.dropout(x)
1899*da0073e9SAndroid Build Coastguard Worker
1900*da0073e9SAndroid Build Coastguard Worker            with freeze_rng_state():
1901*da0073e9SAndroid Build Coastguard Worker                out_ref = torch.nn.functional.dropout(x)
1902*da0073e9SAndroid Build Coastguard Worker                grad_ref = torch.autograd.grad(out_ref.sum(), x)
1903*da0073e9SAndroid Build Coastguard Worker
1904*da0073e9SAndroid Build Coastguard Worker            with freeze_rng_state():
1905*da0073e9SAndroid Build Coastguard Worker                out = func(x)
1906*da0073e9SAndroid Build Coastguard Worker                grad = torch.autograd.grad(out.sum(), x)
1907*da0073e9SAndroid Build Coastguard Worker
1908*da0073e9SAndroid Build Coastguard Worker            # TODO(#40882): previously we assert exact matches between eager and JIT result:
1909*da0073e9SAndroid Build Coastguard Worker            #  self.assertEqual(out, out_ref)
1910*da0073e9SAndroid Build Coastguard Worker            #  self.assertEqual(grad, grad_ref)
1911*da0073e9SAndroid Build Coastguard Worker            # This test was disabled during legacy -> profiling executor transition.
1912*da0073e9SAndroid Build Coastguard Worker            # Currently JIT fused results doesn't match eager result exactly due to some changes merged in between.
1913*da0073e9SAndroid Build Coastguard Worker            # We temporarily only check statstical difference but it should be reverted once the issue is fixed.
1914*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(_zero_rate(out), _zero_rate(out_ref), rtol=1e-3, atol=1e-4)
1915*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(_zero_rate(grad[0]), _zero_rate(grad_ref[0]), rtol=1e-3, atol=1e-4)
1916*da0073e9SAndroid Build Coastguard Worker
1917*da0073e9SAndroid Build Coastguard Worker    def test_torch_ops_overloaded(self):
1918*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "failed to match any schema"):
1919*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten.add("a", 1)
1920*da0073e9SAndroid Build Coastguard Worker        self.assertEqual("ab", torch.ops.aten.add("a", "b"))
1921*da0073e9SAndroid Build Coastguard Worker        a, b = torch.rand(3, 4), torch.rand(3, 4)
1922*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a + b, torch.ops.aten.add(a, b))
1923*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a + 1, torch.ops.aten.add(a, 1))
1924*da0073e9SAndroid Build Coastguard Worker
1925*da0073e9SAndroid Build Coastguard Worker    def test_torch_ops_kwonly(self):
1926*da0073e9SAndroid Build Coastguard Worker        a, b = torch.rand(3, 4), torch.rand(3, 4)
1927*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "positional argument"):
1928*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten.add(a, b, 2)
1929*da0073e9SAndroid Build Coastguard Worker        # h/t Chillee for this ambiguous case
1930*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.prod(1), torch.ops.aten.prod(a, 1))
1931*da0073e9SAndroid Build Coastguard Worker
1932*da0073e9SAndroid Build Coastguard Worker    def test_torch_complex(self):
1933*da0073e9SAndroid Build Coastguard Worker        def fn(real, img):
1934*da0073e9SAndroid Build Coastguard Worker            return torch.complex(real, img)
1935*da0073e9SAndroid Build Coastguard Worker
1936*da0073e9SAndroid Build Coastguard Worker        def fn_out(real, img, out):
1937*da0073e9SAndroid Build Coastguard Worker            return torch.complex(real, img, out=out)
1938*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (torch.rand(3, 4), torch.rand(3, 4), ))
1939*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (torch.ones(5, 1, 4), torch.ones(5, 1, 4), ))
1940*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (torch.zeros(1, 6), torch.ones(6, 1), ))
1941*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (torch.zeros(1, 6), torch.zeros(6, 1), ))
1942*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (torch.empty(3, 4), torch.empty(3, 4), ))
1943*da0073e9SAndroid Build Coastguard Worker
1944*da0073e9SAndroid Build Coastguard Worker        real = torch.tensor([1, 2], dtype=torch.float32)
1945*da0073e9SAndroid Build Coastguard Worker        img = torch.tensor([3, 4], dtype=torch.float32)
1946*da0073e9SAndroid Build Coastguard Worker        out = torch.empty([3, 4], dtype=torch.complex64)
1947*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn_out, (real, img, out, ))
1948*da0073e9SAndroid Build Coastguard Worker
1949*da0073e9SAndroid Build Coastguard Worker        real = torch.tensor([5, 2], dtype=torch.float64)
1950*da0073e9SAndroid Build Coastguard Worker        img = torch.tensor([3, 4], dtype=torch.float64)
1951*da0073e9SAndroid Build Coastguard Worker        out = torch.empty([5, 2], dtype=torch.complex128)
1952*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn_out, (real, img, out, ))
1953*da0073e9SAndroid Build Coastguard Worker
1954*da0073e9SAndroid Build Coastguard Worker        real = torch.ones([1, 2])
1955*da0073e9SAndroid Build Coastguard Worker        img = torch.ones([1, 2])
1956*da0073e9SAndroid Build Coastguard Worker        out = torch.empty([1, 2], dtype=torch.complex64)
1957*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn_out, (real, img, out, ))
1958*da0073e9SAndroid Build Coastguard Worker
1959*da0073e9SAndroid Build Coastguard Worker        real = torch.ones([3, 8, 7])
1960*da0073e9SAndroid Build Coastguard Worker        img = torch.ones([3, 8, 7])
1961*da0073e9SAndroid Build Coastguard Worker        out = torch.empty([3, 8, 7], dtype=torch.complex64)
1962*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn_out, (real, img, out, ))
1963*da0073e9SAndroid Build Coastguard Worker
1964*da0073e9SAndroid Build Coastguard Worker        real = torch.empty([3, 2, 6])
1965*da0073e9SAndroid Build Coastguard Worker        img = torch.empty([3, 2, 6])
1966*da0073e9SAndroid Build Coastguard Worker        out = torch.empty([3, 2, 6], dtype=torch.complex64)
1967*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn_out, (real, img, out, ))
1968*da0073e9SAndroid Build Coastguard Worker
1969*da0073e9SAndroid Build Coastguard Worker        real = torch.zeros([1, 3])
1970*da0073e9SAndroid Build Coastguard Worker        img = torch.empty([3, 1])
1971*da0073e9SAndroid Build Coastguard Worker        out = torch.empty([3, 3], dtype=torch.complex64)
1972*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn_out, (real, img, out, ))
1973*da0073e9SAndroid Build Coastguard Worker
1974*da0073e9SAndroid Build Coastguard Worker        real = torch.ones([2, 5])
1975*da0073e9SAndroid Build Coastguard Worker        img = torch.empty([2, 1])
1976*da0073e9SAndroid Build Coastguard Worker        out = torch.empty([2, 5], dtype=torch.complex64)
1977*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn_out, (real, img, out, ))
1978*da0073e9SAndroid Build Coastguard Worker
1979*da0073e9SAndroid Build Coastguard Worker        real = torch.ones([2, 5])
1980*da0073e9SAndroid Build Coastguard Worker        img = torch.zeros([2, 1])
1981*da0073e9SAndroid Build Coastguard Worker        out = torch.empty([2, 5], dtype=torch.complex64)
1982*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn_out, (real, img, out, ))
1983*da0073e9SAndroid Build Coastguard Worker
1984*da0073e9SAndroid Build Coastguard Worker    def test_einsum(self):
1985*da0073e9SAndroid Build Coastguard Worker        def check(fn, jitted, *args):
1986*da0073e9SAndroid Build Coastguard Worker            self.assertGraphContains(jitted.graph, kind='aten::einsum')
1987*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(fn(*args), jitted(*args))
1988*da0073e9SAndroid Build Coastguard Worker
1989*da0073e9SAndroid Build Coastguard Worker        def equation_format(x, y):
1990*da0073e9SAndroid Build Coastguard Worker            return torch.einsum('i,j->ij', (x, y))
1991*da0073e9SAndroid Build Coastguard Worker
1992*da0073e9SAndroid Build Coastguard Worker        def equation_format_varargs(x, y):
1993*da0073e9SAndroid Build Coastguard Worker            return torch.einsum('i,j->ij', x, y)
1994*da0073e9SAndroid Build Coastguard Worker
1995*da0073e9SAndroid Build Coastguard Worker        def sublist_format(x, y):
1996*da0073e9SAndroid Build Coastguard Worker            return torch.einsum(x, [0], y, [1], [0, 1])
1997*da0073e9SAndroid Build Coastguard Worker
1998*da0073e9SAndroid Build Coastguard Worker        x = make_tensor((5,), dtype=torch.float32, device="cpu")
1999*da0073e9SAndroid Build Coastguard Worker        y = make_tensor((10,), dtype=torch.float32, device="cpu")
2000*da0073e9SAndroid Build Coastguard Worker
2001*da0073e9SAndroid Build Coastguard Worker        for fn in [equation_format, equation_format_varargs, sublist_format]:
2002*da0073e9SAndroid Build Coastguard Worker            check(fn, torch.jit.script(fn), x, y)
2003*da0073e9SAndroid Build Coastguard Worker            check(fn, torch.jit.trace(fn, (x, y)), x, y)
2004*da0073e9SAndroid Build Coastguard Worker
2005*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
2006*da0073e9SAndroid Build Coastguard Worker    def test_python_ivalue(self):
2007*da0073e9SAndroid Build Coastguard Worker        # Test if pure python object can be hold as IValue and conversion
2008*da0073e9SAndroid Build Coastguard Worker        # between IValue and PyObject are correct
2009*da0073e9SAndroid Build Coastguard Worker        # test for numpy object
2010*da0073e9SAndroid Build Coastguard Worker        py_array = np.arange(15)
2011*da0073e9SAndroid Build Coastguard Worker        ret_py_obj = torch._C._ivalue_debug_python_object(py_array)
2012*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(py_array, ret_py_obj)
2013*da0073e9SAndroid Build Coastguard Worker
2014*da0073e9SAndroid Build Coastguard Worker        # test for function object
2015*da0073e9SAndroid Build Coastguard Worker        ret_py_obj = torch._C._ivalue_debug_python_object(F.relu)
2016*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(F.relu, ret_py_obj)
2017*da0073e9SAndroid Build Coastguard Worker
2018*da0073e9SAndroid Build Coastguard Worker        # test for memory management
2019*da0073e9SAndroid Build Coastguard Worker        # we need to ensure IValue correctly call incref/decref to avoid
2020*da0073e9SAndroid Build Coastguard Worker        # dangling behavior and potential memory leaks during conversions
2021*da0073e9SAndroid Build Coastguard Worker        def test_func_scope_helper(inp):
2022*da0073e9SAndroid Build Coastguard Worker            # create a scope and do the conversion -> ivalue -> pyobject
2023*da0073e9SAndroid Build Coastguard Worker            # this func return a new pyobject that refcount + 1
2024*da0073e9SAndroid Build Coastguard Worker            inp_refcount = sys.getrefcount(inp)
2025*da0073e9SAndroid Build Coastguard Worker            ivalue_holder = torch._C._ivalue_debug_python_object(inp)
2026*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(inp_refcount + 1, sys.getrefcount(ivalue_holder))
2027*da0073e9SAndroid Build Coastguard Worker            return ivalue_holder + 1
2028*da0073e9SAndroid Build Coastguard Worker
2029*da0073e9SAndroid Build Coastguard Worker        test_input = 2200
2030*da0073e9SAndroid Build Coastguard Worker        before_count = sys.getrefcount(test_input)
2031*da0073e9SAndroid Build Coastguard Worker        test_func_scope_helper(test_input)
2032*da0073e9SAndroid Build Coastguard Worker        after_count = sys.getrefcount(test_input)
2033*da0073e9SAndroid Build Coastguard Worker
2034*da0073e9SAndroid Build Coastguard Worker        # after the test_func_scope_helper_call, the refcount of
2035*da0073e9SAndroid Build Coastguard Worker        # test_input should be equal to the original refcount
2036*da0073e9SAndroid Build Coastguard Worker        # otherwise we get either dangling pointer or memory leak!
2037*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(before_count, after_count)
2038*da0073e9SAndroid Build Coastguard Worker
2039*da0073e9SAndroid Build Coastguard Worker    def test_decompose_addmm(self):
2040*da0073e9SAndroid Build Coastguard Worker        def does_decompose():
2041*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
2042*da0073e9SAndroid Build Coastguard Worker            def addmm(mat, mat1, mat2):
2043*da0073e9SAndroid Build Coastguard Worker                a = mat.addmm(mat1, mat2)
2044*da0073e9SAndroid Build Coastguard Worker                b = mat.addmm(mat1, mat2, alpha=1.0, beta=1.0)
2045*da0073e9SAndroid Build Coastguard Worker                return a + b
2046*da0073e9SAndroid Build Coastguard Worker
2047*da0073e9SAndroid Build Coastguard Worker            mat = torch.randn(2, 2)
2048*da0073e9SAndroid Build Coastguard Worker            mat1 = torch.randn(2, 4)
2049*da0073e9SAndroid Build Coastguard Worker            mat2 = torch.randn(4, 2)
2050*da0073e9SAndroid Build Coastguard Worker
2051*da0073e9SAndroid Build Coastguard Worker            out_ref = addmm(mat, mat1, mat2)
2052*da0073e9SAndroid Build Coastguard Worker            self.run_pass('decompose_ops', addmm.graph)
2053*da0073e9SAndroid Build Coastguard Worker            out_test = addmm(mat, mat1, mat2)
2054*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out_ref, out_test)
2055*da0073e9SAndroid Build Coastguard Worker            FileCheck().check_not("addmm").run(str(addmm.graph))
2056*da0073e9SAndroid Build Coastguard Worker
2057*da0073e9SAndroid Build Coastguard Worker        def doesnt_decompose():
2058*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
2059*da0073e9SAndroid Build Coastguard Worker            def addmm(mat, mat1, mat2, alpha, beta):
2060*da0073e9SAndroid Build Coastguard Worker                a = mat.addmm(mat1, mat2, alpha=4.20, beta=2.0)
2061*da0073e9SAndroid Build Coastguard Worker                b = mat.addmm(mat1, mat2, alpha=int(alpha), beta=int(beta))
2062*da0073e9SAndroid Build Coastguard Worker
2063*da0073e9SAndroid Build Coastguard Worker                return a + b
2064*da0073e9SAndroid Build Coastguard Worker
2065*da0073e9SAndroid Build Coastguard Worker            orig = str(addmm.graph)
2066*da0073e9SAndroid Build Coastguard Worker            self.run_pass('decompose_ops', addmm.graph)
2067*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(orig == str(addmm.graph))
2068*da0073e9SAndroid Build Coastguard Worker
2069*da0073e9SAndroid Build Coastguard Worker        does_decompose()
2070*da0073e9SAndroid Build Coastguard Worker        doesnt_decompose()
2071*da0073e9SAndroid Build Coastguard Worker
2072*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
2073*da0073e9SAndroid Build Coastguard Worker    def test_sparse_tensors(self):
2074*da0073e9SAndroid Build Coastguard Worker        @torch.jit.ignore
2075*da0073e9SAndroid Build Coastguard Worker        def get_sparse():
2076*da0073e9SAndroid Build Coastguard Worker            return torch.sparse_coo_tensor((2, 3), dtype=torch.float32)
2077*da0073e9SAndroid Build Coastguard Worker
2078*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2079*da0073e9SAndroid Build Coastguard Worker        def test_is_sparse(input):
2080*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor) -> bool
2081*da0073e9SAndroid Build Coastguard Worker            return input.is_sparse
2082*da0073e9SAndroid Build Coastguard Worker
2083*da0073e9SAndroid Build Coastguard Worker        script_out_is_sparse = test_is_sparse(get_sparse())
2084*da0073e9SAndroid Build Coastguard Worker        script_out_is_dense = test_is_sparse(torch.randn(2, 3))
2085*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(script_out_is_sparse, True)
2086*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(script_out_is_dense, False)
2087*da0073e9SAndroid Build Coastguard Worker
2088*da0073e9SAndroid Build Coastguard Worker        def test_basic_sparse(input):
2089*da0073e9SAndroid Build Coastguard Worker            output = get_sparse()
2090*da0073e9SAndroid Build Coastguard Worker            return output, input
2091*da0073e9SAndroid Build Coastguard Worker
2092*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_basic_sparse, (get_sparse(),))
2093*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_basic_sparse, (torch.tensor([1]),))
2094*da0073e9SAndroid Build Coastguard Worker
2095*da0073e9SAndroid Build Coastguard Worker        def test_sparse_sum(input):
2096*da0073e9SAndroid Build Coastguard Worker            return torch.sparse.sum(input)
2097*da0073e9SAndroid Build Coastguard Worker
2098*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_sparse_sum, (get_sparse(),))
2099*da0073e9SAndroid Build Coastguard Worker
2100*da0073e9SAndroid Build Coastguard Worker        def test_sparse_mm(input1, input2):
2101*da0073e9SAndroid Build Coastguard Worker            return torch.sparse.mm(input1, input2)
2102*da0073e9SAndroid Build Coastguard Worker
2103*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_sparse_mm, (get_sparse(), torch.randn(3, 4)))
2104*da0073e9SAndroid Build Coastguard Worker
2105*da0073e9SAndroid Build Coastguard Worker        def test_sparse_addmm(input, input1, input2):
2106*da0073e9SAndroid Build Coastguard Worker            return torch.sparse.addmm(input, input1, input2)
2107*da0073e9SAndroid Build Coastguard Worker
2108*da0073e9SAndroid Build Coastguard Worker        def test_sparse_addmm_alpha_beta(input, input1, input2):
2109*da0073e9SAndroid Build Coastguard Worker            return torch.sparse.addmm(input, input1, input2, alpha=1.3, beta=1.5)
2110*da0073e9SAndroid Build Coastguard Worker
2111*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_sparse_addmm, (torch.randn(2, 4), get_sparse(), torch.randn(3, 4)))
2112*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_sparse_addmm_alpha_beta, (torch.randn(2, 4), get_sparse(), torch.randn(3, 4)))
2113*da0073e9SAndroid Build Coastguard Worker
2114*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
2115*da0073e9SAndroid Build Coastguard Worker    def test_sparse_csr_tensors(self):
2116*da0073e9SAndroid Build Coastguard Worker        @torch.jit.ignore
2117*da0073e9SAndroid Build Coastguard Worker        def get_sparse_csr():
2118*da0073e9SAndroid Build Coastguard Worker            return torch.randn(3, 3).to_sparse_csr()
2119*da0073e9SAndroid Build Coastguard Worker
2120*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2121*da0073e9SAndroid Build Coastguard Worker        def test_is_sparse_csr(input):
2122*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor) -> bool
2123*da0073e9SAndroid Build Coastguard Worker            return input.is_sparse_csr
2124*da0073e9SAndroid Build Coastguard Worker
2125*da0073e9SAndroid Build Coastguard Worker        script_out_is_sparse_csr = test_is_sparse_csr(get_sparse_csr())
2126*da0073e9SAndroid Build Coastguard Worker        script_out_is_dense_csr = test_is_sparse_csr(torch.randn(3, 3))
2127*da0073e9SAndroid Build Coastguard Worker
2128*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(script_out_is_sparse_csr, True)
2129*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(script_out_is_dense_csr, False)
2130*da0073e9SAndroid Build Coastguard Worker
2131*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
2132*da0073e9SAndroid Build Coastguard Worker    def test_device_not_equal(self):
2133*da0073e9SAndroid Build Coastguard Worker
2134*da0073e9SAndroid Build Coastguard Worker        def compare_device(x: torch.device):
2135*da0073e9SAndroid Build Coastguard Worker            return x != torch.device("cuda:0")
2136*da0073e9SAndroid Build Coastguard Worker
2137*da0073e9SAndroid Build Coastguard Worker        def compare_two_device(x: torch.device, y: torch.device):
2138*da0073e9SAndroid Build Coastguard Worker            return x != y
2139*da0073e9SAndroid Build Coastguard Worker
2140*da0073e9SAndroid Build Coastguard Worker        self.checkScript(compare_device, (torch.device("cuda:0"),))
2141*da0073e9SAndroid Build Coastguard Worker        self.checkScript(compare_two_device, (torch.device("cuda:0"), torch.device("cuda:1"), ))
2142*da0073e9SAndroid Build Coastguard Worker
2143*da0073e9SAndroid Build Coastguard Worker    def test_constant_prop_simple(self):
2144*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2145*da0073e9SAndroid Build Coastguard Worker        def constant_prop(input_int):
2146*da0073e9SAndroid Build Coastguard Worker            # type: (int) -> int
2147*da0073e9SAndroid Build Coastguard Worker            a = 2 * 3
2148*da0073e9SAndroid Build Coastguard Worker            b = a + 2
2149*da0073e9SAndroid Build Coastguard Worker            return b - input_int
2150*da0073e9SAndroid Build Coastguard Worker
2151*da0073e9SAndroid Build Coastguard Worker        out_ref = constant_prop(2)
2152*da0073e9SAndroid Build Coastguard Worker        self.run_pass('constant_propagation', constant_prop.graph)
2153*da0073e9SAndroid Build Coastguard Worker        out_test = constant_prop(2)
2154*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref, out_test)
2155*da0073e9SAndroid Build Coastguard Worker        graph_str = str(constant_prop.graph)
2156*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("aten::add" not in graph_str and "aten::mul" not in graph_str)
2157*da0073e9SAndroid Build Coastguard Worker        const = constant_prop.graph.findNode("prim::Constant").output().toIValue()
2158*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(const, 8)
2159*da0073e9SAndroid Build Coastguard Worker
2160*da0073e9SAndroid Build Coastguard Worker    def test_constant_prop_nested(self):
2161*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2162*da0073e9SAndroid Build Coastguard Worker        def constant_prop(a):
2163*da0073e9SAndroid Build Coastguard Worker            b = 2 + 1
2164*da0073e9SAndroid Build Coastguard Worker            if bool(a < 2):
2165*da0073e9SAndroid Build Coastguard Worker                c = b + 2
2166*da0073e9SAndroid Build Coastguard Worker            else:
2167*da0073e9SAndroid Build Coastguard Worker                c = b - 2
2168*da0073e9SAndroid Build Coastguard Worker            return c
2169*da0073e9SAndroid Build Coastguard Worker        out_ref = constant_prop(torch.tensor(2))
2170*da0073e9SAndroid Build Coastguard Worker        self.run_pass('constant_propagation', constant_prop.graph)
2171*da0073e9SAndroid Build Coastguard Worker        out_test = constant_prop(torch.tensor(2))
2172*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref, out_test)
2173*da0073e9SAndroid Build Coastguard Worker        if_node = constant_prop.graph.findNode("prim::If")
2174*da0073e9SAndroid Build Coastguard Worker        for block in if_node.blocks():
2175*da0073e9SAndroid Build Coastguard Worker            for node in block.nodes():
2176*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(node.kind() == "prim::Constant")
2177*da0073e9SAndroid Build Coastguard Worker
2178*da0073e9SAndroid Build Coastguard Worker    def test_constant_prop_print(self):
2179*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2180*da0073e9SAndroid Build Coastguard Worker        def constant_prop(input_tensor):
2181*da0073e9SAndroid Build Coastguard Worker            a = 2 * 3
2182*da0073e9SAndroid Build Coastguard Worker            print(a)
2183*da0073e9SAndroid Build Coastguard Worker            b = a + 2
2184*da0073e9SAndroid Build Coastguard Worker            return b + input_tensor
2185*da0073e9SAndroid Build Coastguard Worker
2186*da0073e9SAndroid Build Coastguard Worker        self.run_pass('constant_propagation', constant_prop.graph)
2187*da0073e9SAndroid Build Coastguard Worker        graph = constant_prop.graph
2188*da0073e9SAndroid Build Coastguard Worker        print_node = graph.findNode("prim::Print")
2189*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(print_node.input().toIValue() == 6)
2190*da0073e9SAndroid Build Coastguard Worker
2191*da0073e9SAndroid Build Coastguard Worker    def test_constant_prop_rand(self):
2192*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2193*da0073e9SAndroid Build Coastguard Worker        def constant_prop():
2194*da0073e9SAndroid Build Coastguard Worker            a = torch.randn([3])
2195*da0073e9SAndroid Build Coastguard Worker            b = a + 2
2196*da0073e9SAndroid Build Coastguard Worker            return b
2197*da0073e9SAndroid Build Coastguard Worker
2198*da0073e9SAndroid Build Coastguard Worker        self.run_pass('constant_propagation', constant_prop.graph)
2199*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("aten::randn" in str(constant_prop.graph))
2200*da0073e9SAndroid Build Coastguard Worker
2201*da0073e9SAndroid Build Coastguard Worker    def test_constant_prop_none(self):
2202*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2203*da0073e9SAndroid Build Coastguard Worker        def typed_none():
2204*da0073e9SAndroid Build Coastguard Worker            # type: () -> Optional[int]
2205*da0073e9SAndroid Build Coastguard Worker            return None
2206*da0073e9SAndroid Build Coastguard Worker
2207*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2208*da0073e9SAndroid Build Coastguard Worker        def constant_prop():
2209*da0073e9SAndroid Build Coastguard Worker            a = typed_none()
2210*da0073e9SAndroid Build Coastguard Worker            b = typed_none()
2211*da0073e9SAndroid Build Coastguard Worker            if (a is None and b is None):
2212*da0073e9SAndroid Build Coastguard Worker                a = 2
2213*da0073e9SAndroid Build Coastguard Worker            else:
2214*da0073e9SAndroid Build Coastguard Worker                a = 1
2215*da0073e9SAndroid Build Coastguard Worker            return a
2216*da0073e9SAndroid Build Coastguard Worker
2217*da0073e9SAndroid Build Coastguard Worker        self.run_pass('constant_propagation', constant_prop.graph)
2218*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("prim::Constant").run(constant_prop.graph)
2219*da0073e9SAndroid Build Coastguard Worker
2220*da0073e9SAndroid Build Coastguard Worker    def test_constant_prop_if_inline(self):
2221*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2222*da0073e9SAndroid Build Coastguard Worker        def constant_prop():
2223*da0073e9SAndroid Build Coastguard Worker            cond = True
2224*da0073e9SAndroid Build Coastguard Worker            a = 1
2225*da0073e9SAndroid Build Coastguard Worker            if cond:
2226*da0073e9SAndroid Build Coastguard Worker                a = 1 * 2
2227*da0073e9SAndroid Build Coastguard Worker            else:
2228*da0073e9SAndroid Build Coastguard Worker                a = 1 // 0
2229*da0073e9SAndroid Build Coastguard Worker            return a
2230*da0073e9SAndroid Build Coastguard Worker
2231*da0073e9SAndroid Build Coastguard Worker        # testing that 1 // 0 error is not thrownn
2232*da0073e9SAndroid Build Coastguard Worker        self.run_pass('constant_propagation', constant_prop.graph)
2233*da0073e9SAndroid Build Coastguard Worker
2234*da0073e9SAndroid Build Coastguard Worker    def test_constant_prop_exception(self):
2235*da0073e9SAndroid Build Coastguard Worker        # checking y = a[4] does not error in constant propagation
2236*da0073e9SAndroid Build Coastguard Worker        def bad_index(x):
2237*da0073e9SAndroid Build Coastguard Worker            # type: (bool)
2238*da0073e9SAndroid Build Coastguard Worker            y = 0
2239*da0073e9SAndroid Build Coastguard Worker            if x:
2240*da0073e9SAndroid Build Coastguard Worker                a = [1, 2, 3]
2241*da0073e9SAndroid Build Coastguard Worker                y = a[4]
2242*da0073e9SAndroid Build Coastguard Worker            return y
2243*da0073e9SAndroid Build Coastguard Worker
2244*da0073e9SAndroid Build Coastguard Worker        self.checkScript(bad_index, (False,))
2245*da0073e9SAndroid Build Coastguard Worker
2246*da0073e9SAndroid Build Coastguard Worker    def test_constant_prop_aliasing_type(self):
2247*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2248*da0073e9SAndroid Build Coastguard Worker        def foo():
2249*da0073e9SAndroid Build Coastguard Worker            return len([1]), len(torch.tensor([2]))
2250*da0073e9SAndroid Build Coastguard Worker
2251*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_dag("aten::tensor").check_dag("aten::len").run(foo.graph)
2252*da0073e9SAndroid Build Coastguard Worker
2253*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2254*da0073e9SAndroid Build Coastguard Worker        def fn():
2255*da0073e9SAndroid Build Coastguard Worker            if 1 == 1:
2256*da0073e9SAndroid Build Coastguard Worker                return 1
2257*da0073e9SAndroid Build Coastguard Worker            else:
2258*da0073e9SAndroid Build Coastguard Worker                return 2
2259*da0073e9SAndroid Build Coastguard Worker
2260*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("prim::If").run(fn.graph)
2261*da0073e9SAndroid Build Coastguard Worker
2262*da0073e9SAndroid Build Coastguard Worker    def test_unchecked_cast(self):
2263*da0073e9SAndroid Build Coastguard Worker        def test(cond):
2264*da0073e9SAndroid Build Coastguard Worker            # type: (bool)
2265*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor([10])
2266*da0073e9SAndroid Build Coastguard Worker            if cond:
2267*da0073e9SAndroid Build Coastguard Worker                b = None
2268*da0073e9SAndroid Build Coastguard Worker            else:
2269*da0073e9SAndroid Build Coastguard Worker                b = a
2270*da0073e9SAndroid Build Coastguard Worker            if b is not None:
2271*da0073e9SAndroid Build Coastguard Worker                b[0] = 5
2272*da0073e9SAndroid Build Coastguard Worker            return a.int()
2273*da0073e9SAndroid Build Coastguard Worker
2274*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test, (True,))
2275*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test, (False,))
2276*da0073e9SAndroid Build Coastguard Worker
2277*da0073e9SAndroid Build Coastguard Worker    def test_constant_prop_if_constant(self):
2278*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2279*da0073e9SAndroid Build Coastguard Worker        def constant_prop(a, b):
2280*da0073e9SAndroid Build Coastguard Worker            c0 = 1
2281*da0073e9SAndroid Build Coastguard Worker            c1 = 1
2282*da0073e9SAndroid Build Coastguard Worker            c2 = 1
2283*da0073e9SAndroid Build Coastguard Worker            if bool(a):  # -> c0, c1
2284*da0073e9SAndroid Build Coastguard Worker                if bool(b):  # -> c0
2285*da0073e9SAndroid Build Coastguard Worker                    if 1 == 1:  # -> c0
2286*da0073e9SAndroid Build Coastguard Worker                        c0 = c0 + 1
2287*da0073e9SAndroid Build Coastguard Worker                        if 1 == 2:
2288*da0073e9SAndroid Build Coastguard Worker                            c1 = c1 + 1
2289*da0073e9SAndroid Build Coastguard Worker                            c2 = c2 + 1
2290*da0073e9SAndroid Build Coastguard Worker            else:  # -> c0, c1
2291*da0073e9SAndroid Build Coastguard Worker                c1 = c1 + 1
2292*da0073e9SAndroid Build Coastguard Worker
2293*da0073e9SAndroid Build Coastguard Worker            if 1 == 1:  # inlined
2294*da0073e9SAndroid Build Coastguard Worker                c0 = c0 + 1  # dynamic
2295*da0073e9SAndroid Build Coastguard Worker                c2 = c2 + 4  # set to 5
2296*da0073e9SAndroid Build Coastguard Worker            return a + c0 + c1 + c2
2297*da0073e9SAndroid Build Coastguard Worker
2298*da0073e9SAndroid Build Coastguard Worker        graph = constant_prop.graph
2299*da0073e9SAndroid Build Coastguard Worker        self.run_pass('constant_propagation', graph)
2300*da0073e9SAndroid Build Coastguard Worker        ifs = graph.findAllNodes("prim::If", recurse=False)
2301*da0073e9SAndroid Build Coastguard Worker        snd_if_inlined = len(ifs) == 1
2302*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(snd_if_inlined)
2303*da0073e9SAndroid Build Coastguard Worker        first_if = ifs[0]
2304*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(first_if.outputsSize() == 2)
2305*da0073e9SAndroid Build Coastguard Worker        second_if = first_if.findNode("prim::If", recurse=False)
2306*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(second_if.outputsSize() == 1)
2307*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(second_if.findNode("prim::If") is None)
2308*da0073e9SAndroid Build Coastguard Worker
2309*da0073e9SAndroid Build Coastguard Worker    def test_constant_prop_loop_constant(self):
2310*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2311*da0073e9SAndroid Build Coastguard Worker        def constant_prop(cond, iter):
2312*da0073e9SAndroid Build Coastguard Worker            # type: (bool, int) -> int
2313*da0073e9SAndroid Build Coastguard Worker            b = 0
2314*da0073e9SAndroid Build Coastguard Worker            while True:
2315*da0073e9SAndroid Build Coastguard Worker                print("stays")
2316*da0073e9SAndroid Build Coastguard Worker            for _ in range(2):
2317*da0073e9SAndroid Build Coastguard Worker                print("stays")
2318*da0073e9SAndroid Build Coastguard Worker            for _ in range(iter):
2319*da0073e9SAndroid Build Coastguard Worker                print("stays")
2320*da0073e9SAndroid Build Coastguard Worker            while cond:
2321*da0073e9SAndroid Build Coastguard Worker                print("stays")
2322*da0073e9SAndroid Build Coastguard Worker            while False:
2323*da0073e9SAndroid Build Coastguard Worker                print("removed")
2324*da0073e9SAndroid Build Coastguard Worker            for _i in range(0):
2325*da0073e9SAndroid Build Coastguard Worker                print("removed")
2326*da0073e9SAndroid Build Coastguard Worker            for _i in range(-4):
2327*da0073e9SAndroid Build Coastguard Worker                print("removed")
2328*da0073e9SAndroid Build Coastguard Worker            return b
2329*da0073e9SAndroid Build Coastguard Worker
2330*da0073e9SAndroid Build Coastguard Worker        self.run_pass('constant_propagation', constant_prop.graph)
2331*da0073e9SAndroid Build Coastguard Worker        graph = canonical(constant_prop.graph)
2332*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(graph.count("removed") == 0)
2333*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(graph.count("stays") == 1)  # constant gets pooled
2334*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(graph.count("prim::Print") == 4)
2335*da0073e9SAndroid Build Coastguard Worker
2336*da0073e9SAndroid Build Coastguard Worker    def test_constant_prop_remove_output(self):
2337*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2338*da0073e9SAndroid Build Coastguard Worker        def constant_prop(iter):
2339*da0073e9SAndroid Build Coastguard Worker            # type: (int) -> None
2340*da0073e9SAndroid Build Coastguard Worker            a = 1
2341*da0073e9SAndroid Build Coastguard Worker            b = 1
2342*da0073e9SAndroid Build Coastguard Worker            c = 1
2343*da0073e9SAndroid Build Coastguard Worker            for i in range(iter):
2344*da0073e9SAndroid Build Coastguard Worker                if 1 == 2:
2345*da0073e9SAndroid Build Coastguard Worker                    a = 10
2346*da0073e9SAndroid Build Coastguard Worker                if i == 5:
2347*da0073e9SAndroid Build Coastguard Worker                    b = 2
2348*da0073e9SAndroid Build Coastguard Worker                    c = 3
2349*da0073e9SAndroid Build Coastguard Worker            print(a, b, c)
2350*da0073e9SAndroid Build Coastguard Worker
2351*da0073e9SAndroid Build Coastguard Worker        graph = constant_prop.graph
2352*da0073e9SAndroid Build Coastguard Worker        self.run_pass('constant_propagation', graph)
2353*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(graph.findNode("prim::Loop").outputsSize() == 2)
2354*da0073e9SAndroid Build Coastguard Worker
2355*da0073e9SAndroid Build Coastguard Worker    # TODO(gmagogsfm): Refactor this test to reduce complexity.
2356*da0073e9SAndroid Build Coastguard Worker    def test_constant_insertion(self):
2357*da0073e9SAndroid Build Coastguard Worker        funcs_template = dedent('''
2358*da0073e9SAndroid Build Coastguard Worker        def func():
2359*da0073e9SAndroid Build Coastguard Worker            return {constant_constructor}
2360*da0073e9SAndroid Build Coastguard Worker        ''')
2361*da0073e9SAndroid Build Coastguard Worker
2362*da0073e9SAndroid Build Coastguard Worker        # constants: primitives: int, double, bool, str, lists of primitives,
2363*da0073e9SAndroid Build Coastguard Worker        # and tuples
2364*da0073e9SAndroid Build Coastguard Worker        def check_constant(constant_constructor):
2365*da0073e9SAndroid Build Coastguard Worker            scope = {}
2366*da0073e9SAndroid Build Coastguard Worker            funcs_str = funcs_template.format(constant_constructor=constant_constructor)
2367*da0073e9SAndroid Build Coastguard Worker            execWrapper(funcs_str, globals(), scope)
2368*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit(funcs_str)
2369*da0073e9SAndroid Build Coastguard Worker            f_script = cu.func
2370*da0073e9SAndroid Build Coastguard Worker            self.run_pass('constant_propagation', f_script.graph)
2371*da0073e9SAndroid Build Coastguard Worker            FileCheck().check_count("prim::Constant", 1, exactly=True).run(f_script.graph)
2372*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(scope['func'](), f_script())
2373*da0073e9SAndroid Build Coastguard Worker            imported = self.getExportImportCopy(f_script)
2374*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(imported(), f_script())
2375*da0073e9SAndroid Build Coastguard Worker
2376*da0073e9SAndroid Build Coastguard Worker        constants = ["None", "-.5", "0", "1", "True", "False", "''", "'a'", "'b'", "torch.tensor(1)",
2377*da0073e9SAndroid Build Coastguard Worker                     "[True, False]", "[0., .5]", "[torch.tensor(4), torch.tensor(2)]", "[0, 1]", "['0', '1']",
2378*da0073e9SAndroid Build Coastguard Worker                     "[True, None]", "[.5, None, .2]"]
2379*da0073e9SAndroid Build Coastguard Worker
2380*da0073e9SAndroid Build Coastguard Worker        for type in ["Tensor", "str", "int", "float", "bool"]:
2381*da0073e9SAndroid Build Coastguard Worker            constants.append("torch.jit.annotate(List[ " + type + "], [])")
2382*da0073e9SAndroid Build Coastguard Worker
2383*da0073e9SAndroid Build Coastguard Worker        for constant in constants:
2384*da0073e9SAndroid Build Coastguard Worker            check_constant(constant)
2385*da0073e9SAndroid Build Coastguard Worker
2386*da0073e9SAndroid Build Coastguard Worker        for key_type in ["str", "int", "float"]:
2387*da0073e9SAndroid Build Coastguard Worker            for value_type in ["Tensor", "bool", "str", "int", "float"]:
2388*da0073e9SAndroid Build Coastguard Worker                check_constant("torch.jit.annotate(Dict[ " + key_type + ", " + value_type + "], {})")
2389*da0073e9SAndroid Build Coastguard Worker                check_constant("torch.jit.annotate(Dict[ " + key_type + ", Optional[" + value_type + "]], {})")
2390*da0073e9SAndroid Build Coastguard Worker
2391*da0073e9SAndroid Build Coastguard Worker        for i in range(len(constants)):
2392*da0073e9SAndroid Build Coastguard Worker            for j in range(i + 1, len(constants)):
2393*da0073e9SAndroid Build Coastguard Worker                tup_constant = constants[i] + ", " + constants[j]
2394*da0073e9SAndroid Build Coastguard Worker                check_constant(tup_constant)
2395*da0073e9SAndroid Build Coastguard Worker
2396*da0073e9SAndroid Build Coastguard Worker        dict_constants = []
2397*da0073e9SAndroid Build Coastguard Worker        for i in range(len(constants)):
2398*da0073e9SAndroid Build Coastguard Worker            # check_constant constructs the second dict with another Tensor
2399*da0073e9SAndroid Build Coastguard Worker            # which fails the comparison
2400*da0073e9SAndroid Build Coastguard Worker            if not isinstance(eval(constants[i]), (str, int, float)):
2401*da0073e9SAndroid Build Coastguard Worker                continue
2402*da0073e9SAndroid Build Coastguard Worker            for j in range(len(constants)):
2403*da0073e9SAndroid Build Coastguard Worker                dict_constant = "{ " + constants[i] + ": " + constants[j] + "}"
2404*da0073e9SAndroid Build Coastguard Worker                check_constant(dict_constant)
2405*da0073e9SAndroid Build Coastguard Worker                dict_constants.append(dict_constant)
2406*da0073e9SAndroid Build Coastguard Worker        constants = constants + dict_constants
2407*da0073e9SAndroid Build Coastguard Worker
2408*da0073e9SAndroid Build Coastguard Worker        # testing node hashing
2409*da0073e9SAndroid Build Coastguard Worker        funcs_template = dedent('''
2410*da0073e9SAndroid Build Coastguard Worker        def func():
2411*da0073e9SAndroid Build Coastguard Worker            print({constant_constructor})
2412*da0073e9SAndroid Build Coastguard Worker        ''')
2413*da0073e9SAndroid Build Coastguard Worker        single_elem_tuples = ("(" + x + ",)" for x in constants)
2414*da0073e9SAndroid Build Coastguard Worker        input_arg = ", ".join(single_elem_tuples)
2415*da0073e9SAndroid Build Coastguard Worker        scope = {}
2416*da0073e9SAndroid Build Coastguard Worker        funcs_str = funcs_template.format(constant_constructor=input_arg)
2417*da0073e9SAndroid Build Coastguard Worker        execWrapper(funcs_str, globals(), scope)
2418*da0073e9SAndroid Build Coastguard Worker        cu = torch.jit.CompilationUnit(funcs_str)
2419*da0073e9SAndroid Build Coastguard Worker        f_script = cu.func
2420*da0073e9SAndroid Build Coastguard Worker        self.run_pass('constant_propagation', f_script.graph)
2421*da0073e9SAndroid Build Coastguard Worker        # prim::None return adds one constant
2422*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(constants) + 1, str(f_script.graph).count("prim::Constant"))
2423*da0073e9SAndroid Build Coastguard Worker        self.run_pass('cse', f_script.graph)
2424*da0073e9SAndroid Build Coastguard Worker        # node hashing correctly working, no CSE occurs
2425*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(constants) + 1, str(f_script.graph).count("prim::Constant"))
2426*da0073e9SAndroid Build Coastguard Worker
2427*da0073e9SAndroid Build Coastguard Worker        funcs_template = dedent('''
2428*da0073e9SAndroid Build Coastguard Worker        def func():
2429*da0073e9SAndroid Build Coastguard Worker            a = {constant_constructor}
2430*da0073e9SAndroid Build Coastguard Worker            print(a)
2431*da0073e9SAndroid Build Coastguard Worker            b = {constant_constructor}
2432*da0073e9SAndroid Build Coastguard Worker            print(b)
2433*da0073e9SAndroid Build Coastguard Worker        ''')
2434*da0073e9SAndroid Build Coastguard Worker
2435*da0073e9SAndroid Build Coastguard Worker        # generate dicts with built-in types (excluding torch.Tensor)
2436*da0073e9SAndroid Build Coastguard Worker        xprod = itertools.product(constants, constants)
2437*da0073e9SAndroid Build Coastguard Worker
2438*da0073e9SAndroid Build Coastguard Worker        # test that equal tuples and dicts correctly work with node hashing
2439*da0073e9SAndroid Build Coastguard Worker        for tup in ("(" + x + ",)" for x in constants):
2440*da0073e9SAndroid Build Coastguard Worker            funcs_str = funcs_template.format(constant_constructor=tup)
2441*da0073e9SAndroid Build Coastguard Worker            scope = {}
2442*da0073e9SAndroid Build Coastguard Worker            execWrapper(funcs_str, globals(), scope)
2443*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit(funcs_str)
2444*da0073e9SAndroid Build Coastguard Worker            f_script = cu.func
2445*da0073e9SAndroid Build Coastguard Worker            self.run_pass('constant_propagation_immutable_types', f_script.graph)
2446*da0073e9SAndroid Build Coastguard Worker            num_constants = str(f_script.graph).count("prim::Constant")
2447*da0073e9SAndroid Build Coastguard Worker            self.run_pass('cse', f_script.graph)
2448*da0073e9SAndroid Build Coastguard Worker            FileCheck().check_count("prim::Constant", num_constants, exactly=True).run(f_script.graph)
2449*da0073e9SAndroid Build Coastguard Worker
2450*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
2451*da0073e9SAndroid Build Coastguard Worker    def test_cuda_export_restore(self):
2452*da0073e9SAndroid Build Coastguard Worker        class Sub(torch.jit.ScriptModule):
2453*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2454*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2455*da0073e9SAndroid Build Coastguard Worker                self.weight = nn.Parameter(torch.randn(3, 4))
2456*da0073e9SAndroid Build Coastguard Worker
2457*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
2458*da0073e9SAndroid Build Coastguard Worker            def forward(self, thing):
2459*da0073e9SAndroid Build Coastguard Worker                return self.weight + thing
2460*da0073e9SAndroid Build Coastguard Worker
2461*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
2462*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2463*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2464*da0073e9SAndroid Build Coastguard Worker                self.mod = Sub()
2465*da0073e9SAndroid Build Coastguard Worker
2466*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
2467*da0073e9SAndroid Build Coastguard Worker            def forward(self, v):
2468*da0073e9SAndroid Build Coastguard Worker                return self.mod(v)
2469*da0073e9SAndroid Build Coastguard Worker        m = M()
2470*da0073e9SAndroid Build Coastguard Worker        m.cuda()
2471*da0073e9SAndroid Build Coastguard Worker        m2 = self.getExportImportCopy(m)
2472*da0073e9SAndroid Build Coastguard Worker        m2.cuda()
2473*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(3, 4).cuda()
2474*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m(input), m2(input))
2475*da0073e9SAndroid Build Coastguard Worker
2476*da0073e9SAndroid Build Coastguard Worker    @slowTest
2477*da0073e9SAndroid Build Coastguard Worker    def test_export_batchnorm(self):
2478*da0073e9SAndroid Build Coastguard Worker        for mode in ['eval', 'train']:
2479*da0073e9SAndroid Build Coastguard Worker            for clazz in [
2480*da0073e9SAndroid Build Coastguard Worker                    torch.nn.BatchNorm1d(100),
2481*da0073e9SAndroid Build Coastguard Worker                    torch.nn.BatchNorm1d(100, affine=False),
2482*da0073e9SAndroid Build Coastguard Worker                    torch.nn.BatchNorm2d(100),
2483*da0073e9SAndroid Build Coastguard Worker                    torch.nn.BatchNorm2d(100, affine=False)]:
2484*da0073e9SAndroid Build Coastguard Worker                getattr(clazz, mode)()
2485*da0073e9SAndroid Build Coastguard Worker                input = torch.randn(20, 100) if isinstance(clazz, torch.nn.BatchNorm1d) else \
2486*da0073e9SAndroid Build Coastguard Worker                    torch.randn(20, 100, 35, 45)
2487*da0073e9SAndroid Build Coastguard Worker                traced = torch.jit.trace(clazz, (input,))
2488*da0073e9SAndroid Build Coastguard Worker                imported = self.getExportImportCopy(traced)
2489*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(20, 100) if isinstance(clazz, torch.nn.BatchNorm1d) else \
2490*da0073e9SAndroid Build Coastguard Worker                    torch.randn(20, 100, 35, 45)
2491*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(traced(x), imported(x))
2492*da0073e9SAndroid Build Coastguard Worker
2493*da0073e9SAndroid Build Coastguard Worker    def test_export_rnn(self):
2494*da0073e9SAndroid Build Coastguard Worker        for clazz in [nn.RNN(10, 20, 2), nn.GRU(10, 20, 2)]:
2495*da0073e9SAndroid Build Coastguard Worker            class RNNTest(torch.nn.Module):
2496*da0073e9SAndroid Build Coastguard Worker                def __init__(self) -> None:
2497*da0073e9SAndroid Build Coastguard Worker                    super().__init__()
2498*da0073e9SAndroid Build Coastguard Worker                    self.rnn = clazz
2499*da0073e9SAndroid Build Coastguard Worker
2500*da0073e9SAndroid Build Coastguard Worker                def forward(self, x, lengths, h0):
2501*da0073e9SAndroid Build Coastguard Worker                    packed = torch.nn.utils.rnn.pack_padded_sequence(x, lengths)
2502*da0073e9SAndroid Build Coastguard Worker                    out, h = self.rnn(packed, h0)
2503*da0073e9SAndroid Build Coastguard Worker                    padded_outs, _ = torch.nn.utils.rnn.pad_packed_sequence(out)
2504*da0073e9SAndroid Build Coastguard Worker                    return padded_outs
2505*da0073e9SAndroid Build Coastguard Worker
2506*da0073e9SAndroid Build Coastguard Worker            test = RNNTest()
2507*da0073e9SAndroid Build Coastguard Worker
2508*da0073e9SAndroid Build Coastguard Worker            traced = torch.jit.trace(test, (torch.randn(5, 3, 10), torch.LongTensor([3, 2, 1]), torch.randn(2, 3, 20)))
2509*da0073e9SAndroid Build Coastguard Worker            imported = self.getExportImportCopy(traced)
2510*da0073e9SAndroid Build Coastguard Worker            # NB: We make sure to pass in a batch with a different max sequence
2511*da0073e9SAndroid Build Coastguard Worker            # length to ensure that the argument stashing for pad_packed works
2512*da0073e9SAndroid Build Coastguard Worker            # properly.
2513*da0073e9SAndroid Build Coastguard Worker            x, lengths, h0 = torch.randn(7, 4, 10), torch.LongTensor([7, 3, 2, 1]), torch.randn(2, 4, 20)
2514*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(traced(x, lengths, h0), imported(x, lengths, h0))
2515*da0073e9SAndroid Build Coastguard Worker
2516*da0073e9SAndroid Build Coastguard Worker    def test_export_lstm(self):
2517*da0073e9SAndroid Build Coastguard Worker        class LSTMTest(torch.nn.Module):
2518*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2519*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2520*da0073e9SAndroid Build Coastguard Worker                self.rnn = nn.LSTM(10, 20, 2)
2521*da0073e9SAndroid Build Coastguard Worker
2522*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, lengths, hiddens):
2523*da0073e9SAndroid Build Coastguard Worker                h0, c0 = hiddens
2524*da0073e9SAndroid Build Coastguard Worker                packed = torch.nn.utils.rnn.pack_padded_sequence(x, lengths)
2525*da0073e9SAndroid Build Coastguard Worker                out, (h, c) = self.rnn(packed, (h0, c0))
2526*da0073e9SAndroid Build Coastguard Worker                padded_outs, _ = torch.nn.utils.rnn.pad_packed_sequence(out)
2527*da0073e9SAndroid Build Coastguard Worker                return padded_outs
2528*da0073e9SAndroid Build Coastguard Worker
2529*da0073e9SAndroid Build Coastguard Worker        test = LSTMTest()
2530*da0073e9SAndroid Build Coastguard Worker
2531*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(test, (torch.randn(5, 3, 10),
2532*da0073e9SAndroid Build Coastguard Worker                                        torch.LongTensor([3, 2, 1]),
2533*da0073e9SAndroid Build Coastguard Worker                                        (torch.randn(2, 3, 20), torch.randn(2, 3, 20))))
2534*da0073e9SAndroid Build Coastguard Worker        imported = self.getExportImportCopy(traced)
2535*da0073e9SAndroid Build Coastguard Worker        x, lengths, h0, c0 = \
2536*da0073e9SAndroid Build Coastguard Worker            torch.randn(7, 3, 10), torch.LongTensor([7, 5, 2]), torch.randn(2, 3, 20), torch.randn(2, 3, 20)
2537*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced(x, lengths, (h0, c0)), imported(x, lengths, (h0, c0)))
2538*da0073e9SAndroid Build Coastguard Worker
2539*da0073e9SAndroid Build Coastguard Worker    def test_unique_state_dict(self):
2540*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
2541*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2542*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2543*da0073e9SAndroid Build Coastguard Worker                shared_param = torch.nn.Parameter(torch.ones(1))
2544*da0073e9SAndroid Build Coastguard Worker                self.register_parameter('w1', shared_param)
2545*da0073e9SAndroid Build Coastguard Worker                self.register_parameter('w2', shared_param)
2546*da0073e9SAndroid Build Coastguard Worker
2547*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
2548*da0073e9SAndroid Build Coastguard Worker                return input + self.w1 + self.w2
2549*da0073e9SAndroid Build Coastguard Worker
2550*da0073e9SAndroid Build Coastguard Worker        model = MyModule()
2551*da0073e9SAndroid Build Coastguard Worker        unittest.TestCase.assertEqual(
2552*da0073e9SAndroid Build Coastguard Worker            self, len(torch.jit._unique_state_dict(model, keep_vars=False)), 1)
2553*da0073e9SAndroid Build Coastguard Worker        unittest.TestCase.assertEqual(
2554*da0073e9SAndroid Build Coastguard Worker            self, len(torch.jit._unique_state_dict(model, keep_vars=True)), 1)
2555*da0073e9SAndroid Build Coastguard Worker
2556*da0073e9SAndroid Build Coastguard Worker    def test_export_dropout(self):
2557*da0073e9SAndroid Build Coastguard Worker        test = torch.nn.Dropout()
2558*da0073e9SAndroid Build Coastguard Worker        test.eval()
2559*da0073e9SAndroid Build Coastguard Worker
2560*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(test, (torch.rand(3, 4),), check_trace=False)
2561*da0073e9SAndroid Build Coastguard Worker        imported = self.getExportImportCopy(traced)
2562*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 4)
2563*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced(x), imported(x))
2564*da0073e9SAndroid Build Coastguard Worker
2565*da0073e9SAndroid Build Coastguard Worker    def test_pretty_printer(self):
2566*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2567*da0073e9SAndroid Build Coastguard Worker        def if_test(a, b):
2568*da0073e9SAndroid Build Coastguard Worker            # FIXME: use 0 instead of a.
2569*da0073e9SAndroid Build Coastguard Worker            # c = 0
2570*da0073e9SAndroid Build Coastguard Worker            c = a
2571*da0073e9SAndroid Build Coastguard Worker            if bool(a < b):
2572*da0073e9SAndroid Build Coastguard Worker                c = b
2573*da0073e9SAndroid Build Coastguard Worker            else:
2574*da0073e9SAndroid Build Coastguard Worker                c = a
2575*da0073e9SAndroid Build Coastguard Worker            return c
2576*da0073e9SAndroid Build Coastguard Worker
2577*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2578*da0073e9SAndroid Build Coastguard Worker        def if_one(a, b):
2579*da0073e9SAndroid Build Coastguard Worker            c = b
2580*da0073e9SAndroid Build Coastguard Worker            if bool(a < b):
2581*da0073e9SAndroid Build Coastguard Worker                c = a
2582*da0073e9SAndroid Build Coastguard Worker            return c
2583*da0073e9SAndroid Build Coastguard Worker
2584*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2585*da0073e9SAndroid Build Coastguard Worker        def while_test(a, i):
2586*da0073e9SAndroid Build Coastguard Worker            while bool(i < 3):
2587*da0073e9SAndroid Build Coastguard Worker                a *= a
2588*da0073e9SAndroid Build Coastguard Worker                i += 1
2589*da0073e9SAndroid Build Coastguard Worker            return a
2590*da0073e9SAndroid Build Coastguard Worker
2591*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2592*da0073e9SAndroid Build Coastguard Worker        def while_if_test(a, b):
2593*da0073e9SAndroid Build Coastguard Worker            c = 0
2594*da0073e9SAndroid Build Coastguard Worker            while bool(a < 10):
2595*da0073e9SAndroid Build Coastguard Worker                a = a + 1
2596*da0073e9SAndroid Build Coastguard Worker                b = b + 1
2597*da0073e9SAndroid Build Coastguard Worker                if bool(a > b):
2598*da0073e9SAndroid Build Coastguard Worker                    c = 2
2599*da0073e9SAndroid Build Coastguard Worker                else:
2600*da0073e9SAndroid Build Coastguard Worker                    c = 3
2601*da0073e9SAndroid Build Coastguard Worker            return a + 1 + c
2602*da0073e9SAndroid Build Coastguard Worker
2603*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2604*da0073e9SAndroid Build Coastguard Worker        def loop_use_test(y):
2605*da0073e9SAndroid Build Coastguard Worker            x = y + 1
2606*da0073e9SAndroid Build Coastguard Worker            z = x + 5
2607*da0073e9SAndroid Build Coastguard Worker            while bool(y < 8):
2608*da0073e9SAndroid Build Coastguard Worker                y += 1
2609*da0073e9SAndroid Build Coastguard Worker                z = x
2610*da0073e9SAndroid Build Coastguard Worker            return x, z
2611*da0073e9SAndroid Build Coastguard Worker
2612*da0073e9SAndroid Build Coastguard Worker        @torch.jit.ignore
2613*da0073e9SAndroid Build Coastguard Worker        def python_fn(x):
2614*da0073e9SAndroid Build Coastguard Worker            return x + 10
2615*da0073e9SAndroid Build Coastguard Worker
2616*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2617*da0073e9SAndroid Build Coastguard Worker        def python_op_name_test(y):
2618*da0073e9SAndroid Build Coastguard Worker            return python_fn(y)
2619*da0073e9SAndroid Build Coastguard Worker
2620*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2621*da0073e9SAndroid Build Coastguard Worker        def empty_int_list_test(y):
2622*da0073e9SAndroid Build Coastguard Worker            x = torch.jit.annotate(List[int], [])
2623*da0073e9SAndroid Build Coastguard Worker            return x[0]
2624*da0073e9SAndroid Build Coastguard Worker
2625*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2626*da0073e9SAndroid Build Coastguard Worker        def empty_float_list_test(y):
2627*da0073e9SAndroid Build Coastguard Worker            return [1.0, 2.0, 3.0]
2628*da0073e9SAndroid Build Coastguard Worker
2629*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2630*da0073e9SAndroid Build Coastguard Worker        def print_weird_test(y):
2631*da0073e9SAndroid Build Coastguard Worker            print("hi\016")
2632*da0073e9SAndroid Build Coastguard Worker
2633*da0073e9SAndroid Build Coastguard Worker        self.assertExpected(if_test.code, "if_test")
2634*da0073e9SAndroid Build Coastguard Worker        self.assertExpected(if_one.code, "if_one")
2635*da0073e9SAndroid Build Coastguard Worker        self.assertExpected(while_test.code, "while_test")
2636*da0073e9SAndroid Build Coastguard Worker        self.assertExpected(while_if_test.code, "while_if_test")
2637*da0073e9SAndroid Build Coastguard Worker        self.assertExpected(loop_use_test.code, "loop_use_test")
2638*da0073e9SAndroid Build Coastguard Worker        self.assertExpected(python_op_name_test.code, "python_op_name_test")
2639*da0073e9SAndroid Build Coastguard Worker        self.assertExpected(empty_int_list_test.code, "empty_int_list_test")
2640*da0073e9SAndroid Build Coastguard Worker        self.assertExpected(empty_float_list_test.code, "empty_float_list_test")
2641*da0073e9SAndroid Build Coastguard Worker        self.assertExpected(print_weird_test.code, "print_weird_test")
2642*da0073e9SAndroid Build Coastguard Worker
2643*da0073e9SAndroid Build Coastguard Worker    def test_cu_escaped_number(self):
2644*da0073e9SAndroid Build Coastguard Worker        cu = torch.jit.CompilationUnit('''
2645*da0073e9SAndroid Build Coastguard Worker            def foo(a):
2646*da0073e9SAndroid Build Coastguard Worker                print("hi\016")
2647*da0073e9SAndroid Build Coastguard Worker        ''')
2648*da0073e9SAndroid Build Coastguard Worker        self.assertExpected(cu.foo.code)
2649*da0073e9SAndroid Build Coastguard Worker
2650*da0073e9SAndroid Build Coastguard Worker    def test_import_method(self):
2651*da0073e9SAndroid Build Coastguard Worker        with torch._jit_internal._disable_emit_hooks():
2652*da0073e9SAndroid Build Coastguard Worker            class Foo(torch.jit.ScriptModule):
2653*da0073e9SAndroid Build Coastguard Worker                @torch.jit.script_method
2654*da0073e9SAndroid Build Coastguard Worker                def forward(self, x, y):
2655*da0073e9SAndroid Build Coastguard Worker                    return 2 * x + y
2656*da0073e9SAndroid Build Coastguard Worker
2657*da0073e9SAndroid Build Coastguard Worker            foo = Foo()
2658*da0073e9SAndroid Build Coastguard Worker            buffer = io.BytesIO()
2659*da0073e9SAndroid Build Coastguard Worker            torch.jit.save(foo, buffer)
2660*da0073e9SAndroid Build Coastguard Worker
2661*da0073e9SAndroid Build Coastguard Worker            buffer.seek(0)
2662*da0073e9SAndroid Build Coastguard Worker            foo_loaded = torch.jit.load(buffer)
2663*da0073e9SAndroid Build Coastguard Worker            self.assertExpected(foo_loaded.forward.code)
2664*da0073e9SAndroid Build Coastguard Worker
2665*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("temporarily disable the test for fwd compatibility")
2666*da0073e9SAndroid Build Coastguard Worker    def test_non_ascii_string(self):
2667*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.jit.ScriptModule):
2668*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2669*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2670*da0073e9SAndroid Build Coastguard Worker                self.a = "Over \u0e55\u0e57 57"
2671*da0073e9SAndroid Build Coastguard Worker
2672*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
2673*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y):
2674*da0073e9SAndroid Build Coastguard Worker                return self.a + "hi\xA1"
2675*da0073e9SAndroid Build Coastguard Worker
2676*da0073e9SAndroid Build Coastguard Worker        foo = Foo()
2677*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO()
2678*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(foo, buffer)
2679*da0073e9SAndroid Build Coastguard Worker
2680*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
2681*da0073e9SAndroid Build Coastguard Worker        foo_loaded = torch.jit.load(buffer)
2682*da0073e9SAndroid Build Coastguard Worker        self.assertExpected(foo_loaded.forward.code)
2683*da0073e9SAndroid Build Coastguard Worker
2684*da0073e9SAndroid Build Coastguard Worker    def test_function_default_values(self):
2685*da0073e9SAndroid Build Coastguard Worker        outer_var = torch.tensor(20)
2686*da0073e9SAndroid Build Coastguard Worker        outer_var2 = torch.tensor(30)
2687*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(0.5)
2688*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor(10)
2689*da0073e9SAndroid Build Coastguard Worker
2690*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2691*da0073e9SAndroid Build Coastguard Worker        def simple_fn(x, a=a, b=b, c=outer_var + outer_var2):
2692*da0073e9SAndroid Build Coastguard Worker            return x + a + b + c
2693*da0073e9SAndroid Build Coastguard Worker
2694*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2695*da0073e9SAndroid Build Coastguard Worker            simple_fn(torch.ones(1)),
2696*da0073e9SAndroid Build Coastguard Worker            torch.ones(1) + 0.5 + 10 + (20 + 30))
2697*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2698*da0073e9SAndroid Build Coastguard Worker            simple_fn(torch.ones(1), torch.tensor(1), torch.tensor(3), torch.tensor(4)),
2699*da0073e9SAndroid Build Coastguard Worker            torch.ones(1) + 1 + 3 + 4)
2700*da0073e9SAndroid Build Coastguard Worker
2701*da0073e9SAndroid Build Coastguard Worker        outer_c = torch.tensor(9)
2702*da0073e9SAndroid Build Coastguard Worker        outer_flag = torch.tensor(False)
2703*da0073e9SAndroid Build Coastguard Worker
2704*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2705*da0073e9SAndroid Build Coastguard Worker        def bool_fn(x, a=outer_c, flag=outer_flag):
2706*da0073e9SAndroid Build Coastguard Worker            if bool(flag):
2707*da0073e9SAndroid Build Coastguard Worker                result = x
2708*da0073e9SAndroid Build Coastguard Worker            else:
2709*da0073e9SAndroid Build Coastguard Worker                result = x + a
2710*da0073e9SAndroid Build Coastguard Worker            return result
2711*da0073e9SAndroid Build Coastguard Worker
2712*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bool_fn(torch.ones(1)), torch.ones(1) + 9)
2713*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2714*da0073e9SAndroid Build Coastguard Worker            bool_fn(torch.ones(1), torch.tensor(1), torch.tensor(True)),
2715*da0073e9SAndroid Build Coastguard Worker            torch.ones(1))
2716*da0073e9SAndroid Build Coastguard Worker
2717*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2718*da0073e9SAndroid Build Coastguard Worker        def none_fn(x=None):
2719*da0073e9SAndroid Build Coastguard Worker            # type: (Optional[int]) -> Optional[int]
2720*da0073e9SAndroid Build Coastguard Worker            return x
2721*da0073e9SAndroid Build Coastguard Worker
2722*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(none_fn(), None)
2723*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(none_fn(1), 1)
2724*da0073e9SAndroid Build Coastguard Worker
2725*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2726*da0073e9SAndroid Build Coastguard Worker        def hints(x, a=0.5, b=10):
2727*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor, float, int) -> Tensor
2728*da0073e9SAndroid Build Coastguard Worker            return x + a + b
2729*da0073e9SAndroid Build Coastguard Worker
2730*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(hints(torch.ones(1)), torch.ones(1) + 0.5 + 10)
2731*da0073e9SAndroid Build Coastguard Worker
2732*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Expected a default value"):
2733*da0073e9SAndroid Build Coastguard Worker
2734*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
2735*da0073e9SAndroid Build Coastguard Worker            def hints_bad_types(x, a=10, b=0.5):  # noqa: T484
2736*da0073e9SAndroid Build Coastguard Worker                # type: (Tensor, float, int) -> Tensor
2737*da0073e9SAndroid Build Coastguard Worker                return x + a + b
2738*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Expected a default value"):
2739*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
2740*da0073e9SAndroid Build Coastguard Worker            def bad_no_optional(x=None):
2741*da0073e9SAndroid Build Coastguard Worker                # type: (Dict[str, int]) -> Dict[str, int]
2742*da0073e9SAndroid Build Coastguard Worker                return x
2743*da0073e9SAndroid Build Coastguard Worker
2744*da0073e9SAndroid Build Coastguard Worker
2745*da0073e9SAndroid Build Coastguard Worker    def test_module_default_values(self):
2746*da0073e9SAndroid Build Coastguard Worker        four = torch.tensor(4)
2747*da0073e9SAndroid Build Coastguard Worker
2748*da0073e9SAndroid Build Coastguard Worker        class Test(torch.jit.ScriptModule):
2749*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
2750*da0073e9SAndroid Build Coastguard Worker            def forward(self, input, other=four):
2751*da0073e9SAndroid Build Coastguard Worker                return input + other
2752*da0073e9SAndroid Build Coastguard Worker
2753*da0073e9SAndroid Build Coastguard Worker        t = Test()
2754*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t(torch.ones(1)), torch.ones(1) + 4)
2755*da0073e9SAndroid Build Coastguard Worker
2756*da0073e9SAndroid Build Coastguard Worker    def test_mutable_default_values(self):
2757*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(Exception, "Mutable default parameters"):
2758*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
2759*da0073e9SAndroid Build Coastguard Worker            def foo(x=(1, [])):
2760*da0073e9SAndroid Build Coastguard Worker                # type: (Tuple[int, List[Tensor]])
2761*da0073e9SAndroid Build Coastguard Worker                return x
2762*da0073e9SAndroid Build Coastguard Worker
2763*da0073e9SAndroid Build Coastguard Worker        class Test(torch.nn.Module):
2764*da0073e9SAndroid Build Coastguard Worker            def forward(self, input=[]):  # noqa: B006
2765*da0073e9SAndroid Build Coastguard Worker                return input
2766*da0073e9SAndroid Build Coastguard Worker
2767*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(Exception, "Mutable default parameters"):
2768*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(Test())
2769*da0073e9SAndroid Build Coastguard Worker
2770*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
2771*da0073e9SAndroid Build Coastguard Worker    def test_warnings(self):
2772*da0073e9SAndroid Build Coastguard Worker        import warnings
2773*da0073e9SAndroid Build Coastguard Worker
2774*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2775*da0073e9SAndroid Build Coastguard Worker            if bool(x < 2):
2776*da0073e9SAndroid Build Coastguard Worker                warnings.warn("x is less than 2")
2777*da0073e9SAndroid Build Coastguard Worker            return x
2778*da0073e9SAndroid Build Coastguard Worker
2779*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
2780*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2781*da0073e9SAndroid Build Coastguard Worker                if bool(x < 2):
2782*da0073e9SAndroid Build Coastguard Worker                    warnings.warn("x is less than 2")
2783*da0073e9SAndroid Build Coastguard Worker                return x
2784*da0073e9SAndroid Build Coastguard Worker
2785*da0073e9SAndroid Build Coastguard Worker
2786*da0073e9SAndroid Build Coastguard Worker        scripted_mod = torch.jit.script(M())
2787*da0073e9SAndroid Build Coastguard Worker        scripted_fn = torch.jit.script(fn)
2788*da0073e9SAndroid Build Coastguard Worker
2789*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as warns:
2790*da0073e9SAndroid Build Coastguard Worker            fn(torch.ones(1))
2791*da0073e9SAndroid Build Coastguard Worker
2792*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as script_warns:
2793*da0073e9SAndroid Build Coastguard Worker            scripted_fn(torch.ones(1))
2794*da0073e9SAndroid Build Coastguard Worker
2795*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as script_mod_warns:
2796*da0073e9SAndroid Build Coastguard Worker            scripted_mod(torch.ones(1))
2797*da0073e9SAndroid Build Coastguard Worker
2798*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(str(warns[0]), str(script_warns[0]))
2799*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(script_mod_warns), 1)
2800*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(str(warns[0].message), str(script_mod_warns[0].message))
2801*da0073e9SAndroid Build Coastguard Worker
2802*da0073e9SAndroid Build Coastguard Worker    def test_no_erroneous_warnings(self):
2803*da0073e9SAndroid Build Coastguard Worker        import warnings
2804*da0073e9SAndroid Build Coastguard Worker
2805*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2806*da0073e9SAndroid Build Coastguard Worker            if bool(x > 0):
2807*da0073e9SAndroid Build Coastguard Worker                warnings.warn('This should NOT be printed')
2808*da0073e9SAndroid Build Coastguard Worker                x += 1
2809*da0073e9SAndroid Build Coastguard Worker            return x
2810*da0073e9SAndroid Build Coastguard Worker
2811*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as warns:
2812*da0073e9SAndroid Build Coastguard Worker            fn_script = torch.jit.script(fn)
2813*da0073e9SAndroid Build Coastguard Worker            fn_script(torch.tensor(0))
2814*da0073e9SAndroid Build Coastguard Worker        warns = [str(w.message) for w in warns]
2815*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(warns), 0)
2816*da0073e9SAndroid Build Coastguard Worker
2817*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(True, "TODO: re-enable with https://github.com/pytorch/pytorch/pull/29339")
2818*da0073e9SAndroid Build Coastguard Worker    def test_torch_load_error(self):
2819*da0073e9SAndroid Build Coastguard Worker        class J(torch.jit.ScriptModule):
2820*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
2821*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
2822*da0073e9SAndroid Build Coastguard Worker                return input + 100
2823*da0073e9SAndroid Build Coastguard Worker
2824*da0073e9SAndroid Build Coastguard Worker        j = J()
2825*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName() as fname:
2826*da0073e9SAndroid Build Coastguard Worker            j.save(fname)
2827*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "is a zip"):
2828*da0073e9SAndroid Build Coastguard Worker                torch.load(fname)
2829*da0073e9SAndroid Build Coastguard Worker
2830*da0073e9SAndroid Build Coastguard Worker    def test_torch_load_zipfile_check(self):
2831*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2832*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2833*da0073e9SAndroid Build Coastguard Worker            return x + 10
2834*da0073e9SAndroid Build Coastguard Worker
2835*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName() as fname:
2836*da0073e9SAndroid Build Coastguard Worker            fn.save(fname)
2837*da0073e9SAndroid Build Coastguard Worker            with open(fname, 'rb') as f:
2838*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch.serialization._is_zipfile(f))
2839*da0073e9SAndroid Build Coastguard Worker
2840*da0073e9SAndroid Build Coastguard Worker    def test_python_bindings(self):
2841*da0073e9SAndroid Build Coastguard Worker        lstm_cell = torch.jit.script(LSTMCellS)
2842*da0073e9SAndroid Build Coastguard Worker
2843*da0073e9SAndroid Build Coastguard Worker        def lstm(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
2844*da0073e9SAndroid Build Coastguard Worker            for i in range(x.size(0)):
2845*da0073e9SAndroid Build Coastguard Worker                hx, cx = lstm_cell(x[i], hx, cx, w_ih, w_hh, b_ih, b_hh)
2846*da0073e9SAndroid Build Coastguard Worker            return hx
2847*da0073e9SAndroid Build Coastguard Worker
2848*da0073e9SAndroid Build Coastguard Worker        slstm = torch.jit.script(lstm)
2849*da0073e9SAndroid Build Coastguard Worker
2850*da0073e9SAndroid Build Coastguard Worker        inputs = get_lstm_inputs('cpu', training=True, seq_length=10)
2851*da0073e9SAndroid Build Coastguard Worker        slstm(*inputs).sum().backward()
2852*da0073e9SAndroid Build Coastguard Worker        global fw_graph
2853*da0073e9SAndroid Build Coastguard Worker        fw_graph = slstm.graph_for(*inputs)
2854*da0073e9SAndroid Build Coastguard Worker        nodes = list(fw_graph.nodes())
2855*da0073e9SAndroid Build Coastguard Worker        tested_blocks = False
2856*da0073e9SAndroid Build Coastguard Worker        for node in nodes:
2857*da0073e9SAndroid Build Coastguard Worker            for output in node.outputs():
2858*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(hasattr(output, 'type'))
2859*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(output.type() is not None)
2860*da0073e9SAndroid Build Coastguard Worker            for input in node.inputs():
2861*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(hasattr(input, 'type'))
2862*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(input.type() is not None)
2863*da0073e9SAndroid Build Coastguard Worker            for block in node.blocks():
2864*da0073e9SAndroid Build Coastguard Worker                tested_blocks = True
2865*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(hasattr(block, 'inputs'))
2866*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(hasattr(block, 'outputs'))
2867*da0073e9SAndroid Build Coastguard Worker                for output in block.outputs():
2868*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(hasattr(output, 'type'))
2869*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(output.type() is not None)
2870*da0073e9SAndroid Build Coastguard Worker                for input in block.inputs():
2871*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(hasattr(input, 'type'))
2872*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(input.type() is not None)
2873*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(hasattr(block, 'returnNode'))
2874*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(type(block.returnNode()) == torch._C.Node)
2875*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(hasattr(block, 'paramNode'))
2876*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(type(block.paramNode()) == torch._C.Node)
2877*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(tested_blocks)
2878*da0073e9SAndroid Build Coastguard Worker
2879*da0073e9SAndroid Build Coastguard Worker    def test_export_opnames(self):
2880*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.jit.ScriptModule):
2881*da0073e9SAndroid Build Coastguard Worker            def one(self, x, y):
2882*da0073e9SAndroid Build Coastguard Worker                # type: (Tensor, Tensor) -> Tensor
2883*da0073e9SAndroid Build Coastguard Worker                return x + y
2884*da0073e9SAndroid Build Coastguard Worker
2885*da0073e9SAndroid Build Coastguard Worker            def two(self, x):
2886*da0073e9SAndroid Build Coastguard Worker                # type: (Tensor) -> Tensor
2887*da0073e9SAndroid Build Coastguard Worker                return 2 * x
2888*da0073e9SAndroid Build Coastguard Worker
2889*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
2890*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2891*da0073e9SAndroid Build Coastguard Worker                # type: (Tensor) -> Tensor
2892*da0073e9SAndroid Build Coastguard Worker                return self.one(self.two(x), x)
2893*da0073e9SAndroid Build Coastguard Worker
2894*da0073e9SAndroid Build Coastguard Worker        class Bar(torch.jit.ScriptModule):
2895*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2896*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2897*da0073e9SAndroid Build Coastguard Worker                self.sub = Foo()
2898*da0073e9SAndroid Build Coastguard Worker
2899*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
2900*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2901*da0073e9SAndroid Build Coastguard Worker                # type: (Tensor) -> Tensor
2902*da0073e9SAndroid Build Coastguard Worker                return self.sub.forward(x)
2903*da0073e9SAndroid Build Coastguard Worker
2904*da0073e9SAndroid Build Coastguard Worker        bar = Bar()
2905*da0073e9SAndroid Build Coastguard Worker        ops = torch.jit.export_opnames(bar)
2906*da0073e9SAndroid Build Coastguard Worker        expected = ['aten::add.Tensor', 'aten::mul.Scalar']
2907*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(set(expected).issubset(set(ops)))
2908*da0073e9SAndroid Build Coastguard Worker
2909*da0073e9SAndroid Build Coastguard Worker    def test_pytorch_jit_env_off(self):
2910*da0073e9SAndroid Build Coastguard Worker        import subprocess
2911*da0073e9SAndroid Build Coastguard Worker        env = os.environ.copy()
2912*da0073e9SAndroid Build Coastguard Worker        env['PYTORCH_JIT'] = '0'
2913*da0073e9SAndroid Build Coastguard Worker        try:
2914*da0073e9SAndroid Build Coastguard Worker            subprocess.check_output([sys.executable, '-c', 'import torch'], env=env)
2915*da0073e9SAndroid Build Coastguard Worker        except subprocess.CalledProcessError as e:
2916*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("Could not 'import torch' with PYTORCH_JIT=0") from e
2917*da0073e9SAndroid Build Coastguard Worker
2918*da0073e9SAndroid Build Coastguard Worker    def test_print_op_module(self):
2919*da0073e9SAndroid Build Coastguard Worker        # Issue #19351: python2 and python3 go through different paths.
2920*da0073e9SAndroid Build Coastguard Worker        # python2 returns '<module 'torch.ops' (built-in)>'
2921*da0073e9SAndroid Build Coastguard Worker        # python3 uses __file__ and return
2922*da0073e9SAndroid Build Coastguard Worker        # '<module 'torch.ops' from '/scratch/ailzhang/pytorch/torch/_ops.py'>'
2923*da0073e9SAndroid Build Coastguard Worker        s = str(torch.ops)
2924*da0073e9SAndroid Build Coastguard Worker        self.assertRegex(s, r'ops')
2925*da0073e9SAndroid Build Coastguard Worker
2926*da0073e9SAndroid Build Coastguard Worker    def test_print_classes_module(self):
2927*da0073e9SAndroid Build Coastguard Worker        s = str(torch.classes)
2928*da0073e9SAndroid Build Coastguard Worker        self.assertRegex(s, r'classes')
2929*da0073e9SAndroid Build Coastguard Worker
2930*da0073e9SAndroid Build Coastguard Worker    def test_print_torch_ops_modules(self):
2931*da0073e9SAndroid Build Coastguard Worker        s = str(torch._ops.ops.quantized)
2932*da0073e9SAndroid Build Coastguard Worker        self.assertRegex(s, r'torch.ops')
2933*da0073e9SAndroid Build Coastguard Worker        s = str(torch._ops.ops.atan)
2934*da0073e9SAndroid Build Coastguard Worker        self.assertRegex(s, r'torch.ops')
2935*da0073e9SAndroid Build Coastguard Worker
2936*da0073e9SAndroid Build Coastguard Worker    def test_hide_source_ranges_context_manager(self):
2937*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2938*da0073e9SAndroid Build Coastguard Worker        def foo(x):
2939*da0073e9SAndroid Build Coastguard Worker            return torch.add(x, x)
2940*da0073e9SAndroid Build Coastguard Worker
2941*da0073e9SAndroid Build Coastguard Worker        graph = foo.graph
2942*da0073e9SAndroid Build Coastguard Worker        source_range_regex = "# .*\\.py"
2943*da0073e9SAndroid Build Coastguard Worker        self.assertRegex(graph.__repr__(), source_range_regex)
2944*da0073e9SAndroid Build Coastguard Worker        with torch.jit._hide_source_ranges():
2945*da0073e9SAndroid Build Coastguard Worker            self.assertNotRegex(graph.__repr__(), source_range_regex)
2946*da0073e9SAndroid Build Coastguard Worker            self.assertRegex(graph.str(print_source_ranges=True), source_range_regex)
2947*da0073e9SAndroid Build Coastguard Worker        self.assertRegex(graph.__repr__(), source_range_regex)
2948*da0073e9SAndroid Build Coastguard Worker
2949*da0073e9SAndroid Build Coastguard Worker
2950*da0073e9SAndroid Build Coastguard Workerclass TestFrontend(JitTestCase):
2951*da0073e9SAndroid Build Coastguard Worker
2952*da0073e9SAndroid Build Coastguard Worker    def test_instancing_error(self):
2953*da0073e9SAndroid Build Coastguard Worker        @torch.jit.ignore
2954*da0073e9SAndroid Build Coastguard Worker        class MyScriptClass:
2955*da0073e9SAndroid Build Coastguard Worker            def unscriptable(self):
2956*da0073e9SAndroid Build Coastguard Worker                return "a" + 200
2957*da0073e9SAndroid Build Coastguard Worker
2958*da0073e9SAndroid Build Coastguard Worker
2959*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.nn.Module):
2960*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2961*da0073e9SAndroid Build Coastguard Worker                return MyScriptClass()
2962*da0073e9SAndroid Build Coastguard Worker
2963*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(torch.jit.frontend.FrontendError) as cm:
2964*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(TestModule())
2965*da0073e9SAndroid Build Coastguard Worker
2966*da0073e9SAndroid Build Coastguard Worker        checker = FileCheck()
2967*da0073e9SAndroid Build Coastguard Worker        checker.check("Cannot instantiate class")
2968*da0073e9SAndroid Build Coastguard Worker        checker.check("def forward")
2969*da0073e9SAndroid Build Coastguard Worker        checker.run(str(cm.exception))
2970*da0073e9SAndroid Build Coastguard Worker
2971*da0073e9SAndroid Build Coastguard Worker    def test_dictionary_as_example_inputs_for_jit_trace(self):
2972*da0073e9SAndroid Build Coastguard Worker        class TestModule_v1(torch.nn.Module):
2973*da0073e9SAndroid Build Coastguard Worker            def forward(self, key2=None, key3=None, key4=None, key5=None, key1=None, key6=None):
2974*da0073e9SAndroid Build Coastguard Worker                return key1 + key2 + key3
2975*da0073e9SAndroid Build Coastguard Worker
2976*da0073e9SAndroid Build Coastguard Worker        class TestModule_v2(torch.nn.Module):
2977*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y):
2978*da0073e9SAndroid Build Coastguard Worker                return x + y
2979*da0073e9SAndroid Build Coastguard Worker
2980*da0073e9SAndroid Build Coastguard Worker        def test_func(x, y):
2981*da0073e9SAndroid Build Coastguard Worker            return x + y
2982*da0073e9SAndroid Build Coastguard Worker        model_1 = TestModule_v1()
2983*da0073e9SAndroid Build Coastguard Worker        model_2 = TestModule_v2()
2984*da0073e9SAndroid Build Coastguard Worker        value1 = torch.ones(1)
2985*da0073e9SAndroid Build Coastguard Worker        value2 = torch.ones(1)
2986*da0073e9SAndroid Build Coastguard Worker        value3 = torch.ones(1)
2987*da0073e9SAndroid Build Coastguard Worker        example_input_dict = {'key1': value1, 'key2': value2, 'key3': value3}
2988*da0073e9SAndroid Build Coastguard Worker        example_input_dict_func = {'x': value1, 'y': value2}
2989*da0073e9SAndroid Build Coastguard Worker        traced_model_1 = torch.jit.trace(model_1, example_kwarg_inputs=example_input_dict, strict=False)
2990*da0073e9SAndroid Build Coastguard Worker        traced_model_1_m = torch.jit.trace_module(
2991*da0073e9SAndroid Build Coastguard Worker            model_1, {'forward': example_input_dict}, example_inputs_is_kwarg=True, strict=False)
2992*da0073e9SAndroid Build Coastguard Worker        traced_model_2 = torch.jit.trace(model_2, example_kwarg_inputs={'x': torch.rand([2]), 'y': torch.rand([2])})
2993*da0073e9SAndroid Build Coastguard Worker        traced_func = torch.jit.trace(test_func, example_kwarg_inputs=example_input_dict_func, strict=False)
2994*da0073e9SAndroid Build Coastguard Worker        res_1 = traced_model_1(**example_input_dict)
2995*da0073e9SAndroid Build Coastguard Worker        res_1_m = traced_model_1_m(**example_input_dict)
2996*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_1, 3 * torch.ones(1))
2997*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_1_m, 3 * torch.ones(1))
2998*da0073e9SAndroid Build Coastguard Worker        res_func = traced_func(**example_input_dict_func)
2999*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_func, 2 * torch.ones(1))
3000*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"forward\(\) is missing value for argument 'x'."):
3001*da0073e9SAndroid Build Coastguard Worker            res_2 = traced_model_2(**{'z': torch.rand([2]), 'y': torch.rand([2])})  # noqa: PIE804
3002*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"forward\(\) is missing value for argument 'y'."):
3003*da0073e9SAndroid Build Coastguard Worker            res_2 = traced_model_2(**{'x': torch.rand([2]), 'z': torch.rand([2])})  # noqa: PIE804
3004*da0073e9SAndroid Build Coastguard Worker
3005*da0073e9SAndroid Build Coastguard Worker
3006*da0073e9SAndroid Build Coastguard Workerclass TestScript(JitTestCase):
3007*da0073e9SAndroid Build Coastguard Worker
3008*da0073e9SAndroid Build Coastguard Worker    # Tests that calling torch.jit.script repeated on function is allowed.
3009*da0073e9SAndroid Build Coastguard Worker    def test_repeated_script_on_function(self):
3010*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
3011*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
3012*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3013*da0073e9SAndroid Build Coastguard Worker            return x
3014*da0073e9SAndroid Build Coastguard Worker
3015*da0073e9SAndroid Build Coastguard Worker        torch.jit.script(torch.jit.script(fn))
3016*da0073e9SAndroid Build Coastguard Worker
3017*da0073e9SAndroid Build Coastguard Worker    def test_pretty_print_function(self):
3018*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
3019*da0073e9SAndroid Build Coastguard Worker        def foo(x):
3020*da0073e9SAndroid Build Coastguard Worker            return torch.nn.functional.interpolate(x)
3021*da0073e9SAndroid Build Coastguard Worker
3022*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("interpolate").run(foo.code)
3023*da0073e9SAndroid Build Coastguard Worker
3024*da0073e9SAndroid Build Coastguard Worker    def test_inlined_graph(self):
3025*da0073e9SAndroid Build Coastguard Worker        """
3026*da0073e9SAndroid Build Coastguard Worker        Check that the `inlined_graph` property correctly returns an inlined
3027*da0073e9SAndroid Build Coastguard Worker        graph, both through function calls and method calls.
3028*da0073e9SAndroid Build Coastguard Worker        """
3029*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
3030*da0073e9SAndroid Build Coastguard Worker        def foo(x):
3031*da0073e9SAndroid Build Coastguard Worker            return torch.add(x, x)
3032*da0073e9SAndroid Build Coastguard Worker
3033*da0073e9SAndroid Build Coastguard Worker        class MyNestedMod(torch.nn.Module):
3034*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3035*da0073e9SAndroid Build Coastguard Worker                return torch.sub(x, x)
3036*da0073e9SAndroid Build Coastguard Worker
3037*da0073e9SAndroid Build Coastguard Worker
3038*da0073e9SAndroid Build Coastguard Worker        class MyMod(torch.nn.Module):
3039*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
3040*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3041*da0073e9SAndroid Build Coastguard Worker                self.nested = MyNestedMod()
3042*da0073e9SAndroid Build Coastguard Worker
3043*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3044*da0073e9SAndroid Build Coastguard Worker                x = self.nested(x)  # sub
3045*da0073e9SAndroid Build Coastguard Worker                x = foo(x)  # add
3046*da0073e9SAndroid Build Coastguard Worker                return torch.mul(x, x)
3047*da0073e9SAndroid Build Coastguard Worker
3048*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(MyMod())
3049*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::sub") \
3050*da0073e9SAndroid Build Coastguard Worker            .check("aten::add") \
3051*da0073e9SAndroid Build Coastguard Worker            .check("aten::mul") \
3052*da0073e9SAndroid Build Coastguard Worker            .run(m.inlined_graph)
3053*da0073e9SAndroid Build Coastguard Worker
3054*da0073e9SAndroid Build Coastguard Worker    def test_static_method_on_module(self):
3055*da0073e9SAndroid Build Coastguard Worker        """
3056*da0073e9SAndroid Build Coastguard Worker        Check that the `@staticmethod` annotation on a function on a module works.
3057*da0073e9SAndroid Build Coastguard Worker        """
3058*da0073e9SAndroid Build Coastguard Worker        class MyCell(torch.nn.Module):
3059*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3060*da0073e9SAndroid Build Coastguard Worker            def do_it(x, h):
3061*da0073e9SAndroid Build Coastguard Worker                new_h = torch.tanh(x + h)
3062*da0073e9SAndroid Build Coastguard Worker                return new_h, new_h
3063*da0073e9SAndroid Build Coastguard Worker
3064*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, h):
3065*da0073e9SAndroid Build Coastguard Worker                return self.do_it(x, h)
3066*da0073e9SAndroid Build Coastguard Worker
3067*da0073e9SAndroid Build Coastguard Worker        my_cell = torch.jit.script(MyCell())
3068*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
3069*da0073e9SAndroid Build Coastguard Worker        h = torch.rand(3, 4)
3070*da0073e9SAndroid Build Coastguard Worker        jitted_cell = my_cell(x, h)
3071*da0073e9SAndroid Build Coastguard Worker        non_jitted_cell = MyCell().do_it(x, h)
3072*da0073e9SAndroid Build Coastguard Worker
3073*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(jitted_cell, non_jitted_cell)
3074*da0073e9SAndroid Build Coastguard Worker
3075*da0073e9SAndroid Build Coastguard Worker    def test_code_with_constants(self):
3076*da0073e9SAndroid Build Coastguard Worker        """
3077*da0073e9SAndroid Build Coastguard Worker        Check that the `code_with_constants` property correctly returns graph CONSTANTS in the
3078*da0073e9SAndroid Build Coastguard Worker        CONSTANTS.cN format used in the output of the `code` property.
3079*da0073e9SAndroid Build Coastguard Worker        """
3080*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
3081*da0073e9SAndroid Build Coastguard Worker        def foo(x=torch.ones(1)):
3082*da0073e9SAndroid Build Coastguard Worker            return x
3083*da0073e9SAndroid Build Coastguard Worker
3084*da0073e9SAndroid Build Coastguard Worker        class Moddy(torch.nn.Module):
3085*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3086*da0073e9SAndroid Build Coastguard Worker                return foo()
3087*da0073e9SAndroid Build Coastguard Worker
3088*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(Moddy())
3089*da0073e9SAndroid Build Coastguard Worker        src, CONSTANTS = m.code_with_constants
3090*da0073e9SAndroid Build Coastguard Worker
3091*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(CONSTANTS.c0, torch.ones(1))
3092*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(src, m.code)
3093*da0073e9SAndroid Build Coastguard Worker
3094*da0073e9SAndroid Build Coastguard Worker    def test_code_with_constants_restore(self):
3095*da0073e9SAndroid Build Coastguard Worker        """
3096*da0073e9SAndroid Build Coastguard Worker        Check that the `code_with_constants` property correctly works on restoration after save() + load()
3097*da0073e9SAndroid Build Coastguard Worker        """
3098*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
3099*da0073e9SAndroid Build Coastguard Worker        def foo(x=torch.ones(1)):
3100*da0073e9SAndroid Build Coastguard Worker            return x
3101*da0073e9SAndroid Build Coastguard Worker
3102*da0073e9SAndroid Build Coastguard Worker        class Moddy(torch.nn.Module):
3103*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3104*da0073e9SAndroid Build Coastguard Worker                return foo()
3105*da0073e9SAndroid Build Coastguard Worker
3106*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(Moddy())
3107*da0073e9SAndroid Build Coastguard Worker        src, CONSTANTS = m.code_with_constants
3108*da0073e9SAndroid Build Coastguard Worker        eic = self.getExportImportCopy(m)
3109*da0073e9SAndroid Build Coastguard Worker
3110*da0073e9SAndroid Build Coastguard Worker        src_eic, CONSTANTS_eic = eic.code_with_constants
3111*da0073e9SAndroid Build Coastguard Worker
3112*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(src, src_eic)
3113*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(CONSTANTS.c0, CONSTANTS_eic.c0)
3114*da0073e9SAndroid Build Coastguard Worker
3115*da0073e9SAndroid Build Coastguard Worker
3116*da0073e9SAndroid Build Coastguard Worker    def test_oneline_func(self):
3117*da0073e9SAndroid Build Coastguard Worker        def fn(x): return x  # noqa: E704
3118*da0073e9SAndroid Build Coastguard Worker
3119*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (torch.ones(2, 2), ))
3120*da0073e9SAndroid Build Coastguard Worker
3121*da0073e9SAndroid Build Coastguard Worker    def test_request_bailout(self):
3122*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode_for_profiling_tests():
3123*da0073e9SAndroid Build Coastguard Worker
3124*da0073e9SAndroid Build Coastguard Worker            def fct_loop(x):
3125*da0073e9SAndroid Build Coastguard Worker                for i in range(3):
3126*da0073e9SAndroid Build Coastguard Worker                    x = torch.cat((x, x), 0)
3127*da0073e9SAndroid Build Coastguard Worker                return x
3128*da0073e9SAndroid Build Coastguard Worker
3129*da0073e9SAndroid Build Coastguard Worker            x = torch.ones(2, 3, 4, dtype=torch.float32)
3130*da0073e9SAndroid Build Coastguard Worker            expected = fct_loop(x)
3131*da0073e9SAndroid Build Coastguard Worker            jitted = torch.jit.script(fct_loop)
3132*da0073e9SAndroid Build Coastguard Worker            # profile
3133*da0073e9SAndroid Build Coastguard Worker            jitted(x)
3134*da0073e9SAndroid Build Coastguard Worker            # optimize
3135*da0073e9SAndroid Build Coastguard Worker            jitted(x)
3136*da0073e9SAndroid Build Coastguard Worker            dstate = jitted.get_debug_state()
3137*da0073e9SAndroid Build Coastguard Worker            eplan = get_execution_plan(dstate)
3138*da0073e9SAndroid Build Coastguard Worker            num_bailouts = eplan.code.num_bailouts()
3139*da0073e9SAndroid Build Coastguard Worker
3140*da0073e9SAndroid Build Coastguard Worker            for i in range(0, num_bailouts):
3141*da0073e9SAndroid Build Coastguard Worker                eplan.code.request_bailout(i)
3142*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(jitted(x), expected)
3143*da0073e9SAndroid Build Coastguard Worker
3144*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("bailouts are being deprecated")
3145*da0073e9SAndroid Build Coastguard Worker    def test_dominated_bailout(self):
3146*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode_for_profiling_tests():
3147*da0073e9SAndroid Build Coastguard Worker            # functional dominated guard
3148*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
3149*da0073e9SAndroid Build Coastguard Worker            def foo(x):
3150*da0073e9SAndroid Build Coastguard Worker                dim = x.dim()
3151*da0073e9SAndroid Build Coastguard Worker                if dim == 0:
3152*da0073e9SAndroid Build Coastguard Worker                    y = int(x)
3153*da0073e9SAndroid Build Coastguard Worker                else:
3154*da0073e9SAndroid Build Coastguard Worker                    y = x.size()[dim - 1]
3155*da0073e9SAndroid Build Coastguard Worker                return y
3156*da0073e9SAndroid Build Coastguard Worker
3157*da0073e9SAndroid Build Coastguard Worker            x = torch.zeros(2)
3158*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(foo(x), 2)
3159*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(foo(x), 2)
3160*da0073e9SAndroid Build Coastguard Worker            g = torch.jit.last_executed_optimized_graph()
3161*da0073e9SAndroid Build Coastguard Worker            g_s = str(g)
3162*da0073e9SAndroid Build Coastguard Worker            g_s = g_s[0:g_s.find("return")]
3163*da0073e9SAndroid Build Coastguard Worker            FileCheck().check_count("prim::BailOut[", 1, exactly=True).run(g_s)
3164*da0073e9SAndroid Build Coastguard Worker
3165*da0073e9SAndroid Build Coastguard Worker            # dominated guard of non-functional value
3166*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
3167*da0073e9SAndroid Build Coastguard Worker            def foo(x):
3168*da0073e9SAndroid Build Coastguard Worker                dim = x.dim()
3169*da0073e9SAndroid Build Coastguard Worker                x.add_(3)
3170*da0073e9SAndroid Build Coastguard Worker                if dim == 0:
3171*da0073e9SAndroid Build Coastguard Worker                    return 0
3172*da0073e9SAndroid Build Coastguard Worker                else:
3173*da0073e9SAndroid Build Coastguard Worker                    return x.size()[dim - 1]
3174*da0073e9SAndroid Build Coastguard Worker
3175*da0073e9SAndroid Build Coastguard Worker            x = torch.zeros(2)
3176*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(foo(x), 2)
3177*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(foo(x), 2)
3178*da0073e9SAndroid Build Coastguard Worker            g = torch.jit.last_executed_optimized_graph()
3179*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("prim::BailOut[").check("aten::add_").check_next("prim::BailOut[").check("return").run(g)
3180*da0073e9SAndroid Build Coastguard Worker
3181*da0073e9SAndroid Build Coastguard Worker            with torch.enable_grad():
3182*da0073e9SAndroid Build Coastguard Worker                @torch.jit.ignore
3183*da0073e9SAndroid Build Coastguard Worker                def disable_grad():
3184*da0073e9SAndroid Build Coastguard Worker                    torch.set_grad_enabled(False)
3185*da0073e9SAndroid Build Coastguard Worker
3186*da0073e9SAndroid Build Coastguard Worker                @torch.jit.ignore
3187*da0073e9SAndroid Build Coastguard Worker                def enable_grad():
3188*da0073e9SAndroid Build Coastguard Worker                    torch.set_grad_enabled(True)
3189*da0073e9SAndroid Build Coastguard Worker
3190*da0073e9SAndroid Build Coastguard Worker                @torch.jit.script
3191*da0073e9SAndroid Build Coastguard Worker                def foo(x):
3192*da0073e9SAndroid Build Coastguard Worker                    x = x + 1
3193*da0073e9SAndroid Build Coastguard Worker                    dim = x.dim()
3194*da0073e9SAndroid Build Coastguard Worker                    disable_grad()
3195*da0073e9SAndroid Build Coastguard Worker                    if dim == 0:
3196*da0073e9SAndroid Build Coastguard Worker                        y = int(x)
3197*da0073e9SAndroid Build Coastguard Worker                    else:
3198*da0073e9SAndroid Build Coastguard Worker                        y = x.size()[dim - 1]
3199*da0073e9SAndroid Build Coastguard Worker                    enable_grad()
3200*da0073e9SAndroid Build Coastguard Worker                    return y
3201*da0073e9SAndroid Build Coastguard Worker
3202*da0073e9SAndroid Build Coastguard Worker                x = torch.zeros(2, requires_grad=True)
3203*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(foo(x), 2)
3204*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(foo(x), 2)
3205*da0073e9SAndroid Build Coastguard Worker                g = torch.jit.last_executed_optimized_graph()
3206*da0073e9SAndroid Build Coastguard Worker                # there should still be a Bailout after disable_grad call
3207*da0073e9SAndroid Build Coastguard Worker                FileCheck().check("disable_grad").check("BailOut[").check("BailoutTemplate").run(g)
3208*da0073e9SAndroid Build Coastguard Worker
3209*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Torchdynamo cannot correctly handle profiler.profile calls")
3210*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "skip if profiling isn't enabled")
3211*da0073e9SAndroid Build Coastguard Worker    def test_profiling_merge(self):
3212*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
3213*da0073e9SAndroid Build Coastguard Worker        def test_not_const(x):
3214*da0073e9SAndroid Build Coastguard Worker            if x.size(0) == 1:
3215*da0073e9SAndroid Build Coastguard Worker                return 1
3216*da0073e9SAndroid Build Coastguard Worker            else:
3217*da0073e9SAndroid Build Coastguard Worker                return 2
3218*da0073e9SAndroid Build Coastguard Worker
3219*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode_for_profiling_tests():
3220*da0073e9SAndroid Build Coastguard Worker            with num_profiled_runs(2):
3221*da0073e9SAndroid Build Coastguard Worker                test_not_const(torch.rand([1, 2]))
3222*da0073e9SAndroid Build Coastguard Worker                test_not_const(torch.rand([2, 2]))
3223*da0073e9SAndroid Build Coastguard Worker
3224*da0073e9SAndroid Build Coastguard Worker                graph_str = torch.jit.last_executed_optimized_graph()
3225*da0073e9SAndroid Build Coastguard Worker                FileCheck().check("profiled_type=Float(*, 2, strides=[2, 1], requires_grad=0, device=cpu").run(graph_str)
3226*da0073e9SAndroid Build Coastguard Worker                FileCheck().check_not("profiled_type=Float(1, 2, strides=[2, 1], requires_grad=0, device=cpu").run(graph_str)
3227*da0073e9SAndroid Build Coastguard Worker
3228*da0073e9SAndroid Build Coastguard Worker
3229*da0073e9SAndroid Build Coastguard Worker    def test_nested_bailouts(self):
3230*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
3231*da0073e9SAndroid Build Coastguard Worker        def fct_loop(x):
3232*da0073e9SAndroid Build Coastguard Worker            for i in range(3):
3233*da0073e9SAndroid Build Coastguard Worker                x = torch.cat((x, x), 0)
3234*da0073e9SAndroid Build Coastguard Worker            return x
3235*da0073e9SAndroid Build Coastguard Worker
3236*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2, 3, 4, dtype=torch.float32)
3237*da0073e9SAndroid Build Coastguard Worker        out = fct_loop(x)
3238*da0073e9SAndroid Build Coastguard Worker        jit_trace = torch.jit.trace(fct_loop, x)
3239*da0073e9SAndroid Build Coastguard Worker        out_trace = jit_trace(x)
3240*da0073e9SAndroid Build Coastguard Worker
3241*da0073e9SAndroid Build Coastguard Worker    def test_no_self_arg_ignore_function(self):
3242*da0073e9SAndroid Build Coastguard Worker        class MyModule(nn.Module):
3243*da0073e9SAndroid Build Coastguard Worker            @torch.jit.ignore  # noqa: B902
3244*da0073e9SAndroid Build Coastguard Worker            def call_np():  # noqa: B902
3245*da0073e9SAndroid Build Coastguard Worker                # type: () -> int
3246*da0073e9SAndroid Build Coastguard Worker                return np.random.choice(2, p=[.95, .05])
3247*da0073e9SAndroid Build Coastguard Worker
3248*da0073e9SAndroid Build Coastguard Worker            def forward(self):
3249*da0073e9SAndroid Build Coastguard Worker                return self.call_np()
3250*da0073e9SAndroid Build Coastguard Worker
3251*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(Exception, "does not have a self argument"):
3252*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(MyModule())
3253*da0073e9SAndroid Build Coastguard Worker
3254*da0073e9SAndroid Build Coastguard Worker    def test_loop_liveness(self):
3255*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode_for_profiling_tests():
3256*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
3257*da0073e9SAndroid Build Coastguard Worker            def f(i):
3258*da0073e9SAndroid Build Coastguard Worker                # type: (int) -> Tensor
3259*da0073e9SAndroid Build Coastguard Worker                l = []
3260*da0073e9SAndroid Build Coastguard Worker                for n in [2, 1]:
3261*da0073e9SAndroid Build Coastguard Worker                    l.append(torch.zeros(n, i))
3262*da0073e9SAndroid Build Coastguard Worker
3263*da0073e9SAndroid Build Coastguard Worker                return l[0]
3264*da0073e9SAndroid Build Coastguard Worker
3265*da0073e9SAndroid Build Coastguard Worker            f(2)
3266*da0073e9SAndroid Build Coastguard Worker            f(1)
3267*da0073e9SAndroid Build Coastguard Worker
3268*da0073e9SAndroid Build Coastguard Worker    def test_bailout_loop_carried_deps_name_clash(self):
3269*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode_for_profiling_tests():
3270*da0073e9SAndroid Build Coastguard Worker            NUM_ITERATIONS = 10
3271*da0073e9SAndroid Build Coastguard Worker
3272*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
3273*da0073e9SAndroid Build Coastguard Worker            def fct_loop(z, size):
3274*da0073e9SAndroid Build Coastguard Worker                # type: (int, int) -> Tuple[Tensor, List[int]]
3275*da0073e9SAndroid Build Coastguard Worker                counters = torch.jit.annotate(List[int], [])
3276*da0073e9SAndroid Build Coastguard Worker                j = 0
3277*da0073e9SAndroid Build Coastguard Worker                y = torch.ones(2)
3278*da0073e9SAndroid Build Coastguard Worker                for i in range(size):
3279*da0073e9SAndroid Build Coastguard Worker                    counters.append(i + j)
3280*da0073e9SAndroid Build Coastguard Worker                    y = torch.cat((y, torch.ones(z)), 0)
3281*da0073e9SAndroid Build Coastguard Worker                    j = j + 1
3282*da0073e9SAndroid Build Coastguard Worker                return y, counters
3283*da0073e9SAndroid Build Coastguard Worker
3284*da0073e9SAndroid Build Coastguard Worker            inputs = [1, 2, 3, 4]
3285*da0073e9SAndroid Build Coastguard Worker            expected = [x * 2 for x in range(NUM_ITERATIONS)]
3286*da0073e9SAndroid Build Coastguard Worker            for inp in inputs:
3287*da0073e9SAndroid Build Coastguard Worker                results = fct_loop(inp, NUM_ITERATIONS)
3288*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(results[1], expected)
3289*da0073e9SAndroid Build Coastguard Worker
3290*da0073e9SAndroid Build Coastguard Worker    def test_bailout_loop_counter_transition(self):
3291*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode_for_profiling_tests():
3292*da0073e9SAndroid Build Coastguard Worker            NUM_ITERATIONS = 10
3293*da0073e9SAndroid Build Coastguard Worker
3294*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
3295*da0073e9SAndroid Build Coastguard Worker            def fct_loop(z, size):
3296*da0073e9SAndroid Build Coastguard Worker                # type: (int, int) -> Tuple[Tensor, List[int]]
3297*da0073e9SAndroid Build Coastguard Worker                counters = torch.jit.annotate(List[int], [])
3298*da0073e9SAndroid Build Coastguard Worker                y = torch.ones(2)
3299*da0073e9SAndroid Build Coastguard Worker                for i in range(size):
3300*da0073e9SAndroid Build Coastguard Worker                    counters.append(i)
3301*da0073e9SAndroid Build Coastguard Worker                    y = torch.cat((y, torch.ones(z)), 0)
3302*da0073e9SAndroid Build Coastguard Worker                return y, counters
3303*da0073e9SAndroid Build Coastguard Worker
3304*da0073e9SAndroid Build Coastguard Worker            inputs = [1, 2, 3, 4]
3305*da0073e9SAndroid Build Coastguard Worker            expected = list(range(NUM_ITERATIONS))
3306*da0073e9SAndroid Build Coastguard Worker            for inp in inputs:
3307*da0073e9SAndroid Build Coastguard Worker                results = fct_loop(inp, NUM_ITERATIONS)
3308*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(results[1], expected)
3309*da0073e9SAndroid Build Coastguard Worker
3310*da0073e9SAndroid Build Coastguard Worker    def test_ignored_method_binding(self):
3311*da0073e9SAndroid Build Coastguard Worker        class Bar(torch.nn.Module):
3312*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
3313*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3314*da0073e9SAndroid Build Coastguard Worker                self.x : int = 0
3315*da0073e9SAndroid Build Coastguard Worker
3316*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
3317*da0073e9SAndroid Build Coastguard Worker            def setx(self, x : int):
3318*da0073e9SAndroid Build Coastguard Worker                self.x = x
3319*da0073e9SAndroid Build Coastguard Worker
3320*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
3321*da0073e9SAndroid Build Coastguard Worker            def getx(self):
3322*da0073e9SAndroid Build Coastguard Worker                return self.x
3323*da0073e9SAndroid Build Coastguard Worker
3324*da0073e9SAndroid Build Coastguard Worker            @torch.jit.ignore
3325*da0073e9SAndroid Build Coastguard Worker            def ignored_getx(self):
3326*da0073e9SAndroid Build Coastguard Worker                return self.x
3327*da0073e9SAndroid Build Coastguard Worker
3328*da0073e9SAndroid Build Coastguard Worker        b = Bar()
3329*da0073e9SAndroid Build Coastguard Worker        b.setx(123)
3330*da0073e9SAndroid Build Coastguard Worker        sb = torch.jit.script(b)
3331*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(sb.getx(), 123)
3332*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(sb.ignored_getx(), 123)
3333*da0073e9SAndroid Build Coastguard Worker
3334*da0073e9SAndroid Build Coastguard Worker        sb.setx(456)
3335*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(sb.getx(), 456)
3336*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(sb.ignored_getx(), 456)
3337*da0073e9SAndroid Build Coastguard Worker
3338*da0073e9SAndroid Build Coastguard Worker    def test_set_attribute_through_optional(self):
3339*da0073e9SAndroid Build Coastguard Worker        class A(torch.nn.Module):
3340*da0073e9SAndroid Build Coastguard Worker            __annotations__ = {"x": Optional[torch.Tensor]}
3341*da0073e9SAndroid Build Coastguard Worker
3342*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
3343*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3344*da0073e9SAndroid Build Coastguard Worker                self.x = None
3345*da0073e9SAndroid Build Coastguard Worker
3346*da0073e9SAndroid Build Coastguard Worker            @torch.jit.ignore
3347*da0073e9SAndroid Build Coastguard Worker            def foo(self):
3348*da0073e9SAndroid Build Coastguard Worker                if self.x is None:
3349*da0073e9SAndroid Build Coastguard Worker                    self.x = torch.tensor([3])
3350*da0073e9SAndroid Build Coastguard Worker                return self.x
3351*da0073e9SAndroid Build Coastguard Worker
3352*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3353*da0073e9SAndroid Build Coastguard Worker                a = self.foo()
3354*da0073e9SAndroid Build Coastguard Worker                return x + 1
3355*da0073e9SAndroid Build Coastguard Worker
3356*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(A())
3357*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.x, None)
3358*da0073e9SAndroid Build Coastguard Worker        m(torch.rand(1))
3359*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.x, torch.tensor([3]))
3360*da0073e9SAndroid Build Coastguard Worker
3361*da0073e9SAndroid Build Coastguard Worker    def test_mutate_constant(self):
3362*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
3363*da0073e9SAndroid Build Coastguard Worker            __constants__ = ["foo"]
3364*da0073e9SAndroid Build Coastguard Worker
3365*da0073e9SAndroid Build Coastguard Worker            def __init__(self, foo):
3366*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3367*da0073e9SAndroid Build Coastguard Worker                self.foo = foo
3368*da0073e9SAndroid Build Coastguard Worker
3369*da0073e9SAndroid Build Coastguard Worker        m = M(5)
3370*da0073e9SAndroid Build Coastguard Worker        # m has a constant attribute, but we can't
3371*da0073e9SAndroid Build Coastguard Worker        # assign to it
3372*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
3373*da0073e9SAndroid Build Coastguard Worker            m.foo = 6
3374*da0073e9SAndroid Build Coastguard Worker
3375*da0073e9SAndroid Build Coastguard Worker    def test_class_attribute(self):
3376*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
3377*da0073e9SAndroid Build Coastguard Worker            FOO = 0
3378*da0073e9SAndroid Build Coastguard Worker
3379*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
3380*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3381*da0073e9SAndroid Build Coastguard Worker                self.foo = self.FOO
3382*da0073e9SAndroid Build Coastguard Worker        m = M()
3383*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.foo, M.FOO)
3384*da0073e9SAndroid Build Coastguard Worker
3385*da0073e9SAndroid Build Coastguard Worker    def test_class_attribute_in_script(self):
3386*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
3387*da0073e9SAndroid Build Coastguard Worker            FOO = 0
3388*da0073e9SAndroid Build Coastguard Worker
3389*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
3390*da0073e9SAndroid Build Coastguard Worker            def forward(self):
3391*da0073e9SAndroid Build Coastguard Worker                return self.FOO
3392*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
3393*da0073e9SAndroid Build Coastguard Worker            M()
3394*da0073e9SAndroid Build Coastguard Worker
3395*da0073e9SAndroid Build Coastguard Worker    def test_not_initialized_err(self):
3396*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
3397*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
3398*da0073e9SAndroid Build Coastguard Worker                self.foo = torch.rand(2, 3)
3399*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
3400*da0073e9SAndroid Build Coastguard Worker            M()
3401*da0073e9SAndroid Build Coastguard Worker
3402*da0073e9SAndroid Build Coastguard Worker    def test_attribute_in_init(self):
3403*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
3404*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
3405*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3406*da0073e9SAndroid Build Coastguard Worker                self.foo = torch.jit.Attribute(0.1, float)
3407*da0073e9SAndroid Build Coastguard Worker                # we should be able to use self.foo as a float here
3408*da0073e9SAndroid Build Coastguard Worker                assert 0.0 < self.foo
3409*da0073e9SAndroid Build Coastguard Worker        M()
3410*da0073e9SAndroid Build Coastguard Worker
3411*da0073e9SAndroid Build Coastguard Worker    def test_scriptable_fn_as_attr(self):
3412*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
3413*da0073e9SAndroid Build Coastguard Worker            def __init__(self, fn):
3414*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3415*da0073e9SAndroid Build Coastguard Worker                self.fn = fn
3416*da0073e9SAndroid Build Coastguard Worker
3417*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3418*da0073e9SAndroid Build Coastguard Worker                return self.fn(x)
3419*da0073e9SAndroid Build Coastguard Worker
3420*da0073e9SAndroid Build Coastguard Worker        m = M(torch.sigmoid)
3421*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand(2, 3)
3422*da0073e9SAndroid Build Coastguard Worker        self.checkModule(m, (inp, ))
3423*da0073e9SAndroid Build Coastguard Worker
3424*da0073e9SAndroid Build Coastguard Worker    def test_sequence_parsing(self):
3425*da0073e9SAndroid Build Coastguard Worker        tests = [
3426*da0073e9SAndroid Build Coastguard Worker            ("return [x, x,]", True),
3427*da0073e9SAndroid Build Coastguard Worker            ("return [x x]", "expected ]"),
3428*da0073e9SAndroid Build Coastguard Worker            ("return x, x,", True),
3429*da0073e9SAndroid Build Coastguard Worker            ("return bar(x, x,)", True),
3430*da0073e9SAndroid Build Coastguard Worker            ("return bar()", "Argument x not provided"),
3431*da0073e9SAndroid Build Coastguard Worker            ("for a, b, in x, x,:\n        pass", "List of iterables"),
3432*da0073e9SAndroid Build Coastguard Worker            ("a, b, = x, x,\n    return a + b", True)
3433*da0073e9SAndroid Build Coastguard Worker        ]
3434*da0073e9SAndroid Build Coastguard Worker        for exp, result in tests:
3435*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit()
3436*da0073e9SAndroid Build Coastguard Worker            full = f"""
3437*da0073e9SAndroid Build Coastguard Workerdef bar(x, y):
3438*da0073e9SAndroid Build Coastguard Worker    return x + y
3439*da0073e9SAndroid Build Coastguard Workerdef foo(x):
3440*da0073e9SAndroid Build Coastguard Worker    {exp}
3441*da0073e9SAndroid Build Coastguard Worker            """
3442*da0073e9SAndroid Build Coastguard Worker            if isinstance(result, str):
3443*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, result):
3444*da0073e9SAndroid Build Coastguard Worker                    cu.define(full)
3445*da0073e9SAndroid Build Coastguard Worker            else:
3446*da0073e9SAndroid Build Coastguard Worker                cu.define(full)
3447*da0073e9SAndroid Build Coastguard Worker
3448*da0073e9SAndroid Build Coastguard Worker    def test_namedtuple_python(self):
3449*da0073e9SAndroid Build Coastguard Worker        global MyTuple, MyMod  # see [local resolution in python]
3450*da0073e9SAndroid Build Coastguard Worker        MyTuple = namedtuple('MyTuple', ['a'])
3451*da0073e9SAndroid Build Coastguard Worker
3452*da0073e9SAndroid Build Coastguard Worker        @torch.jit.unused
3453*da0073e9SAndroid Build Coastguard Worker        def fn():
3454*da0073e9SAndroid Build Coastguard Worker            # type: () -> MyTuple
3455*da0073e9SAndroid Build Coastguard Worker            return MyTuple(1)
3456*da0073e9SAndroid Build Coastguard Worker
3457*da0073e9SAndroid Build Coastguard Worker        # Only check compilation
3458*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
3459*da0073e9SAndroid Build Coastguard Worker        def fn2():
3460*da0073e9SAndroid Build Coastguard Worker            # type: () -> MyTuple
3461*da0073e9SAndroid Build Coastguard Worker            return fn()
3462*da0073e9SAndroid Build Coastguard Worker
3463*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("NamedTuple").run(fn2.graph)
3464*da0073e9SAndroid Build Coastguard Worker
3465*da0073e9SAndroid Build Coastguard Worker        class MyMod(torch.nn.Module):
3466*da0073e9SAndroid Build Coastguard Worker            @torch.jit.unused
3467*da0073e9SAndroid Build Coastguard Worker            def fn(self):
3468*da0073e9SAndroid Build Coastguard Worker                # type: () -> MyTuple
3469*da0073e9SAndroid Build Coastguard Worker                return MyTuple(1)
3470*da0073e9SAndroid Build Coastguard Worker
3471*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3472*da0073e9SAndroid Build Coastguard Worker                if 1 == 1:
3473*da0073e9SAndroid Build Coastguard Worker                    return MyTuple(torch.rand(2, 3))
3474*da0073e9SAndroid Build Coastguard Worker                else:
3475*da0073e9SAndroid Build Coastguard Worker                    return self.fn()
3476*da0073e9SAndroid Build Coastguard Worker
3477*da0073e9SAndroid Build Coastguard Worker        # shouldn't throw a type error
3478*da0073e9SAndroid Build Coastguard Worker        torch.jit.script(MyMod())
3479*da0073e9SAndroid Build Coastguard Worker
3480*da0073e9SAndroid Build Coastguard Worker    def test_unused_decorator(self):
3481*da0073e9SAndroid Build Coastguard Worker        class MyMod(torch.nn.Module):
3482*da0073e9SAndroid Build Coastguard Worker            @torch.jit.unused
3483*da0073e9SAndroid Build Coastguard Worker            @torch.no_grad()
3484*da0073e9SAndroid Build Coastguard Worker            def fn(self, x):
3485*da0073e9SAndroid Build Coastguard Worker                # type: (Tensor) -> int
3486*da0073e9SAndroid Build Coastguard Worker                return next(x)  # invalid, but should be ignored
3487*da0073e9SAndroid Build Coastguard Worker
3488*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3489*da0073e9SAndroid Build Coastguard Worker                return self.fn(x)
3490*da0073e9SAndroid Build Coastguard Worker
3491*da0073e9SAndroid Build Coastguard Worker        torch.jit.script(MyMod())
3492*da0073e9SAndroid Build Coastguard Worker
3493*da0073e9SAndroid Build Coastguard Worker    @_inline_everything
3494*da0073e9SAndroid Build Coastguard Worker    def test_lazy_script(self):
3495*da0073e9SAndroid Build Coastguard Worker        def untraceable(x):
3496*da0073e9SAndroid Build Coastguard Worker            if x.ndim > 2:
3497*da0073e9SAndroid Build Coastguard Worker                print("hello")
3498*da0073e9SAndroid Build Coastguard Worker            else:
3499*da0073e9SAndroid Build Coastguard Worker                print("goodbye")
3500*da0073e9SAndroid Build Coastguard Worker            return x + 2
3501*da0073e9SAndroid Build Coastguard Worker
3502*da0073e9SAndroid Build Coastguard Worker        # Non-working example
3503*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3504*da0073e9SAndroid Build Coastguard Worker            return untraceable(x)
3505*da0073e9SAndroid Build Coastguard Worker
3506*da0073e9SAndroid Build Coastguard Worker        with self.capture_stdout():
3507*da0073e9SAndroid Build Coastguard Worker            traced_bad = torch.jit.trace(fn, [torch.ones(2, 2)])
3508*da0073e9SAndroid Build Coastguard Worker
3509*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("goodbye").check_not("hello").run(traced_bad.graph)
3510*da0073e9SAndroid Build Coastguard Worker
3511*da0073e9SAndroid Build Coastguard Worker        # Working example
3512*da0073e9SAndroid Build Coastguard Worker        untraceable = torch.jit.script_if_tracing(untraceable)
3513*da0073e9SAndroid Build Coastguard Worker
3514*da0073e9SAndroid Build Coastguard Worker        def fn2(x):
3515*da0073e9SAndroid Build Coastguard Worker            return untraceable(x)
3516*da0073e9SAndroid Build Coastguard Worker
3517*da0073e9SAndroid Build Coastguard Worker        with self.capture_stdout():
3518*da0073e9SAndroid Build Coastguard Worker            traced = torch.jit.trace(fn, [torch.ones(2, 2)])
3519*da0073e9SAndroid Build Coastguard Worker
3520*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("goodbye").run(traced.graph)
3521*da0073e9SAndroid Build Coastguard Worker
3522*da0073e9SAndroid Build Coastguard Worker        def foo(x: int):
3523*da0073e9SAndroid Build Coastguard Worker            return x + 1
3524*da0073e9SAndroid Build Coastguard Worker
3525*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script_if_tracing
3526*da0073e9SAndroid Build Coastguard Worker        def fee(x: int = 2):
3527*da0073e9SAndroid Build Coastguard Worker            return foo(1) + x
3528*da0073e9SAndroid Build Coastguard Worker
3529*da0073e9SAndroid Build Coastguard Worker        # test directly compiling function
3530*da0073e9SAndroid Build Coastguard Worker        fee_compiled = torch.jit.script(fee)
3531*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fee_compiled(), fee())
3532*da0073e9SAndroid Build Coastguard Worker
3533*da0073e9SAndroid Build Coastguard Worker        # test compiling it within another function
3534*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
3535*da0073e9SAndroid Build Coastguard Worker        def hum():
3536*da0073e9SAndroid Build Coastguard Worker            return fee(x=3)
3537*da0073e9SAndroid Build Coastguard Worker
3538*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(hum(), 5)
3539*da0073e9SAndroid Build Coastguard Worker
3540*da0073e9SAndroid Build Coastguard Worker    def test_big_int_literals(self):
3541*da0073e9SAndroid Build Coastguard Worker        def ok():
3542*da0073e9SAndroid Build Coastguard Worker            # signed 64 bit max
3543*da0073e9SAndroid Build Coastguard Worker            a = 9223372036854775807
3544*da0073e9SAndroid Build Coastguard Worker            return a
3545*da0073e9SAndroid Build Coastguard Worker
3546*da0073e9SAndroid Build Coastguard Worker        def toobig():
3547*da0073e9SAndroid Build Coastguard Worker            a = 9223372036854775808
3548*da0073e9SAndroid Build Coastguard Worker            return a
3549*da0073e9SAndroid Build Coastguard Worker
3550*da0073e9SAndroid Build Coastguard Worker        def waytoobig():
3551*da0073e9SAndroid Build Coastguard Worker            a = 99999999999999999999
3552*da0073e9SAndroid Build Coastguard Worker            return a
3553*da0073e9SAndroid Build Coastguard Worker
3554*da0073e9SAndroid Build Coastguard Worker        self.checkScript(ok, [])
3555*da0073e9SAndroid Build Coastguard Worker
3556*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "out of range"):
3557*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(toobig)
3558*da0073e9SAndroid Build Coastguard Worker
3559*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "out of range"):
3560*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(waytoobig)
3561*da0073e9SAndroid Build Coastguard Worker
3562*da0073e9SAndroid Build Coastguard Worker    def test_hex_literals(self):
3563*da0073e9SAndroid Build Coastguard Worker        def test1():
3564*da0073e9SAndroid Build Coastguard Worker            return 0xaaaaaa
3565*da0073e9SAndroid Build Coastguard Worker
3566*da0073e9SAndroid Build Coastguard Worker        def test2():
3567*da0073e9SAndroid Build Coastguard Worker            return 0xaaaaaa
3568*da0073e9SAndroid Build Coastguard Worker
3569*da0073e9SAndroid Build Coastguard Worker        def test3():
3570*da0073e9SAndroid Build Coastguard Worker            return -0xaaaaaa
3571*da0073e9SAndroid Build Coastguard Worker
3572*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test1, [])
3573*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test2, [])
3574*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test3, [])
3575*da0073e9SAndroid Build Coastguard Worker
3576*da0073e9SAndroid Build Coastguard Worker        def ok():
3577*da0073e9SAndroid Build Coastguard Worker            a = 0x7FFFFFFFFFFFFFFF
3578*da0073e9SAndroid Build Coastguard Worker            return a
3579*da0073e9SAndroid Build Coastguard Worker
3580*da0073e9SAndroid Build Coastguard Worker        def toobig():
3581*da0073e9SAndroid Build Coastguard Worker            a = 0xFFFFFFFFFFFFFFFF
3582*da0073e9SAndroid Build Coastguard Worker            return a
3583*da0073e9SAndroid Build Coastguard Worker
3584*da0073e9SAndroid Build Coastguard Worker        def waytoobig():
3585*da0073e9SAndroid Build Coastguard Worker            a = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF
3586*da0073e9SAndroid Build Coastguard Worker            return a
3587*da0073e9SAndroid Build Coastguard Worker
3588*da0073e9SAndroid Build Coastguard Worker        self.checkScript(ok, [])
3589*da0073e9SAndroid Build Coastguard Worker
3590*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "out of range"):
3591*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(toobig)
3592*da0073e9SAndroid Build Coastguard Worker
3593*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "out of range"):
3594*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(waytoobig)
3595*da0073e9SAndroid Build Coastguard Worker
3596*da0073e9SAndroid Build Coastguard Worker    def test_big_float_literals(self):
3597*da0073e9SAndroid Build Coastguard Worker        def ok():
3598*da0073e9SAndroid Build Coastguard Worker            # Python interprets this as inf
3599*da0073e9SAndroid Build Coastguard Worker            a = 1.2E400
3600*da0073e9SAndroid Build Coastguard Worker            return a
3601*da0073e9SAndroid Build Coastguard Worker
3602*da0073e9SAndroid Build Coastguard Worker        def check(fn):
3603*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(fn() == ok())
3604*da0073e9SAndroid Build Coastguard Worker
3605*da0073e9SAndroid Build Coastguard Worker        # checkScript doesn't work since assertEqual doesn't consider
3606*da0073e9SAndroid Build Coastguard Worker        # `inf` == `inf`
3607*da0073e9SAndroid Build Coastguard Worker        check(torch.jit.script(ok))
3608*da0073e9SAndroid Build Coastguard Worker
3609*da0073e9SAndroid Build Coastguard Worker        cu = torch.jit.CompilationUnit()
3610*da0073e9SAndroid Build Coastguard Worker        cu.define(dedent(inspect.getsource(ok)))
3611*da0073e9SAndroid Build Coastguard Worker        check(cu.ok)
3612*da0073e9SAndroid Build Coastguard Worker
3613*da0073e9SAndroid Build Coastguard Worker    def _test_device_type(self, dest):
3614*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3615*da0073e9SAndroid Build Coastguard Worker            # type: (Device) -> Tuple[str, Optional[int]]
3616*da0073e9SAndroid Build Coastguard Worker            return x.type, x.index
3617*da0073e9SAndroid Build Coastguard Worker
3618*da0073e9SAndroid Build Coastguard Worker        device = torch.ones(2).to(dest).device
3619*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, [device])
3620*da0073e9SAndroid Build Coastguard Worker
3621*da0073e9SAndroid Build Coastguard Worker    def test_device_type(self):
3622*da0073e9SAndroid Build Coastguard Worker        self._test_device_type('cpu')
3623*da0073e9SAndroid Build Coastguard Worker
3624*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "Requires CUDA")
3625*da0073e9SAndroid Build Coastguard Worker    def test_device_type_cuda(self):
3626*da0073e9SAndroid Build Coastguard Worker        self._test_device_type('cuda')
3627*da0073e9SAndroid Build Coastguard Worker
3628*da0073e9SAndroid Build Coastguard Worker    def test_string_device_implicit_conversion(self):
3629*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
3630*da0073e9SAndroid Build Coastguard Worker        def fn(x: torch.device):
3631*da0073e9SAndroid Build Coastguard Worker            return x
3632*da0073e9SAndroid Build Coastguard Worker
3633*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn("cpu"), torch.device("cpu"))
3634*da0073e9SAndroid Build Coastguard Worker
3635*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Expected one of"):
3636*da0073e9SAndroid Build Coastguard Worker            fn("invalid_device")
3637*da0073e9SAndroid Build Coastguard Worker
3638*da0073e9SAndroid Build Coastguard Worker    def test_eval_python(self):
3639*da0073e9SAndroid Build Coastguard Worker        def _test(m):
3640*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(m(torch.ones(2, 2)))
3641*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(m.training)
3642*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(m._c.getattr('training'))
3643*da0073e9SAndroid Build Coastguard Worker
3644*da0073e9SAndroid Build Coastguard Worker            m.eval()
3645*da0073e9SAndroid Build Coastguard Worker
3646*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(m.training)
3647*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(m._c.getattr('training'))
3648*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(m(torch.ones(2, 2)))
3649*da0073e9SAndroid Build Coastguard Worker
3650*da0073e9SAndroid Build Coastguard Worker            buffer = io.BytesIO()
3651*da0073e9SAndroid Build Coastguard Worker            torch.jit.save(m, buffer)
3652*da0073e9SAndroid Build Coastguard Worker            buffer.seek(0)
3653*da0073e9SAndroid Build Coastguard Worker
3654*da0073e9SAndroid Build Coastguard Worker            loaded = torch.jit.load(buffer)
3655*da0073e9SAndroid Build Coastguard Worker
3656*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(loaded.training)
3657*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(loaded._c.getattr('training'))
3658*da0073e9SAndroid Build Coastguard Worker
3659*da0073e9SAndroid Build Coastguard Worker        class M(nn.Module):
3660*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3661*da0073e9SAndroid Build Coastguard Worker                return self.training
3662*da0073e9SAndroid Build Coastguard Worker
3663*da0073e9SAndroid Build Coastguard Worker        class OldM(torch.jit.ScriptModule):
3664*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
3665*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3666*da0073e9SAndroid Build Coastguard Worker                return self.training
3667*da0073e9SAndroid Build Coastguard Worker
3668*da0073e9SAndroid Build Coastguard Worker        _test(torch.jit.script(M()))
3669*da0073e9SAndroid Build Coastguard Worker        _test(OldM())
3670*da0073e9SAndroid Build Coastguard Worker
3671*da0073e9SAndroid Build Coastguard Worker    def test_inherit_method(self):
3672*da0073e9SAndroid Build Coastguard Worker        class A(torch.jit.ScriptModule):
3673*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
3674*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3675*da0073e9SAndroid Build Coastguard Worker                return x + self.bar(x)
3676*da0073e9SAndroid Build Coastguard Worker
3677*da0073e9SAndroid Build Coastguard Worker        class B(A):
3678*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
3679*da0073e9SAndroid Build Coastguard Worker            def bar(self, x):
3680*da0073e9SAndroid Build Coastguard Worker                return x * x
3681*da0073e9SAndroid Build Coastguard Worker
3682*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'attribute'):
3683*da0073e9SAndroid Build Coastguard Worker            A()  # cannot use because bar is not defined
3684*da0073e9SAndroid Build Coastguard Worker
3685*da0073e9SAndroid Build Coastguard Worker        v = torch.rand(3, 4)
3686*da0073e9SAndroid Build Coastguard Worker        b = B()
3687*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b(v), v + v * v)
3688*da0073e9SAndroid Build Coastguard Worker
3689*da0073e9SAndroid Build Coastguard Worker        class C(torch.jit.ScriptModule):
3690*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
3691*da0073e9SAndroid Build Coastguard Worker            def bar(self, x):
3692*da0073e9SAndroid Build Coastguard Worker                return x
3693*da0073e9SAndroid Build Coastguard Worker
3694*da0073e9SAndroid Build Coastguard Worker        class D(C, B):
3695*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
3696*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3697*da0073e9SAndroid Build Coastguard Worker
3698*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(D()(v), v + v)
3699*da0073e9SAndroid Build Coastguard Worker
3700*da0073e9SAndroid Build Coastguard Worker    def test_tensor_subclasses(self):
3701*da0073e9SAndroid Build Coastguard Worker        def check_subclass(x, tensor):
3702*da0073e9SAndroid Build Coastguard Worker            template = dedent("""
3703*da0073e9SAndroid Build Coastguard Worker                def func(input: {}) -> {}:
3704*da0073e9SAndroid Build Coastguard Worker                    return torch.zeros((input.shape[0], 1), dtype=input.dtype)
3705*da0073e9SAndroid Build Coastguard Worker                """)
3706*da0073e9SAndroid Build Coastguard Worker
3707*da0073e9SAndroid Build Coastguard Worker            self._check_code(template.format(x, x), "func", [tensor])
3708*da0073e9SAndroid Build Coastguard Worker
3709*da0073e9SAndroid Build Coastguard Worker        check_subclass("torch.LongTensor", torch.LongTensor([[1, 2], [3, 4]]))
3710*da0073e9SAndroid Build Coastguard Worker        check_subclass("torch.DoubleTensor", torch.DoubleTensor([[1.2, 2.3], [3.4, 4.5]]))
3711*da0073e9SAndroid Build Coastguard Worker        check_subclass("torch.IntTensor", torch.IntTensor([[1, 2], [3, 4]]))
3712*da0073e9SAndroid Build Coastguard Worker        check_subclass("torch.BoolTensor", torch.BoolTensor([[False, True], [True, False]]))
3713*da0073e9SAndroid Build Coastguard Worker
3714*da0073e9SAndroid Build Coastguard Worker        def check_subclass_warn(input: torch.LongTensor) -> torch.LongTensor:
3715*da0073e9SAndroid Build Coastguard Worker            return torch.zeros((input.shape[0], 1), dtype=input.dtype)
3716*da0073e9SAndroid Build Coastguard Worker
3717*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as warns:
3718*da0073e9SAndroid Build Coastguard Worker            scripted = torch.jit.script(check_subclass_warn)
3719*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("TorchScript will treat type annotations of Tensor").run(str(warns[0]))
3720*da0073e9SAndroid Build Coastguard Worker
3721*da0073e9SAndroid Build Coastguard Worker    def test_first_class_module(self):
3722*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.jit.ScriptModule):
3723*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
3724*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3725*da0073e9SAndroid Build Coastguard Worker                self.foo = nn.Parameter(torch.rand(3, 4))
3726*da0073e9SAndroid Build Coastguard Worker
3727*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
3728*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
3729*da0073e9SAndroid Build Coastguard Worker                self.foo = input
3730*da0073e9SAndroid Build Coastguard Worker                return self.foo
3731*da0073e9SAndroid Build Coastguard Worker        foo = Foo()
3732*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(3, 4)
3733*da0073e9SAndroid Build Coastguard Worker        foo.forward(input)
3734*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(input, foo.foo)
3735*da0073e9SAndroid Build Coastguard Worker
3736*da0073e9SAndroid Build Coastguard Worker    @_tmp_donotuse_dont_inline_everything
3737*da0073e9SAndroid Build Coastguard Worker    def test_first_class_calls(self):
3738*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
3739*da0073e9SAndroid Build Coastguard Worker        class Foo:
3740*da0073e9SAndroid Build Coastguard Worker            def __init__(self, x):
3741*da0073e9SAndroid Build Coastguard Worker                self.bar = x
3742*da0073e9SAndroid Build Coastguard Worker
3743*da0073e9SAndroid Build Coastguard Worker            def stuff(self, x):
3744*da0073e9SAndroid Build Coastguard Worker                return self.bar + x
3745*da0073e9SAndroid Build Coastguard Worker
3746*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
3747*da0073e9SAndroid Build Coastguard Worker        def foo(x):
3748*da0073e9SAndroid Build Coastguard Worker            return x * x + Foo(x).stuff(2 * x)
3749*da0073e9SAndroid Build Coastguard Worker
3750*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
3751*da0073e9SAndroid Build Coastguard Worker        def bar(x):
3752*da0073e9SAndroid Build Coastguard Worker            return foo(x) * foo(x)
3753*da0073e9SAndroid Build Coastguard Worker
3754*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
3755*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bar(x), (x * x + 3 * x) * (x * x + 3 * x))
3756*da0073e9SAndroid Build Coastguard Worker
3757*da0073e9SAndroid Build Coastguard Worker    def test_static_methods(self):
3758*da0073e9SAndroid Build Coastguard Worker        class M(nn.Module):
3759*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3760*da0073e9SAndroid Build Coastguard Worker            def my_method(x):
3761*da0073e9SAndroid Build Coastguard Worker                return x + 100
3762*da0073e9SAndroid Build Coastguard Worker
3763*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3764*da0073e9SAndroid Build Coastguard Worker                return x + M.my_method(x)
3765*da0073e9SAndroid Build Coastguard Worker
3766*da0073e9SAndroid Build Coastguard Worker        class N(nn.Module):
3767*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3768*da0073e9SAndroid Build Coastguard Worker            def my_method(x):
3769*da0073e9SAndroid Build Coastguard Worker                return x * 100
3770*da0073e9SAndroid Build Coastguard Worker
3771*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3772*da0073e9SAndroid Build Coastguard Worker                return x - M.my_method(x) + N.my_method(x)
3773*da0073e9SAndroid Build Coastguard Worker
3774*da0073e9SAndroid Build Coastguard Worker        self.checkModule(M(), (torch.ones(2, 2),))
3775*da0073e9SAndroid Build Coastguard Worker
3776*da0073e9SAndroid Build Coastguard Worker        self.checkModule(N(), (torch.ones(2, 2),))
3777*da0073e9SAndroid Build Coastguard Worker
3778*da0073e9SAndroid Build Coastguard Worker    def test_invalid_prefix_annotation(self):
3779*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "annotation prefix in line"):
3780*da0073e9SAndroid Build Coastguard Worker            with self.capture_stdout() as captured:
3781*da0073e9SAndroid Build Coastguard Worker                @torch.jit.script
3782*da0073e9SAndroid Build Coastguard Worker                def invalid_prefix_annotation1(a):
3783*da0073e9SAndroid Build Coastguard Worker                    #type: (Int) -> Int # noqa: E265
3784*da0073e9SAndroid Build Coastguard Worker                    return a + 2
3785*da0073e9SAndroid Build Coastguard Worker
3786*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "annotation prefix in line"):
3787*da0073e9SAndroid Build Coastguard Worker            with self.capture_stdout() as captured:
3788*da0073e9SAndroid Build Coastguard Worker                @torch.jit.script
3789*da0073e9SAndroid Build Coastguard Worker                def invalid_prefix_annotation2(a):
3790*da0073e9SAndroid Build Coastguard Worker                    #type   : (Int) -> Int # noqa: E265
3791*da0073e9SAndroid Build Coastguard Worker                    return a + 2
3792*da0073e9SAndroid Build Coastguard Worker
3793*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "annotation prefix in line"):
3794*da0073e9SAndroid Build Coastguard Worker            with self.capture_stdout() as captured:
3795*da0073e9SAndroid Build Coastguard Worker                @torch.jit.script
3796*da0073e9SAndroid Build Coastguard Worker                def invalid_prefix_annotation3(a):
3797*da0073e9SAndroid Build Coastguard Worker                    #     type: (Int) -> Int
3798*da0073e9SAndroid Build Coastguard Worker                    return a + 2
3799*da0073e9SAndroid Build Coastguard Worker
3800*da0073e9SAndroid Build Coastguard Worker    def test_builtin_function_attributes(self):
3801*da0073e9SAndroid Build Coastguard Worker        class Add(nn.Module):
3802*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
3803*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3804*da0073e9SAndroid Build Coastguard Worker                self.add = torch.add
3805*da0073e9SAndroid Build Coastguard Worker
3806*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
3807*da0073e9SAndroid Build Coastguard Worker                return self.add(input, input)
3808*da0073e9SAndroid Build Coastguard Worker
3809*da0073e9SAndroid Build Coastguard Worker        self.checkModule(Add(), [torch.randn(2, 2)])
3810*da0073e9SAndroid Build Coastguard Worker
3811*da0073e9SAndroid Build Coastguard Worker    def test_pybind_type_comparisons(self):
3812*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
3813*da0073e9SAndroid Build Coastguard Worker        def f():
3814*da0073e9SAndroid Build Coastguard Worker            return None
3815*da0073e9SAndroid Build Coastguard Worker
3816*da0073e9SAndroid Build Coastguard Worker        node = list(f.graph.nodes())[0]
3817*da0073e9SAndroid Build Coastguard Worker        t = node.outputsAt(0).type()
3818*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(t)
3819*da0073e9SAndroid Build Coastguard Worker
3820*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, 'TODO: need to fix the test case')
3821*da0073e9SAndroid Build Coastguard Worker    def test_unmatched_type_annotation(self):
3822*da0073e9SAndroid Build Coastguard Worker        message1 = re.escape("Number of type annotations (2) did not match the number of function parameters (1):")
3823*da0073e9SAndroid Build Coastguard Worker        message2 = 'def invalid2\\(a\\):\n\\s*~+\\.*\\s+<--- HERE\n\\s+# type: \\(Int, Int\\) -> Int\n\\s+return a \\+ 2'
3824*da0073e9SAndroid Build Coastguard Worker        message3 = 'def invalid4\\(a\\):\n\\s*~+\\.*\\s+<--- HERE\n\\s+# type: \\(Int, Int\\) -> Int\n\\s+return a \\+ 2'
3825*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, message1):
3826*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
3827*da0073e9SAndroid Build Coastguard Worker            def invalid1(a):
3828*da0073e9SAndroid Build Coastguard Worker                # type: (Int, Int) -> Int
3829*da0073e9SAndroid Build Coastguard Worker                return a + 2
3830*da0073e9SAndroid Build Coastguard Worker
3831*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, message2):
3832*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
3833*da0073e9SAndroid Build Coastguard Worker            def invalid2(a):
3834*da0073e9SAndroid Build Coastguard Worker                # type: (Int, Int) -> Int
3835*da0073e9SAndroid Build Coastguard Worker                return a + 2
3836*da0073e9SAndroid Build Coastguard Worker
3837*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, message1):
3838*da0073e9SAndroid Build Coastguard Worker            def invalid3(a):
3839*da0073e9SAndroid Build Coastguard Worker                # type: (Int, Int) -> Int
3840*da0073e9SAndroid Build Coastguard Worker                return a + 2
3841*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(invalid3)
3842*da0073e9SAndroid Build Coastguard Worker
3843*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, message3):
3844*da0073e9SAndroid Build Coastguard Worker            def invalid4(a):
3845*da0073e9SAndroid Build Coastguard Worker                # type: (Int, Int) -> Int
3846*da0073e9SAndroid Build Coastguard Worker                return a + 2
3847*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(invalid4)
3848*da0073e9SAndroid Build Coastguard Worker
3849*da0073e9SAndroid Build Coastguard Worker    def test_calls_in_type_annotations(self):
3850*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Type annotation should not contain calls"):
3851*da0073e9SAndroid Build Coastguard Worker            def spooky(a):
3852*da0073e9SAndroid Build Coastguard Worker                # type: print("Hello") -> Tensor # noqa: F723
3853*da0073e9SAndroid Build Coastguard Worker                return a + 2
3854*da0073e9SAndroid Build Coastguard Worker            print(torch.__file__)
3855*da0073e9SAndroid Build Coastguard Worker            torch.jit.annotations.get_signature(spooky, None, 1, True)
3856*da0073e9SAndroid Build Coastguard Worker
3857*da0073e9SAndroid Build Coastguard Worker    def test_is_optional(self):
3858*da0073e9SAndroid Build Coastguard Worker        ann = Union[List[int], List[float]]
3859*da0073e9SAndroid Build Coastguard Worker        torch._jit_internal.is_optional(ann)
3860*da0073e9SAndroid Build Coastguard Worker
3861*da0073e9SAndroid Build Coastguard Worker    def test_interpreter_fuzz(self):
3862*da0073e9SAndroid Build Coastguard Worker        import builtins
3863*da0073e9SAndroid Build Coastguard Worker        # This test generates random tree-like programs to fuzz test
3864*da0073e9SAndroid Build Coastguard Worker        # that the interpreter does not have a bug in its stack manipulation
3865*da0073e9SAndroid Build Coastguard Worker        # code. An assert in that code ensures individual operators are
3866*da0073e9SAndroid Build Coastguard Worker        # not reordered.
3867*da0073e9SAndroid Build Coastguard Worker        templates = [
3868*da0073e9SAndroid Build Coastguard Worker            "torch.rand(3, 4)",
3869*da0073e9SAndroid Build Coastguard Worker            "({} + {})",
3870*da0073e9SAndroid Build Coastguard Worker            "-{}",
3871*da0073e9SAndroid Build Coastguard Worker            "({} * {})",
3872*da0073e9SAndroid Build Coastguard Worker            "torch.tanh({})",
3873*da0073e9SAndroid Build Coastguard Worker            "VAR {}",
3874*da0073e9SAndroid Build Coastguard Worker        ]
3875*da0073e9SAndroid Build Coastguard Worker
3876*da0073e9SAndroid Build Coastguard Worker        def gen_code():
3877*da0073e9SAndroid Build Coastguard Worker            src_lines = ['def f():']
3878*da0073e9SAndroid Build Coastguard Worker            exprs = []
3879*da0073e9SAndroid Build Coastguard Worker            n_variables = 0
3880*da0073e9SAndroid Build Coastguard Worker
3881*da0073e9SAndroid Build Coastguard Worker            def get_expr(idx):
3882*da0073e9SAndroid Build Coastguard Worker                elem = exprs[idx]
3883*da0073e9SAndroid Build Coastguard Worker                exprs[idx] = exprs[-1]
3884*da0073e9SAndroid Build Coastguard Worker                exprs.pop()
3885*da0073e9SAndroid Build Coastguard Worker                return elem
3886*da0073e9SAndroid Build Coastguard Worker
3887*da0073e9SAndroid Build Coastguard Worker            def select_expr_or_var():
3888*da0073e9SAndroid Build Coastguard Worker                idx = random.randrange(0, len(exprs) + n_variables)
3889*da0073e9SAndroid Build Coastguard Worker                if idx < len(exprs):
3890*da0073e9SAndroid Build Coastguard Worker                    return get_expr(idx)
3891*da0073e9SAndroid Build Coastguard Worker                else:
3892*da0073e9SAndroid Build Coastguard Worker                    return f'v{idx - len(exprs)}'
3893*da0073e9SAndroid Build Coastguard Worker
3894*da0073e9SAndroid Build Coastguard Worker            for i in range(50):
3895*da0073e9SAndroid Build Coastguard Worker                n = None
3896*da0073e9SAndroid Build Coastguard Worker                while n is None or n > len(exprs) + n_variables:
3897*da0073e9SAndroid Build Coastguard Worker                    template = random.choice(templates)
3898*da0073e9SAndroid Build Coastguard Worker                    n = template.count('{}')
3899*da0073e9SAndroid Build Coastguard Worker
3900*da0073e9SAndroid Build Coastguard Worker                if 'VAR' in template:
3901*da0073e9SAndroid Build Coastguard Worker                    src_lines.append(f'  v{n_variables} = {select_expr_or_var()}')
3902*da0073e9SAndroid Build Coastguard Worker                    n_variables += 1
3903*da0073e9SAndroid Build Coastguard Worker                else:
3904*da0073e9SAndroid Build Coastguard Worker                    exprs.append(template.format(*(select_expr_or_var() for _ in range(n))))
3905*da0073e9SAndroid Build Coastguard Worker
3906*da0073e9SAndroid Build Coastguard Worker            src_lines.append('  return ({})\n'.format(''.join(f'v{i},' for i in range(n_variables))))
3907*da0073e9SAndroid Build Coastguard Worker            return '\n'.join(src_lines)
3908*da0073e9SAndroid Build Coastguard Worker
3909*da0073e9SAndroid Build Coastguard Worker        for i in range(100):
3910*da0073e9SAndroid Build Coastguard Worker            g = {'torch': torch}
3911*da0073e9SAndroid Build Coastguard Worker            code = gen_code()
3912*da0073e9SAndroid Build Coastguard Worker            builtins.exec(code, g, None)
3913*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit(code)
3914*da0073e9SAndroid Build Coastguard Worker            with freeze_rng_state():
3915*da0073e9SAndroid Build Coastguard Worker                o1 = g['f']()
3916*da0073e9SAndroid Build Coastguard Worker            with freeze_rng_state():
3917*da0073e9SAndroid Build Coastguard Worker                o2 = cu.f()
3918*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(o1, o2)
3919*da0073e9SAndroid Build Coastguard Worker
3920*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
3921*da0073e9SAndroid Build Coastguard Worker    def test_cpp_module_iterator(self):
3922*da0073e9SAndroid Build Coastguard Worker        a = nn.Module()
3923*da0073e9SAndroid Build Coastguard Worker        a.name = 'a'
3924*da0073e9SAndroid Build Coastguard Worker        a.p = nn.Parameter(torch.rand(3, 4))
3925*da0073e9SAndroid Build Coastguard Worker        a.foo = nn.Module()
3926*da0073e9SAndroid Build Coastguard Worker        a.foo.name = 'foo'
3927*da0073e9SAndroid Build Coastguard Worker        a.foo.b = nn.Buffer(torch.rand(1, 1))
3928*da0073e9SAndroid Build Coastguard Worker        a.foo.bar = nn.Module()
3929*da0073e9SAndroid Build Coastguard Worker        a.foo.bar.name = 'bar'
3930*da0073e9SAndroid Build Coastguard Worker        a.foo.bar.an_int = 4
3931*da0073e9SAndroid Build Coastguard Worker        a.another = nn.Module()
3932*da0073e9SAndroid Build Coastguard Worker        a.another.name = 'another'
3933*da0073e9SAndroid Build Coastguard Worker        sa = torch.jit.script(a)
3934*da0073e9SAndroid Build Coastguard Worker        result = torch._C._jit_debug_module_iterators(sa._c)
3935*da0073e9SAndroid Build Coastguard Worker
3936*da0073e9SAndroid Build Coastguard Worker        def replace(e):
3937*da0073e9SAndroid Build Coastguard Worker            if e is a.p:
3938*da0073e9SAndroid Build Coastguard Worker                return 'P'
3939*da0073e9SAndroid Build Coastguard Worker            elif e is a.foo.b:
3940*da0073e9SAndroid Build Coastguard Worker                return 'B'
3941*da0073e9SAndroid Build Coastguard Worker            elif isinstance(e, torch._C.ScriptModule):
3942*da0073e9SAndroid Build Coastguard Worker                return e.getattr('name')
3943*da0073e9SAndroid Build Coastguard Worker
3944*da0073e9SAndroid Build Coastguard Worker            return e
3945*da0073e9SAndroid Build Coastguard Worker        for v in result.values():
3946*da0073e9SAndroid Build Coastguard Worker            for i in range(len(v)):
3947*da0073e9SAndroid Build Coastguard Worker                if isinstance(v[i], tuple):
3948*da0073e9SAndroid Build Coastguard Worker                    n, v2 = v[i]
3949*da0073e9SAndroid Build Coastguard Worker                    v[i] = (n, replace(v2))
3950*da0073e9SAndroid Build Coastguard Worker                else:
3951*da0073e9SAndroid Build Coastguard Worker                    v[i] = replace(v[i])
3952*da0073e9SAndroid Build Coastguard Worker            # module type creation is not deterministic, so we have to sort
3953*da0073e9SAndroid Build Coastguard Worker            # the result
3954*da0073e9SAndroid Build Coastguard Worker            v.sort()
3955*da0073e9SAndroid Build Coastguard Worker        expected = {'buffers': [],
3956*da0073e9SAndroid Build Coastguard Worker                    'buffers_r': ['B'],
3957*da0073e9SAndroid Build Coastguard Worker                    'children': ['another', 'foo'],
3958*da0073e9SAndroid Build Coastguard Worker                    'modules': ['a', 'another', 'bar', 'foo'],
3959*da0073e9SAndroid Build Coastguard Worker                    'named_attributes': [('_is_full_backward_hook', None),
3960*da0073e9SAndroid Build Coastguard Worker                                         ('another', 'another'),
3961*da0073e9SAndroid Build Coastguard Worker                                         ('foo', 'foo'),
3962*da0073e9SAndroid Build Coastguard Worker                                         ('name', 'a'),
3963*da0073e9SAndroid Build Coastguard Worker                                         ('p', 'P'),
3964*da0073e9SAndroid Build Coastguard Worker                                         ('training', True)],
3965*da0073e9SAndroid Build Coastguard Worker                    'named_attributes_r': [('_is_full_backward_hook', None),
3966*da0073e9SAndroid Build Coastguard Worker                                           ('another', 'another'),
3967*da0073e9SAndroid Build Coastguard Worker                                           ('another._is_full_backward_hook', None),
3968*da0073e9SAndroid Build Coastguard Worker                                           ('another.name', 'another'),
3969*da0073e9SAndroid Build Coastguard Worker                                           ('another.training', True),
3970*da0073e9SAndroid Build Coastguard Worker                                           ('foo', 'foo'),
3971*da0073e9SAndroid Build Coastguard Worker                                           ('foo._is_full_backward_hook', None),
3972*da0073e9SAndroid Build Coastguard Worker                                           ('foo.b', 'B'),
3973*da0073e9SAndroid Build Coastguard Worker                                           ('foo.bar', 'bar'),
3974*da0073e9SAndroid Build Coastguard Worker                                           ('foo.bar._is_full_backward_hook', None),
3975*da0073e9SAndroid Build Coastguard Worker                                           ('foo.bar.an_int', 4),
3976*da0073e9SAndroid Build Coastguard Worker                                           ('foo.bar.name', 'bar'),
3977*da0073e9SAndroid Build Coastguard Worker                                           ('foo.bar.training', True),
3978*da0073e9SAndroid Build Coastguard Worker                                           ('foo.name', 'foo'),
3979*da0073e9SAndroid Build Coastguard Worker                                           ('foo.training', True),
3980*da0073e9SAndroid Build Coastguard Worker                                           ('name', 'a'),
3981*da0073e9SAndroid Build Coastguard Worker                                           ('p', 'P'),
3982*da0073e9SAndroid Build Coastguard Worker                                           ('training', True)],
3983*da0073e9SAndroid Build Coastguard Worker                    'named_buffers': [],
3984*da0073e9SAndroid Build Coastguard Worker                    'named_buffers_r': [('foo.b', 'B')],
3985*da0073e9SAndroid Build Coastguard Worker                    'named_children': [('another', 'another'), ('foo', 'foo')],
3986*da0073e9SAndroid Build Coastguard Worker                    'named_modules': [('', 'a'),
3987*da0073e9SAndroid Build Coastguard Worker                                      ('another', 'another'),
3988*da0073e9SAndroid Build Coastguard Worker                                      ('foo', 'foo'),
3989*da0073e9SAndroid Build Coastguard Worker                                      ('foo.bar', 'bar')],
3990*da0073e9SAndroid Build Coastguard Worker                    'named_parameters': [('p', 'P')],
3991*da0073e9SAndroid Build Coastguard Worker                    'named_parameters_r': [('p', 'P')],
3992*da0073e9SAndroid Build Coastguard Worker                    'parameters': ['P'],
3993*da0073e9SAndroid Build Coastguard Worker                    'parameters_r': ['P']}
3994*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, result)
3995*da0073e9SAndroid Build Coastguard Worker
3996*da0073e9SAndroid Build Coastguard Worker    def test_parameter_order(self):
3997*da0073e9SAndroid Build Coastguard Worker        m = nn.Module()
3998*da0073e9SAndroid Build Coastguard Worker        for i, name in enumerate(string.ascii_letters):
3999*da0073e9SAndroid Build Coastguard Worker            setattr(m, name, nn.Parameter(torch.tensor([float(i)])))
4000*da0073e9SAndroid Build Coastguard Worker        ms = torch.jit.script(m)
4001*da0073e9SAndroid Build Coastguard Worker        print(torch.cat(list(m.parameters())))
4002*da0073e9SAndroid Build Coastguard Worker        print(torch.cat(list(ms.parameters())))
4003*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(m.parameters()), list(ms.parameters()))
4004*da0073e9SAndroid Build Coastguard Worker
4005*da0073e9SAndroid Build Coastguard Worker    def test_python_op_builtins(self):
4006*da0073e9SAndroid Build Coastguard Worker        @torch.jit.unused
4007*da0073e9SAndroid Build Coastguard Worker        def fn(x):
4008*da0073e9SAndroid Build Coastguard Worker            # type: (List[int]) -> int
4009*da0073e9SAndroid Build Coastguard Worker            return sum(x)
4010*da0073e9SAndroid Build Coastguard Worker
4011*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
4012*da0073e9SAndroid Build Coastguard Worker        def script_fn(x):
4013*da0073e9SAndroid Build Coastguard Worker            # type: (List[int]) -> int
4014*da0073e9SAndroid Build Coastguard Worker            return fn(x)
4015*da0073e9SAndroid Build Coastguard Worker
4016*da0073e9SAndroid Build Coastguard Worker    def test_submodule_twice(self):
4017*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
4018*da0073e9SAndroid Build Coastguard Worker        def foo(x):
4019*da0073e9SAndroid Build Coastguard Worker            return x * x
4020*da0073e9SAndroid Build Coastguard Worker
4021*da0073e9SAndroid Build Coastguard Worker        class What(torch.jit.ScriptModule):
4022*da0073e9SAndroid Build Coastguard Worker            def __init__(self, x):
4023*da0073e9SAndroid Build Coastguard Worker                super().__init__()
4024*da0073e9SAndroid Build Coastguard Worker                self.foo = x
4025*da0073e9SAndroid Build Coastguard Worker        a = What(foo)
4026*da0073e9SAndroid Build Coastguard Worker        c = What(foo)
4027*da0073e9SAndroid Build Coastguard Worker
4028*da0073e9SAndroid Build Coastguard Worker    def test_training_param(self):
4029*da0073e9SAndroid Build Coastguard Worker        class What(torch.jit.ScriptModule):
4030*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
4031*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
4032*da0073e9SAndroid Build Coastguard Worker                # type: (int) -> int
4033*da0073e9SAndroid Build Coastguard Worker                if self.training:
4034*da0073e9SAndroid Build Coastguard Worker                    r = x
4035*da0073e9SAndroid Build Coastguard Worker                else:
4036*da0073e9SAndroid Build Coastguard Worker                    r = x + 4
4037*da0073e9SAndroid Build Coastguard Worker                # check double use of training
4038*da0073e9SAndroid Build Coastguard Worker                if self.training:
4039*da0073e9SAndroid Build Coastguard Worker                    r = r + 1
4040*da0073e9SAndroid Build Coastguard Worker                return r
4041*da0073e9SAndroid Build Coastguard Worker
4042*da0073e9SAndroid Build Coastguard Worker        w = What()
4043*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(4, w(3))
4044*da0073e9SAndroid Build Coastguard Worker        w.train(False)
4045*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(7, w(3))
4046*da0073e9SAndroid Build Coastguard Worker        self.assertFalse("training" in w.state_dict())
4047*da0073e9SAndroid Build Coastguard Worker
4048*da0073e9SAndroid Build Coastguard Worker    def test_class_as_attribute(self):
4049*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
4050*da0073e9SAndroid Build Coastguard Worker        class Foo321:
4051*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
4052*da0073e9SAndroid Build Coastguard Worker                self.x = 3
4053*da0073e9SAndroid Build Coastguard Worker
4054*da0073e9SAndroid Build Coastguard Worker        class FooBar1234(torch.nn.Module):
4055*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
4056*da0073e9SAndroid Build Coastguard Worker                super().__init__()
4057*da0073e9SAndroid Build Coastguard Worker                self.f = Foo321()
4058*da0073e9SAndroid Build Coastguard Worker
4059*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
4060*da0073e9SAndroid Build Coastguard Worker                return x + self.f.x
4061*da0073e9SAndroid Build Coastguard Worker
4062*da0073e9SAndroid Build Coastguard Worker        scripted = torch.jit.script(FooBar1234())
4063*da0073e9SAndroid Build Coastguard Worker        eic = self.getExportImportCopy(scripted)
4064*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
4065*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted(x), eic(x))
4066*da0073e9SAndroid Build Coastguard Worker
4067*da0073e9SAndroid Build Coastguard Worker    def test_module_str(self):
4068*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
4069*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
4070*da0073e9SAndroid Build Coastguard Worker                return torch.relu(x)
4071*da0073e9SAndroid Build Coastguard Worker
4072*da0073e9SAndroid Build Coastguard Worker        f = torch.jit.script(Foo())
4073*da0073e9SAndroid Build Coastguard Worker
4074*da0073e9SAndroid Build Coastguard Worker        str_f = str(f._c)
4075*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(str_f.startswith('ScriptObject'))
4076*da0073e9SAndroid Build Coastguard Worker        self.assertTrue('__torch__.' in str_f)
4077*da0073e9SAndroid Build Coastguard Worker        self.assertTrue('.Foo' in str_f)
4078*da0073e9SAndroid Build Coastguard Worker
4079*da0073e9SAndroid Build Coastguard Worker    def test_jitter_bug(self):
4080*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
4081*da0073e9SAndroid Build Coastguard Worker        def fn2(input, kernel_size):
4082*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor, List[int]) -> Tensor
4083*da0073e9SAndroid Build Coastguard Worker            if kernel_size[0] > 1:
4084*da0073e9SAndroid Build Coastguard Worker                _stride = [2]
4085*da0073e9SAndroid Build Coastguard Worker            else:
4086*da0073e9SAndroid Build Coastguard Worker                _stride = kernel_size
4087*da0073e9SAndroid Build Coastguard Worker            print(_stride, kernel_size)
4088*da0073e9SAndroid Build Coastguard Worker            return input
4089*da0073e9SAndroid Build Coastguard Worker
4090*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
4091*da0073e9SAndroid Build Coastguard Worker        def fn(input):
4092*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor) -> Tensor
4093*da0073e9SAndroid Build Coastguard Worker            return fn2(input, [1])
4094*da0073e9SAndroid Build Coastguard Worker
4095*da0073e9SAndroid Build Coastguard Worker    def test_parser_kwargonly(self):
4096*da0073e9SAndroid Build Coastguard Worker        cu = torch.jit.CompilationUnit('''
4097*da0073e9SAndroid Build Coastguard Worker            def foo(x, *, y) -> Tuple[Tensor, Tensor]:
4098*da0073e9SAndroid Build Coastguard Worker                return x, x
4099*da0073e9SAndroid Build Coastguard Worker            def bar(x):
4100*da0073e9SAndroid Build Coastguard Worker                return foo(x, y=x)
4101*da0073e9SAndroid Build Coastguard Worker        ''')
4102*da0073e9SAndroid Build Coastguard Worker        self.assertTrue('*' in str(cu.foo.schema))
4103*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "not provided"):
4104*da0073e9SAndroid Build Coastguard Worker            torch.jit.CompilationUnit('''
4105*da0073e9SAndroid Build Coastguard Worker                def foo(x, *, y) -> Tuple[Tensor, Tensor]:
4106*da0073e9SAndroid Build Coastguard Worker                    return x, x
4107*da0073e9SAndroid Build Coastguard Worker                def bar(x):
4108*da0073e9SAndroid Build Coastguard Worker                    return foo(x, x)
4109*da0073e9SAndroid Build Coastguard Worker            ''')
4110*da0073e9SAndroid Build Coastguard Worker
4111*da0073e9SAndroid Build Coastguard Worker    def test_annoying_doubles(self):
4112*da0073e9SAndroid Build Coastguard Worker        mod = types.ModuleType("temp")
4113*da0073e9SAndroid Build Coastguard Worker        mod.inf = float("inf")
4114*da0073e9SAndroid Build Coastguard Worker        mod.ninf = float("-inf")
4115*da0073e9SAndroid Build Coastguard Worker        mod.nan = float("nan")
4116*da0073e9SAndroid Build Coastguard Worker
4117*da0073e9SAndroid Build Coastguard Worker        with torch._jit_internal._disable_emit_hooks():
4118*da0073e9SAndroid Build Coastguard Worker            class Foo(torch.jit.ScriptModule):
4119*da0073e9SAndroid Build Coastguard Worker                @torch.jit.script_method
4120*da0073e9SAndroid Build Coastguard Worker                def forward(self):
4121*da0073e9SAndroid Build Coastguard Worker                    return math.pi, 0.1, mod.inf, mod.ninf, 2.225073858507201e-308, mod.nan
4122*da0073e9SAndroid Build Coastguard Worker
4123*da0073e9SAndroid Build Coastguard Worker            foo = Foo()
4124*da0073e9SAndroid Build Coastguard Worker            buffer = io.BytesIO()
4125*da0073e9SAndroid Build Coastguard Worker            torch.jit.save(foo, buffer)
4126*da0073e9SAndroid Build Coastguard Worker
4127*da0073e9SAndroid Build Coastguard Worker            buffer.seek(0)
4128*da0073e9SAndroid Build Coastguard Worker            foo_loaded = torch.jit.load(buffer)
4129*da0073e9SAndroid Build Coastguard Worker
4130*da0073e9SAndroid Build Coastguard Worker            r = foo()
4131*da0073e9SAndroid Build Coastguard Worker            r2 = foo_loaded()
4132*da0073e9SAndroid Build Coastguard Worker            # use precise assert, we are checking floating point details
4133*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(r[:-1] == r2[:-1])
4134*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(math.isnan(r[-1]) and math.isnan(r2[-1]))
4135*da0073e9SAndroid Build Coastguard Worker
4136*da0073e9SAndroid Build Coastguard Worker    def test_type_annotate(self):
4137*da0073e9SAndroid Build Coastguard Worker
4138*da0073e9SAndroid Build Coastguard Worker        def foo(a):
4139*da0073e9SAndroid Build Coastguard Worker            return torch.jit.annotate(torch.Tensor, a)
4140*da0073e9SAndroid Build Coastguard Worker
4141*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, (torch.rand(3),))
4142*da0073e9SAndroid Build Coastguard Worker
4143*da0073e9SAndroid Build Coastguard Worker        def bar():
4144*da0073e9SAndroid Build Coastguard Worker            a = torch.jit.annotate(List[int], [])
4145*da0073e9SAndroid Build Coastguard Worker            for _ in range(10):
4146*da0073e9SAndroid Build Coastguard Worker                a.append(4)
4147*da0073e9SAndroid Build Coastguard Worker            return a
4148*da0073e9SAndroid Build Coastguard Worker
4149*da0073e9SAndroid Build Coastguard Worker        self.checkScript(bar, ())
4150*da0073e9SAndroid Build Coastguard Worker
4151*da0073e9SAndroid Build Coastguard Worker        def baz(a):
4152*da0073e9SAndroid Build Coastguard Worker            return torch.jit.annotate(float, a)
4153*da0073e9SAndroid Build Coastguard Worker        self.checkScript(baz, (torch.rand(()),))
4154*da0073e9SAndroid Build Coastguard Worker
4155*da0073e9SAndroid Build Coastguard Worker        # test annotate none types
4156*da0073e9SAndroid Build Coastguard Worker        def annotate_none():
4157*da0073e9SAndroid Build Coastguard Worker            return torch.jit.annotate(Optional[torch.Tensor], None)
4158*da0073e9SAndroid Build Coastguard Worker
4159*da0073e9SAndroid Build Coastguard Worker        self.checkScript(annotate_none, ())
4160*da0073e9SAndroid Build Coastguard Worker
4161*da0073e9SAndroid Build Coastguard Worker
4162*da0073e9SAndroid Build Coastguard Worker    def test_robust_op_resolution(self):
4163*da0073e9SAndroid Build Coastguard Worker        neg = torch.add  # misleading name to make sure we resolve by function
4164*da0073e9SAndroid Build Coastguard Worker
4165*da0073e9SAndroid Build Coastguard Worker        def stuff(x):
4166*da0073e9SAndroid Build Coastguard Worker            return neg(x, x)
4167*da0073e9SAndroid Build Coastguard Worker
4168*da0073e9SAndroid Build Coastguard Worker        a = (torch.rand(3),)
4169*da0073e9SAndroid Build Coastguard Worker        self.checkScript(stuff, a)
4170*da0073e9SAndroid Build Coastguard Worker
4171*da0073e9SAndroid Build Coastguard Worker    def test_nested_aug_assign(self):
4172*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
4173*da0073e9SAndroid Build Coastguard Worker        class SomeClass:
4174*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
4175*da0073e9SAndroid Build Coastguard Worker                self.num = 99
4176*da0073e9SAndroid Build Coastguard Worker
4177*da0073e9SAndroid Build Coastguard Worker            def __iadd__(self, x):
4178*da0073e9SAndroid Build Coastguard Worker                # type: (int)
4179*da0073e9SAndroid Build Coastguard Worker                self.num += x
4180*da0073e9SAndroid Build Coastguard Worker                return self
4181*da0073e9SAndroid Build Coastguard Worker
4182*da0073e9SAndroid Build Coastguard Worker            def __eq__(self, other):
4183*da0073e9SAndroid Build Coastguard Worker                # type: (SomeClass) -> bool
4184*da0073e9SAndroid Build Coastguard Worker                return self.num == other.num
4185*da0073e9SAndroid Build Coastguard Worker
4186*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
4187*da0073e9SAndroid Build Coastguard Worker        class SomeOutOfPlaceClass:
4188*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
4189*da0073e9SAndroid Build Coastguard Worker                self.num = 99
4190*da0073e9SAndroid Build Coastguard Worker
4191*da0073e9SAndroid Build Coastguard Worker            def __add__(self, x):
4192*da0073e9SAndroid Build Coastguard Worker                # type: (int)
4193*da0073e9SAndroid Build Coastguard Worker                self.num = x
4194*da0073e9SAndroid Build Coastguard Worker                return self
4195*da0073e9SAndroid Build Coastguard Worker
4196*da0073e9SAndroid Build Coastguard Worker            def __eq__(self, other):
4197*da0073e9SAndroid Build Coastguard Worker                # type: (SomeClass) -> bool
4198*da0073e9SAndroid Build Coastguard Worker                return self.num == other.num
4199*da0073e9SAndroid Build Coastguard Worker
4200*da0073e9SAndroid Build Coastguard Worker        class Child(nn.Module):
4201*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
4202*da0073e9SAndroid Build Coastguard Worker                super().__init__()
4203*da0073e9SAndroid Build Coastguard Worker                self.x = 2
4204*da0073e9SAndroid Build Coastguard Worker                self.o = SomeClass()
4205*da0073e9SAndroid Build Coastguard Worker                self.oop = SomeOutOfPlaceClass()
4206*da0073e9SAndroid Build Coastguard Worker                self.list = [1, 2, 3]
4207*da0073e9SAndroid Build Coastguard Worker
4208*da0073e9SAndroid Build Coastguard Worker        class A(nn.Module):
4209*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
4210*da0073e9SAndroid Build Coastguard Worker                super().__init__()
4211*da0073e9SAndroid Build Coastguard Worker                self.child = Child()
4212*da0073e9SAndroid Build Coastguard Worker
4213*da0073e9SAndroid Build Coastguard Worker            def forward(self):
4214*da0073e9SAndroid Build Coastguard Worker                self.child.x += 1
4215*da0073e9SAndroid Build Coastguard Worker                self.child.o += 5
4216*da0073e9SAndroid Build Coastguard Worker                self.child.oop += 5
4217*da0073e9SAndroid Build Coastguard Worker                some_list = [1, 2]
4218*da0073e9SAndroid Build Coastguard Worker                self.child.list += some_list
4219*da0073e9SAndroid Build Coastguard Worker                self.child.list *= 2
4220*da0073e9SAndroid Build Coastguard Worker                return self.child.x, self.child.o, self.child.list, self.child.oop
4221*da0073e9SAndroid Build Coastguard Worker
4222*da0073e9SAndroid Build Coastguard Worker        a = A()
4223*da0073e9SAndroid Build Coastguard Worker        sa = torch.jit.script(A())
4224*da0073e9SAndroid Build Coastguard Worker        eager_result = a()
4225*da0073e9SAndroid Build Coastguard Worker        script_result = sa()
4226*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager_result, script_result)
4227*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.child.x, sa.child.x)
4228*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.child.o, sa.child.o)
4229*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.child.list, sa.child.list)
4230*da0073e9SAndroid Build Coastguard Worker
4231*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
4232*da0073e9SAndroid Build Coastguard Worker        class SomeNonAddableClass:
4233*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
4234*da0073e9SAndroid Build Coastguard Worker                self.num = 99
4235*da0073e9SAndroid Build Coastguard Worker
4236*da0073e9SAndroid Build Coastguard Worker            def __eq__(self, other):
4237*da0073e9SAndroid Build Coastguard Worker                # type: (SomeClass) -> bool
4238*da0073e9SAndroid Build Coastguard Worker                return self.num == other.num
4239*da0073e9SAndroid Build Coastguard Worker
4240*da0073e9SAndroid Build Coastguard Worker        # with self.assertRaisesRegex(RuntimeError, "")
4241*da0073e9SAndroid Build Coastguard Worker        class A(nn.Module):
4242*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
4243*da0073e9SAndroid Build Coastguard Worker                super().__init__()
4244*da0073e9SAndroid Build Coastguard Worker                self.x = SomeNonAddableClass()
4245*da0073e9SAndroid Build Coastguard Worker
4246*da0073e9SAndroid Build Coastguard Worker            def forward(self):
4247*da0073e9SAndroid Build Coastguard Worker                self.x += SomeNonAddableClass()
4248*da0073e9SAndroid Build Coastguard Worker                return self.x
4249*da0073e9SAndroid Build Coastguard Worker
4250*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Cannot emit inplace op"):
4251*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(A())
4252*da0073e9SAndroid Build Coastguard Worker
4253*da0073e9SAndroid Build Coastguard Worker    def test_var_aug_assign(self):
4254*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
4255*da0073e9SAndroid Build Coastguard Worker        class SomeNonAddableClass:
4256*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
4257*da0073e9SAndroid Build Coastguard Worker                self.num = 99
4258*da0073e9SAndroid Build Coastguard Worker
4259*da0073e9SAndroid Build Coastguard Worker            def __eq__(self, other):
4260*da0073e9SAndroid Build Coastguard Worker                # type: (SomeNonAddableClass) -> bool
4261*da0073e9SAndroid Build Coastguard Worker                return self.num == other.num
4262*da0073e9SAndroid Build Coastguard Worker
4263*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Cannot emit inplace op"):
4264*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
4265*da0073e9SAndroid Build Coastguard Worker            def fn():
4266*da0073e9SAndroid Build Coastguard Worker                a = SomeNonAddableClass()
4267*da0073e9SAndroid Build Coastguard Worker                a += SomeNonAddableClass()
4268*da0073e9SAndroid Build Coastguard Worker                return a
4269*da0073e9SAndroid Build Coastguard Worker
4270*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
4271*da0073e9SAndroid Build Coastguard Worker        class SomeClass:
4272*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
4273*da0073e9SAndroid Build Coastguard Worker                self.num = 99
4274*da0073e9SAndroid Build Coastguard Worker
4275*da0073e9SAndroid Build Coastguard Worker            def __iadd__(self, x):
4276*da0073e9SAndroid Build Coastguard Worker                # type: (int)
4277*da0073e9SAndroid Build Coastguard Worker                self.num += x
4278*da0073e9SAndroid Build Coastguard Worker                return self
4279*da0073e9SAndroid Build Coastguard Worker
4280*da0073e9SAndroid Build Coastguard Worker            def __eq__(self, other):
4281*da0073e9SAndroid Build Coastguard Worker                # type: (SomeClass) -> bool
4282*da0073e9SAndroid Build Coastguard Worker                return self.num == other.num
4283*da0073e9SAndroid Build Coastguard Worker
4284*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
4285*da0073e9SAndroid Build Coastguard Worker        class SomeOutOfPlaceClass:
4286*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
4287*da0073e9SAndroid Build Coastguard Worker                self.num = 99
4288*da0073e9SAndroid Build Coastguard Worker
4289*da0073e9SAndroid Build Coastguard Worker            def __add__(self, x):
4290*da0073e9SAndroid Build Coastguard Worker                # type: (int)
4291*da0073e9SAndroid Build Coastguard Worker                self.num = x
4292*da0073e9SAndroid Build Coastguard Worker                return self
4293*da0073e9SAndroid Build Coastguard Worker
4294*da0073e9SAndroid Build Coastguard Worker            def __eq__(self, other):
4295*da0073e9SAndroid Build Coastguard Worker                # type: (SomeClass) -> bool
4296*da0073e9SAndroid Build Coastguard Worker                return self.num == other.num
4297*da0073e9SAndroid Build Coastguard Worker
4298*da0073e9SAndroid Build Coastguard Worker        def fn2():
4299*da0073e9SAndroid Build Coastguard Worker            a = SomeClass()
4300*da0073e9SAndroid Build Coastguard Worker            a_copy = a
4301*da0073e9SAndroid Build Coastguard Worker            a += 20
4302*da0073e9SAndroid Build Coastguard Worker            assert a is a_copy
4303*da0073e9SAndroid Build Coastguard Worker            b = SomeOutOfPlaceClass()
4304*da0073e9SAndroid Build Coastguard Worker            b_copy = b
4305*da0073e9SAndroid Build Coastguard Worker            b += 99
4306*da0073e9SAndroid Build Coastguard Worker            assert b is b_copy
4307*da0073e9SAndroid Build Coastguard Worker            c = [1, 2, 3]
4308*da0073e9SAndroid Build Coastguard Worker            c_copy = c
4309*da0073e9SAndroid Build Coastguard Worker            c *= 2
4310*da0073e9SAndroid Build Coastguard Worker            assert c is c_copy
4311*da0073e9SAndroid Build Coastguard Worker            c += [4, 5, 6]
4312*da0073e9SAndroid Build Coastguard Worker            d = torch.ones(2, 2)
4313*da0073e9SAndroid Build Coastguard Worker            d_copy = d
4314*da0073e9SAndroid Build Coastguard Worker            d += torch.ones(2, 2)
4315*da0073e9SAndroid Build Coastguard Worker            assert d is d_copy
4316*da0073e9SAndroid Build Coastguard Worker            return a, b, c, d
4317*da0073e9SAndroid Build Coastguard Worker
4318*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn2, [])
4319*da0073e9SAndroid Build Coastguard Worker
4320*da0073e9SAndroid Build Coastguard Worker    def test_nested_list_construct(self):
4321*da0073e9SAndroid Build Coastguard Worker        def foo():
4322*da0073e9SAndroid Build Coastguard Worker            return [[4]] + [[4, 5]]
4323*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, ())
4324*da0073e9SAndroid Build Coastguard Worker
4325*da0073e9SAndroid Build Coastguard Worker    def test_file_line_error(self):
4326*da0073e9SAndroid Build Coastguard Worker        def foobar(xyz):
4327*da0073e9SAndroid Build Coastguard Worker            return torch.blargh(xyz)
4328*da0073e9SAndroid Build Coastguard Worker
4329*da0073e9SAndroid Build Coastguard Worker        _, lineno = inspect.getsourcelines(foobar)
4330*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, f'test_jit.py", line {lineno + 1}'):
4331*da0073e9SAndroid Build Coastguard Worker            scripted = torch.jit.script(foobar)
4332*da0073e9SAndroid Build Coastguard Worker
4333*da0073e9SAndroid Build Coastguard Worker    def test_file_line_error_class_defn(self):
4334*da0073e9SAndroid Build Coastguard Worker        class FooBar:
4335*da0073e9SAndroid Build Coastguard Worker            def baz(self, xyz):
4336*da0073e9SAndroid Build Coastguard Worker                return torch.blargh(xyz)
4337*da0073e9SAndroid Build Coastguard Worker
4338*da0073e9SAndroid Build Coastguard Worker        _, lineno = inspect.getsourcelines(FooBar)
4339*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, f'test_jit.py", line {lineno + 2}'):
4340*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(FooBar)
4341*da0073e9SAndroid Build Coastguard Worker
4342*da0073e9SAndroid Build Coastguard Worker    def test_file_line_graph(self):
4343*da0073e9SAndroid Build Coastguard Worker        def foobar(xyz):
4344*da0073e9SAndroid Build Coastguard Worker            return torch.neg(xyz)
4345*da0073e9SAndroid Build Coastguard Worker
4346*da0073e9SAndroid Build Coastguard Worker        scripted = torch.jit.script(foobar)
4347*da0073e9SAndroid Build Coastguard Worker
4348*da0073e9SAndroid Build Coastguard Worker        _, lineno = inspect.getsourcelines(foobar)
4349*da0073e9SAndroid Build Coastguard Worker        fc = FileCheck().check(f'test_jit.py:{lineno + 1}:19')
4350*da0073e9SAndroid Build Coastguard Worker        fc.run(scripted.graph)
4351*da0073e9SAndroid Build Coastguard Worker        fc.run(str(scripted.graph))
4352*da0073e9SAndroid Build Coastguard Worker
4353*da0073e9SAndroid Build Coastguard Worker    def test_file_line_save_load(self):
4354*da0073e9SAndroid Build Coastguard Worker        class Scripted(torch.jit.ScriptModule):
4355*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
4356*da0073e9SAndroid Build Coastguard Worker            def forward(self, xyz):
4357*da0073e9SAndroid Build Coastguard Worker                return torch.neg(xyz)
4358*da0073e9SAndroid Build Coastguard Worker
4359*da0073e9SAndroid Build Coastguard Worker        scripted = Scripted()
4360*da0073e9SAndroid Build Coastguard Worker
4361*da0073e9SAndroid Build Coastguard Worker        # NB: not using getExportImportCopy because that takes a different
4362*da0073e9SAndroid Build Coastguard Worker        # code path that calls CompilationUnit._import rather than
4363*da0073e9SAndroid Build Coastguard Worker        # going through the full save/load pathway
4364*da0073e9SAndroid Build Coastguard Worker        buffer = scripted.save_to_buffer()
4365*da0073e9SAndroid Build Coastguard Worker        bytesio = io.BytesIO(buffer)
4366*da0073e9SAndroid Build Coastguard Worker        scripted = torch.jit.load(bytesio)
4367*da0073e9SAndroid Build Coastguard Worker
4368*da0073e9SAndroid Build Coastguard Worker        _, lineno = inspect.getsourcelines(Scripted)
4369*da0073e9SAndroid Build Coastguard Worker        fc = FileCheck().check(f':{lineno + 3}')
4370*da0073e9SAndroid Build Coastguard Worker        fc.run(scripted.graph)
4371*da0073e9SAndroid Build Coastguard Worker        fc.run(str(scripted.graph))
4372*da0073e9SAndroid Build Coastguard Worker
4373*da0073e9SAndroid Build Coastguard Worker    def test_file_line_string(self):
4374*da0073e9SAndroid Build Coastguard Worker        scripted = torch.jit.CompilationUnit('''
4375*da0073e9SAndroid Build Coastguard Workerdef foo(xyz):
4376*da0073e9SAndroid Build Coastguard Worker    return torch.neg(xyz)
4377*da0073e9SAndroid Build Coastguard Worker        ''')
4378*da0073e9SAndroid Build Coastguard Worker
4379*da0073e9SAndroid Build Coastguard Worker        fc = FileCheck().check('<string>:3:11')
4380*da0073e9SAndroid Build Coastguard Worker        fc.run(scripted.foo.graph)
4381*da0073e9SAndroid Build Coastguard Worker        fc.run(str(scripted.foo.graph))
4382*da0073e9SAndroid Build Coastguard Worker
4383*da0073e9SAndroid Build Coastguard Worker    @skipIfCrossRef
4384*da0073e9SAndroid Build Coastguard Worker    def test_file_line_trace(self):
4385*da0073e9SAndroid Build Coastguard Worker        def foobar(xyz):
4386*da0073e9SAndroid Build Coastguard Worker            return torch.neg(xyz)
4387*da0073e9SAndroid Build Coastguard Worker
4388*da0073e9SAndroid Build Coastguard Worker        scripted = torch.jit.trace(foobar, (torch.rand(3, 4)))
4389*da0073e9SAndroid Build Coastguard Worker
4390*da0073e9SAndroid Build Coastguard Worker        _, lineno = inspect.getsourcelines(foobar)
4391*da0073e9SAndroid Build Coastguard Worker        fc = FileCheck().check(f'test_jit.py:{lineno + 1}:0')
4392*da0073e9SAndroid Build Coastguard Worker        fc.run(scripted.graph)
4393*da0073e9SAndroid Build Coastguard Worker        fc.run(str(scripted.graph))
4394*da0073e9SAndroid Build Coastguard Worker
4395*da0073e9SAndroid Build Coastguard Worker    def test_serialized_source_ranges(self):
4396*da0073e9SAndroid Build Coastguard Worker
4397*da0073e9SAndroid Build Coastguard Worker        class FooTest(torch.jit.ScriptModule):
4398*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
4399*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, w):
4400*da0073e9SAndroid Build Coastguard Worker                return torch.mm(x, w.t())
4401*da0073e9SAndroid Build Coastguard Worker
4402*da0073e9SAndroid Build Coastguard Worker        ft = FooTest()
4403*da0073e9SAndroid Build Coastguard Worker        loaded = self.getExportImportCopy(ft)
4404*da0073e9SAndroid Build Coastguard Worker        _, lineno = inspect.getsourcelines(FooTest)
4405*da0073e9SAndroid Build Coastguard Worker
4406*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, f'test_jit.py", line {lineno + 3}'):
4407*da0073e9SAndroid Build Coastguard Worker            loaded(torch.rand(3, 4), torch.rand(30, 40))
4408*da0073e9SAndroid Build Coastguard Worker
4409*da0073e9SAndroid Build Coastguard Worker    def test_serialized_source_ranges_graph(self):
4410*da0073e9SAndroid Build Coastguard Worker
4411*da0073e9SAndroid Build Coastguard Worker        class FooTest3(torch.jit.ScriptModule):
4412*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
4413*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, w):
4414*da0073e9SAndroid Build Coastguard Worker                return torch.mm(x, w.t())
4415*da0073e9SAndroid Build Coastguard Worker
4416*da0073e9SAndroid Build Coastguard Worker        ft = FooTest3()
4417*da0073e9SAndroid Build Coastguard Worker        loaded = self.getExportImportCopy(ft)
4418*da0073e9SAndroid Build Coastguard Worker        _, lineno = inspect.getsourcelines(FooTest3)
4419*da0073e9SAndroid Build Coastguard Worker
4420*da0073e9SAndroid Build Coastguard Worker        fc = FileCheck().check(f'test_jit.py:{lineno + 3}')
4421*da0073e9SAndroid Build Coastguard Worker        fc.run(loaded.graph)
4422*da0073e9SAndroid Build Coastguard Worker
4423*da0073e9SAndroid Build Coastguard Worker    def test_serialized_source_ranges2(self):
4424*da0073e9SAndroid Build Coastguard Worker
4425*da0073e9SAndroid Build Coastguard Worker        class FooTest2(torch.jit.ScriptModule):
4426*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
4427*da0073e9SAndroid Build Coastguard Worker            def forward(self):
4428*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError('foo')
4429*da0073e9SAndroid Build Coastguard Worker
4430*da0073e9SAndroid Build Coastguard Worker        _, lineno = inspect.getsourcelines(FooTest2)
4431*da0073e9SAndroid Build Coastguard Worker
4432*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(torch.jit.Error, f'test_jit.py", line {lineno + 3}'):
4433*da0073e9SAndroid Build Coastguard Worker            ft = FooTest2()
4434*da0073e9SAndroid Build Coastguard Worker            loaded = self.getExportImportCopy(ft)
4435*da0073e9SAndroid Build Coastguard Worker            loaded()
4436*da0073e9SAndroid Build Coastguard Worker
4437*da0073e9SAndroid Build Coastguard Worker    def test_serialized_source_ranges_dont_jitter(self):
4438*da0073e9SAndroid Build Coastguard Worker        class FooTest3(torch.jit.ScriptModule):
4439*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
4440*da0073e9SAndroid Build Coastguard Worker            def forward(self, lim):
4441*da0073e9SAndroid Build Coastguard Worker                first = 1
4442*da0073e9SAndroid Build Coastguard Worker                second = 1
4443*da0073e9SAndroid Build Coastguard Worker                i = 1
4444*da0073e9SAndroid Build Coastguard Worker                somenum = 5
4445*da0073e9SAndroid Build Coastguard Worker                dontmutateme = 3
4446*da0073e9SAndroid Build Coastguard Worker                third = 0
4447*da0073e9SAndroid Build Coastguard Worker                while bool(i < lim):
4448*da0073e9SAndroid Build Coastguard Worker                    third = first + second
4449*da0073e9SAndroid Build Coastguard Worker                    first = second
4450*da0073e9SAndroid Build Coastguard Worker                    second = third
4451*da0073e9SAndroid Build Coastguard Worker                    j = 0
4452*da0073e9SAndroid Build Coastguard Worker                    while j < 10:
4453*da0073e9SAndroid Build Coastguard Worker                        somenum = somenum * 2
4454*da0073e9SAndroid Build Coastguard Worker                        j = j + 1
4455*da0073e9SAndroid Build Coastguard Worker                    i = i + j
4456*da0073e9SAndroid Build Coastguard Worker                    i = i + dontmutateme
4457*da0073e9SAndroid Build Coastguard Worker
4458*da0073e9SAndroid Build Coastguard Worker                st = second + third
4459*da0073e9SAndroid Build Coastguard Worker                fs = first + second
4460*da0073e9SAndroid Build Coastguard Worker                return third, st, fs
4461*da0073e9SAndroid Build Coastguard Worker
4462*da0073e9SAndroid Build Coastguard Worker        ft3 = FooTest3()
4463*da0073e9SAndroid Build Coastguard Worker
4464*da0073e9SAndroid Build Coastguard Worker        def debug_records_from_mod(self, mod):
4465*da0073e9SAndroid Build Coastguard Worker            buffer = io.BytesIO()
4466*da0073e9SAndroid Build Coastguard Worker            torch.jit.save(ft3, buffer)
4467*da0073e9SAndroid Build Coastguard Worker            buffer.seek(0)
4468*da0073e9SAndroid Build Coastguard Worker            archive = zipfile.ZipFile(buffer)
4469*da0073e9SAndroid Build Coastguard Worker            files = filter(lambda x: x.startswith('archive/code/'), archive.namelist())
4470*da0073e9SAndroid Build Coastguard Worker            debug_files = list(filter(lambda f: f.endswith('.debug_pkl'), files))
4471*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(debug_files), 1)
4472*da0073e9SAndroid Build Coastguard Worker            debug_file = archive.open(debug_files[0])
4473*da0073e9SAndroid Build Coastguard Worker            return pickle.load(debug_file), buffer
4474*da0073e9SAndroid Build Coastguard Worker
4475*da0073e9SAndroid Build Coastguard Worker        records1, buffer = debug_records_from_mod(self, ft3)
4476*da0073e9SAndroid Build Coastguard Worker
4477*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
4478*da0073e9SAndroid Build Coastguard Worker        loaded = torch.jit.load(buffer)
4479*da0073e9SAndroid Build Coastguard Worker        records2, buffer = debug_records_from_mod(self, loaded)
4480*da0073e9SAndroid Build Coastguard Worker
4481*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
4482*da0073e9SAndroid Build Coastguard Worker        loaded2 = torch.jit.load(buffer)
4483*da0073e9SAndroid Build Coastguard Worker        records3, _ = debug_records_from_mod(self, loaded2)
4484*da0073e9SAndroid Build Coastguard Worker
4485*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(records1, records2)
4486*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(records2, records3)
4487*da0073e9SAndroid Build Coastguard Worker
4488*da0073e9SAndroid Build Coastguard Worker    def test_serialized_source_ranges_no_dups(self):
4489*da0073e9SAndroid Build Coastguard Worker        class FooTest3(torch.jit.ScriptModule):
4490*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
4491*da0073e9SAndroid Build Coastguard Worker            def forward(self, lim):
4492*da0073e9SAndroid Build Coastguard Worker                first = 1
4493*da0073e9SAndroid Build Coastguard Worker                second = 1
4494*da0073e9SAndroid Build Coastguard Worker                i = 1
4495*da0073e9SAndroid Build Coastguard Worker                somenum = 5
4496*da0073e9SAndroid Build Coastguard Worker                dontmutateme = 3
4497*da0073e9SAndroid Build Coastguard Worker                third = 0
4498*da0073e9SAndroid Build Coastguard Worker                while bool(i < lim):
4499*da0073e9SAndroid Build Coastguard Worker                    third = first + second
4500*da0073e9SAndroid Build Coastguard Worker                    first = second
4501*da0073e9SAndroid Build Coastguard Worker                    second = third
4502*da0073e9SAndroid Build Coastguard Worker                    j = 0
4503*da0073e9SAndroid Build Coastguard Worker                    while j < 10:
4504*da0073e9SAndroid Build Coastguard Worker                        somenum = somenum * 2
4505*da0073e9SAndroid Build Coastguard Worker                        j = j + 1
4506*da0073e9SAndroid Build Coastguard Worker                    i = i + j
4507*da0073e9SAndroid Build Coastguard Worker                    i = i + dontmutateme
4508*da0073e9SAndroid Build Coastguard Worker
4509*da0073e9SAndroid Build Coastguard Worker                st = second + third
4510*da0073e9SAndroid Build Coastguard Worker                fs = first + second
4511*da0073e9SAndroid Build Coastguard Worker                return third, st, fs
4512*da0073e9SAndroid Build Coastguard Worker
4513*da0073e9SAndroid Build Coastguard Worker        ft3 = FooTest3()
4514*da0073e9SAndroid Build Coastguard Worker
4515*da0073e9SAndroid Build Coastguard Worker        def debug_records_from_mod(mod):
4516*da0073e9SAndroid Build Coastguard Worker            buffer = io.BytesIO()
4517*da0073e9SAndroid Build Coastguard Worker            torch.jit.save(ft3, buffer)
4518*da0073e9SAndroid Build Coastguard Worker            buffer.seek(0)
4519*da0073e9SAndroid Build Coastguard Worker            archive = zipfile.ZipFile(buffer)
4520*da0073e9SAndroid Build Coastguard Worker            files = list(filter(lambda x: x.startswith('archive/code/'), archive.namelist()))
4521*da0073e9SAndroid Build Coastguard Worker            debug_files = filter(lambda f: f.endswith('.debug_pkl'), files)
4522*da0073e9SAndroid Build Coastguard Worker            debug_files = (archive.open(f) for f in debug_files)
4523*da0073e9SAndroid Build Coastguard Worker            debug_files = (pickle.load(f) for f in debug_files)
4524*da0073e9SAndroid Build Coastguard Worker            debug_files = (f[2] for f in debug_files)
4525*da0073e9SAndroid Build Coastguard Worker            return list(debug_files)
4526*da0073e9SAndroid Build Coastguard Worker
4527*da0073e9SAndroid Build Coastguard Worker        debug_files = debug_records_from_mod(ft3)
4528*da0073e9SAndroid Build Coastguard Worker        for debug_file in debug_files:
4529*da0073e9SAndroid Build Coastguard Worker            for i in range(len(debug_file) - 1):
4530*da0073e9SAndroid Build Coastguard Worker                offset, source_range_tag, source_range = debug_file[i]
4531*da0073e9SAndroid Build Coastguard Worker                offset2, source_range_tag2, source_range2 = debug_file[i + 1]
4532*da0073e9SAndroid Build Coastguard Worker                self.assertNotEqual(source_range, source_range2)
4533*da0073e9SAndroid Build Coastguard Worker
4534*da0073e9SAndroid Build Coastguard Worker    def test_circular_dependency(self):
4535*da0073e9SAndroid Build Coastguard Worker        """
4536*da0073e9SAndroid Build Coastguard Worker        https://github.com/pytorch/pytorch/issues/25871
4537*da0073e9SAndroid Build Coastguard Worker        """
4538*da0073e9SAndroid Build Coastguard Worker        class A(torch.jit.ScriptModule):
4539*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
4540*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
4541*da0073e9SAndroid Build Coastguard Worker                return x
4542*da0073e9SAndroid Build Coastguard Worker
4543*da0073e9SAndroid Build Coastguard Worker        class B(torch.jit.ScriptModule):
4544*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
4545*da0073e9SAndroid Build Coastguard Worker                super().__init__()
4546*da0073e9SAndroid Build Coastguard Worker                self.foo = torch.nn.ModuleList([A()])
4547*da0073e9SAndroid Build Coastguard Worker
4548*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
4549*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
4550*da0073e9SAndroid Build Coastguard Worker                for f in self.foo:
4551*da0073e9SAndroid Build Coastguard Worker                    x = f(x)
4552*da0073e9SAndroid Build Coastguard Worker                return x
4553*da0073e9SAndroid Build Coastguard Worker
4554*da0073e9SAndroid Build Coastguard Worker        class C(torch.jit.ScriptModule):
4555*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
4556*da0073e9SAndroid Build Coastguard Worker                super().__init__()
4557*da0073e9SAndroid Build Coastguard Worker                self.foo = torch.nn.Sequential(B())
4558*da0073e9SAndroid Build Coastguard Worker
4559*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
4560*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
4561*da0073e9SAndroid Build Coastguard Worker                for f in self.foo:
4562*da0073e9SAndroid Build Coastguard Worker                    x = f(x)
4563*da0073e9SAndroid Build Coastguard Worker                return x
4564*da0073e9SAndroid Build Coastguard Worker        self.getExportImportCopy(C())
4565*da0073e9SAndroid Build Coastguard Worker
4566*da0073e9SAndroid Build Coastguard Worker    def test_serialize_long_lines(self):
4567*da0073e9SAndroid Build Coastguard Worker        class OrderModuleLong(torch.nn.Module):
4568*da0073e9SAndroid Build Coastguard Worker            def forward(self, long_arg_name: List[torch.Tensor]):
4569*da0073e9SAndroid Build Coastguard Worker                return [(long_arg_name[1],), (long_arg_name[0].argmax(),)]
4570*da0073e9SAndroid Build Coastguard Worker        src = str(torch.jit.script(OrderModuleLong()).code)
4571*da0073e9SAndroid Build Coastguard Worker        # make long_arg_name[1] does not get reordered after the argmax
4572*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("long_arg_name[1]").check("argmax").run(src)
4573*da0073e9SAndroid Build Coastguard Worker
4574*da0073e9SAndroid Build Coastguard Worker    def test_tensor_shape(self):
4575*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(34, 56, 78)
4576*da0073e9SAndroid Build Coastguard Worker
4577*da0073e9SAndroid Build Coastguard Worker        def f(x):
4578*da0073e9SAndroid Build Coastguard Worker            return x.shape
4579*da0073e9SAndroid Build Coastguard Worker
4580*da0073e9SAndroid Build Coastguard Worker        self.checkScript(f, (x,))
4581*da0073e9SAndroid Build Coastguard Worker
4582*da0073e9SAndroid Build Coastguard Worker
4583*da0073e9SAndroid Build Coastguard Worker    def test_block_input_grad_in_loop(self):
4584*da0073e9SAndroid Build Coastguard Worker
4585*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 3, requires_grad=False)
4586*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(3, 3, requires_grad=True)
4587*da0073e9SAndroid Build Coastguard Worker
4588*da0073e9SAndroid Build Coastguard Worker        def grad_in_loop(x, y):
4589*da0073e9SAndroid Build Coastguard Worker            for i in range(100):
4590*da0073e9SAndroid Build Coastguard Worker                x = y @ x
4591*da0073e9SAndroid Build Coastguard Worker            return x
4592*da0073e9SAndroid Build Coastguard Worker
4593*da0073e9SAndroid Build Coastguard Worker        scripted = torch.jit.script(grad_in_loop)
4594*da0073e9SAndroid Build Coastguard Worker        outer = scripted.graph_for(x, y)
4595*da0073e9SAndroid Build Coastguard Worker        loop = outer.findNode("prim::Loop")
4596*da0073e9SAndroid Build Coastguard Worker        loop_block = next(loop.blocks())
4597*da0073e9SAndroid Build Coastguard Worker        param_node = loop_block.paramNode()
4598*da0073e9SAndroid Build Coastguard Worker        x_value = list(param_node.outputs())[1]
4599*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x_value.requires_grad())
4600*da0073e9SAndroid Build Coastguard Worker
4601*da0073e9SAndroid Build Coastguard Worker    def test_tensor_grad(self):
4602*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 4, requires_grad=True)
4603*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(3, 4, requires_grad=False)
4604*da0073e9SAndroid Build Coastguard Worker
4605*da0073e9SAndroid Build Coastguard Worker        def f_requires_grad(x):
4606*da0073e9SAndroid Build Coastguard Worker            return x.requires_grad
4607*da0073e9SAndroid Build Coastguard Worker
4608*da0073e9SAndroid Build Coastguard Worker        self.checkScript(f_requires_grad, (x,))
4609*da0073e9SAndroid Build Coastguard Worker        self.checkScript(f_requires_grad, (y,))
4610*da0073e9SAndroid Build Coastguard Worker
4611*da0073e9SAndroid Build Coastguard Worker        def f_grad(x):
4612*da0073e9SAndroid Build Coastguard Worker            return x.grad
4613*da0073e9SAndroid Build Coastguard Worker
4614*da0073e9SAndroid Build Coastguard Worker        x.sum().backward()
4615*da0073e9SAndroid Build Coastguard Worker        self.checkScript(f_grad, (x,))
4616*da0073e9SAndroid Build Coastguard Worker        self.checkScript(f_grad, (y,))
4617*da0073e9SAndroid Build Coastguard Worker
4618*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "shape analysis is only enabled in Legacy")
4619*da0073e9SAndroid Build Coastguard Worker    def test_prim_grad_undefined(self):
4620*da0073e9SAndroid Build Coastguard Worker
4621*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2)
4622*da0073e9SAndroid Build Coastguard Worker
4623*da0073e9SAndroid Build Coastguard Worker        def f_grad(x):
4624*da0073e9SAndroid Build Coastguard Worker            return x.grad
4625*da0073e9SAndroid Build Coastguard Worker
4626*da0073e9SAndroid Build Coastguard Worker        scripted = self.checkScript(f_grad, (x,))
4627*da0073e9SAndroid Build Coastguard Worker        g = scripted.graph_for(x)
4628*da0073e9SAndroid Build Coastguard Worker
4629*da0073e9SAndroid Build Coastguard Worker        prim_grad_node = g.findNode("prim::grad")
4630*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(next(prim_grad_node.outputs()).type().undefined() is None)
4631*da0073e9SAndroid Build Coastguard Worker
4632*da0073e9SAndroid Build Coastguard Worker    def test_tensor_data(self):
4633*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 4, requires_grad=True)
4634*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(4, 5)
4635*da0073e9SAndroid Build Coastguard Worker
4636*da0073e9SAndroid Build Coastguard Worker        def f_data(x):
4637*da0073e9SAndroid Build Coastguard Worker            return x.data
4638*da0073e9SAndroid Build Coastguard Worker
4639*da0073e9SAndroid Build Coastguard Worker        scripted_f_data = torch.jit.script(f_data)
4640*da0073e9SAndroid Build Coastguard Worker
4641*da0073e9SAndroid Build Coastguard Worker        scripted_x = scripted_f_data(x)
4642*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_x, f_data(x))
4643*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_x.requires_grad, False)
4644*da0073e9SAndroid Build Coastguard Worker
4645*da0073e9SAndroid Build Coastguard Worker        scripted_y = scripted_f_data(y)
4646*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_y, f_data(y))
4647*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_x.requires_grad, False)
4648*da0073e9SAndroid Build Coastguard Worker
4649*da0073e9SAndroid Build Coastguard Worker    def test_tensor_dtype(self):
4650*da0073e9SAndroid Build Coastguard Worker        x_byte = torch.empty(34, 56, 78, dtype=torch.uint8)
4651*da0073e9SAndroid Build Coastguard Worker        x_long = torch.empty(34, 56, 78, dtype=torch.long)
4652*da0073e9SAndroid Build Coastguard Worker        x_float32 = torch.empty(34, 56, 78, dtype=torch.float32)
4653*da0073e9SAndroid Build Coastguard Worker
4654*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
4655*da0073e9SAndroid Build Coastguard Worker        def byte(x):
4656*da0073e9SAndroid Build Coastguard Worker            return x.dtype == torch.uint8
4657*da0073e9SAndroid Build Coastguard Worker
4658*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
4659*da0073e9SAndroid Build Coastguard Worker        def long(x):
4660*da0073e9SAndroid Build Coastguard Worker            return x.dtype == torch.long
4661*da0073e9SAndroid Build Coastguard Worker
4662*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
4663*da0073e9SAndroid Build Coastguard Worker        def float32(x):
4664*da0073e9SAndroid Build Coastguard Worker            return x.dtype == torch.float32
4665*da0073e9SAndroid Build Coastguard Worker
4666*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(byte(x_byte))
4667*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(byte(x_long))
4668*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(byte(x_float32))
4669*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(long(x_byte))
4670*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(long(x_long))
4671*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(long(x_float32))
4672*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(float32(x_byte))
4673*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(float32(x_long))
4674*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(float32(x_float32))
4675*da0073e9SAndroid Build Coastguard Worker
4676*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "device tests require CUDA")
4677*da0073e9SAndroid Build Coastguard Worker    def test_tensor_device(self):
4678*da0073e9SAndroid Build Coastguard Worker        cpu = torch.empty(34, 56, 78, device='cpu')
4679*da0073e9SAndroid Build Coastguard Worker        gpu = torch.empty(34, 56, 78, device='cuda')
4680*da0073e9SAndroid Build Coastguard Worker
4681*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
4682*da0073e9SAndroid Build Coastguard Worker        def same_device(x, y):
4683*da0073e9SAndroid Build Coastguard Worker            return x.device == y.device
4684*da0073e9SAndroid Build Coastguard Worker
4685*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same_device(cpu, cpu))
4686*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same_device(gpu, gpu))
4687*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(same_device(cpu, gpu))
4688*da0073e9SAndroid Build Coastguard Worker
4689*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "device tests require CUDA")
4690*da0073e9SAndroid Build Coastguard Worker    def test_tensor_to_device(self):
4691*da0073e9SAndroid Build Coastguard Worker        def to_device(x):
4692*da0073e9SAndroid Build Coastguard Worker            return x.to(device="cuda").to(device=torch.device("cpu"))
4693*da0073e9SAndroid Build Coastguard Worker
4694*da0073e9SAndroid Build Coastguard Worker        self.checkScript(to_device, (torch.ones(3, 4),))
4695*da0073e9SAndroid Build Coastguard Worker
4696*da0073e9SAndroid Build Coastguard Worker    def test_tensor_to_cpu(self):
4697*da0073e9SAndroid Build Coastguard Worker        def to_cpu(x):
4698*da0073e9SAndroid Build Coastguard Worker            return x.cpu()
4699*da0073e9SAndroid Build Coastguard Worker
4700*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(3, 4)
4701*da0073e9SAndroid Build Coastguard Worker        script_fn = torch.jit.script(to_cpu)
4702*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(to_cpu(x).device, script_fn(x).device)
4703*da0073e9SAndroid Build Coastguard Worker        self.checkScript(to_cpu, (x,))
4704*da0073e9SAndroid Build Coastguard Worker
4705*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "device tests require CUDA")
4706*da0073e9SAndroid Build Coastguard Worker    def test_tensor_to_cuda(self):
4707*da0073e9SAndroid Build Coastguard Worker        def to_cuda(x):
4708*da0073e9SAndroid Build Coastguard Worker            return x.cuda()
4709*da0073e9SAndroid Build Coastguard Worker
4710*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(3, 4)
4711*da0073e9SAndroid Build Coastguard Worker        script_fn = torch.jit.script(to_cuda)
4712*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(to_cuda(x).device, script_fn(x).device)
4713*da0073e9SAndroid Build Coastguard Worker        self.checkScript(to_cuda, (x,))
4714*da0073e9SAndroid Build Coastguard Worker
4715*da0073e9SAndroid Build Coastguard Worker    def test_generic_list_errors(self):
4716*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "previously matched to type"):
4717*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
4718*da0073e9SAndroid Build Coastguard Worker            def foo(x):
4719*da0073e9SAndroid Build Coastguard Worker                return [[x]] + [[1]]
4720*da0073e9SAndroid Build Coastguard Worker
4721*da0073e9SAndroid Build Coastguard Worker    def test_script_cu(self):
4722*da0073e9SAndroid Build Coastguard Worker        cu = torch.jit.CompilationUnit('''
4723*da0073e9SAndroid Build Coastguard Worker            def foo(a):
4724*da0073e9SAndroid Build Coastguard Worker                b = a
4725*da0073e9SAndroid Build Coastguard Worker                return b
4726*da0073e9SAndroid Build Coastguard Worker        ''')
4727*da0073e9SAndroid Build Coastguard Worker        a = Variable(torch.rand(1))
4728*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, cu.foo(a))
4729*da0073e9SAndroid Build Coastguard Worker
4730*da0073e9SAndroid Build Coastguard Worker    # because the compilation unit ingests python strings
4731*da0073e9SAndroid Build Coastguard Worker    # to use an escape sequence escape the backslash (\\n = \n)
4732*da0073e9SAndroid Build Coastguard Worker    def test_string_cu(self):
4733*da0073e9SAndroid Build Coastguard Worker        cu = torch.jit.CompilationUnit('''
4734*da0073e9SAndroid Build Coastguard Worker            def foo(a):
4735*da0073e9SAndroid Build Coastguard Worker                print(a, """a\\n\tb\\n""", 2, "a\
4736*da0073e9SAndroid Build Coastguard Workera")
4737*da0073e9SAndroid Build Coastguard Worker                return a
4738*da0073e9SAndroid Build Coastguard Worker        ''')
4739*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aa").check("a\\n\\tb\\n").run(str(cu.foo.graph))
4740*da0073e9SAndroid Build Coastguard Worker
4741*da0073e9SAndroid Build Coastguard Worker    def test_function_compilation_caching(self):
4742*da0073e9SAndroid Build Coastguard Worker        def fun():
4743*da0073e9SAndroid Build Coastguard Worker            return 1 + 2
4744*da0073e9SAndroid Build Coastguard Worker
4745*da0073e9SAndroid Build Coastguard Worker        fun_compiled = torch.jit.script(fun)
4746*da0073e9SAndroid Build Coastguard Worker        # python wrapper around the script function is a different pointer,
4747*da0073e9SAndroid Build Coastguard Worker        # but the underlying script function graph is the same
4748*da0073e9SAndroid Build Coastguard Worker        self.assertIs(fun_compiled.graph, torch.jit.script(fun).graph)
4749*da0073e9SAndroid Build Coastguard Worker
4750*da0073e9SAndroid Build Coastguard Worker        def fun():
4751*da0073e9SAndroid Build Coastguard Worker            return 3 + 4
4752*da0073e9SAndroid Build Coastguard Worker
4753*da0073e9SAndroid Build Coastguard Worker        num_ref_counts = sys.getrefcount(fun)
4754*da0073e9SAndroid Build Coastguard Worker
4755*da0073e9SAndroid Build Coastguard Worker        # caching doesn't get tripped up by same qualname
4756*da0073e9SAndroid Build Coastguard Worker        fun_compiled_2 = torch.jit.script(fun)
4757*da0073e9SAndroid Build Coastguard Worker        self.assertIsNot(fun_compiled, fun_compiled_2)
4758*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fun_compiled_2(), 7)
4759*da0073e9SAndroid Build Coastguard Worker
4760*da0073e9SAndroid Build Coastguard Worker        # caching doesnt increase refcounts to function (holds weak reference)
4761*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(sys.getrefcount(fun), num_ref_counts)
4762*da0073e9SAndroid Build Coastguard Worker
4763*da0073e9SAndroid Build Coastguard Worker    def test_string_ops(self):
4764*da0073e9SAndroid Build Coastguard Worker        def foo():
4765*da0073e9SAndroid Build Coastguard Worker            a = "a" + "b"
4766*da0073e9SAndroid Build Coastguard Worker            return a + a, "ab" == "b", "ab" != "b", "ab" == "ab", "ab" != "ab"
4767*da0073e9SAndroid Build Coastguard Worker
4768*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, ())
4769*da0073e9SAndroid Build Coastguard Worker
4770*da0073e9SAndroid Build Coastguard Worker    def test_string_sorted(self):
4771*da0073e9SAndroid Build Coastguard Worker        def foo(strs: List[str]):
4772*da0073e9SAndroid Build Coastguard Worker            return sorted(strs)
4773*da0073e9SAndroid Build Coastguard Worker
4774*da0073e9SAndroid Build Coastguard Worker        FileCheck() \
4775*da0073e9SAndroid Build Coastguard Worker            .check("graph") \
4776*da0073e9SAndroid Build Coastguard Worker            .check_next("str[] = aten::sorted") \
4777*da0073e9SAndroid Build Coastguard Worker            .check_next("return") \
4778*da0073e9SAndroid Build Coastguard Worker            .run(str(torch.jit.script(foo).graph))
4779*da0073e9SAndroid Build Coastguard Worker
4780*da0073e9SAndroid Build Coastguard Worker        inputs = ["str3", "str2", "str1"]
4781*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, (inputs,))
4782*da0073e9SAndroid Build Coastguard Worker
4783*da0073e9SAndroid Build Coastguard Worker    def test_string_sort(self):
4784*da0073e9SAndroid Build Coastguard Worker        def foo(strs: List[str]):
4785*da0073e9SAndroid Build Coastguard Worker            strs.sort()
4786*da0073e9SAndroid Build Coastguard Worker            return strs
4787*da0073e9SAndroid Build Coastguard Worker
4788*da0073e9SAndroid Build Coastguard Worker        inputs = ["str3", "str2", "str1"]
4789*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, (inputs,))
4790*da0073e9SAndroid Build Coastguard Worker
4791*da0073e9SAndroid Build Coastguard Worker    def test_tuple_sorted(self):
4792*da0073e9SAndroid Build Coastguard Worker        def foo(tups: List[Tuple[int, int]]):
4793*da0073e9SAndroid Build Coastguard Worker            return sorted(tups)
4794*da0073e9SAndroid Build Coastguard Worker
4795*da0073e9SAndroid Build Coastguard Worker        inputs = [(1, 2), (0, 2), (1, 3)]
4796*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, (inputs,))
4797*da0073e9SAndroid Build Coastguard Worker
4798*da0073e9SAndroid Build Coastguard Worker    def test_tuple_sort(self):
4799*da0073e9SAndroid Build Coastguard Worker        def foo(tups: List[Tuple[int, int]]):
4800*da0073e9SAndroid Build Coastguard Worker            tups.sort()
4801*da0073e9SAndroid Build Coastguard Worker            return tups
4802*da0073e9SAndroid Build Coastguard Worker
4803*da0073e9SAndroid Build Coastguard Worker        inputs = [(1, 2), (0, 2), (1, 3)]
4804*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, (inputs,))
4805*da0073e9SAndroid Build Coastguard Worker
4806*da0073e9SAndroid Build Coastguard Worker    def test_tuple_sort_reverse(self):
4807*da0073e9SAndroid Build Coastguard Worker        def foo(tups: List[Tuple[int, int]]):
4808*da0073e9SAndroid Build Coastguard Worker            tups.sort(reverse=True)
4809*da0073e9SAndroid Build Coastguard Worker            return tups
4810*da0073e9SAndroid Build Coastguard Worker
4811*da0073e9SAndroid Build Coastguard Worker        inputs = [(1, 2), (0, 2), (1, 3)]
4812*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, (inputs,))
4813*da0073e9SAndroid Build Coastguard Worker
4814*da0073e9SAndroid Build Coastguard Worker    def test_tuple_unsortable_element_type(self):
4815*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
4816*da0073e9SAndroid Build Coastguard Worker        def foo():
4817*da0073e9SAndroid Build Coastguard Worker            tups = [({1: 2}, {2: 3})]
4818*da0073e9SAndroid Build Coastguard Worker            tups.sort()
4819*da0073e9SAndroid Build Coastguard Worker            return tups
4820*da0073e9SAndroid Build Coastguard Worker
4821*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegexWithHighlight(RuntimeError, "are not sortable", "tups.sort"):
4822*da0073e9SAndroid Build Coastguard Worker            foo()
4823*da0073e9SAndroid Build Coastguard Worker
4824*da0073e9SAndroid Build Coastguard Worker    def test_tuple_unsortable_diff_type(self):
4825*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
4826*da0073e9SAndroid Build Coastguard Worker        def foo(inputs: List[Any]):
4827*da0073e9SAndroid Build Coastguard Worker            inputs.sort()
4828*da0073e9SAndroid Build Coastguard Worker            return inputs
4829*da0073e9SAndroid Build Coastguard Worker
4830*da0073e9SAndroid Build Coastguard Worker        inputs = [(1, 2), ("foo", "bar")]
4831*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegexWithHighlight(RuntimeError, "Only values of same type can be compared", "inputs.sort"):
4832*da0073e9SAndroid Build Coastguard Worker            foo(inputs)
4833*da0073e9SAndroid Build Coastguard Worker
4834*da0073e9SAndroid Build Coastguard Worker    def test_tuple_nested_sort(self):
4835*da0073e9SAndroid Build Coastguard Worker        def foo(inputs: List[Tuple[int, Tuple[int, str]]]):
4836*da0073e9SAndroid Build Coastguard Worker            inputs.sort()
4837*da0073e9SAndroid Build Coastguard Worker            return inputs
4838*da0073e9SAndroid Build Coastguard Worker
4839*da0073e9SAndroid Build Coastguard Worker        inputs = [(1, (2, "foo")), (1, (2, "bar")), (1, (0, "bar"))]
4840*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, (inputs,))
4841*da0073e9SAndroid Build Coastguard Worker
4842*da0073e9SAndroid Build Coastguard Worker    def test_tuple_unsortable_nested_diff_type(self):
4843*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
4844*da0073e9SAndroid Build Coastguard Worker        def foo(inputs: List[Any]):
4845*da0073e9SAndroid Build Coastguard Worker            inputs.sort()
4846*da0073e9SAndroid Build Coastguard Worker            return inputs
4847*da0073e9SAndroid Build Coastguard Worker
4848*da0073e9SAndroid Build Coastguard Worker        inputs = [(1, (2, 3)), (2, ("foo", "bar"))]
4849*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegexWithHighlight(RuntimeError, "Only values of same type can be compared", "inputs.sort"):
4850*da0073e9SAndroid Build Coastguard Worker            foo(inputs)
4851*da0073e9SAndroid Build Coastguard Worker
4852*da0073e9SAndroid Build Coastguard Worker    def test_string_new_line(self):
4853*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "expected a valid token*"):
4854*da0073e9SAndroid Build Coastguard Worker            torch.jit.CompilationUnit('''
4855*da0073e9SAndroid Build Coastguard Worker            def test_while(a):
4856*da0073e9SAndroid Build Coastguard Worker                print("
4857*da0073e9SAndroid Build Coastguard Worker                    a")
4858*da0073e9SAndroid Build Coastguard Worker                return a
4859*da0073e9SAndroid Build Coastguard Worker            ''')
4860*da0073e9SAndroid Build Coastguard Worker
4861*da0073e9SAndroid Build Coastguard Worker    def test_string_single_escape(self):
4862*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "expected a valid token*"):
4863*da0073e9SAndroid Build Coastguard Worker            torch.jit.CompilationUnit('''
4864*da0073e9SAndroid Build Coastguard Worker            def test_while(a):
4865*da0073e9SAndroid Build Coastguard Worker                print("\\")
4866*da0073e9SAndroid Build Coastguard Worker                return a
4867*da0073e9SAndroid Build Coastguard Worker            ''')
4868*da0073e9SAndroid Build Coastguard Worker
4869*da0073e9SAndroid Build Coastguard Worker    def test_script_annotation(self):
4870*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
4871*da0073e9SAndroid Build Coastguard Worker        def foo(a):
4872*da0073e9SAndroid Build Coastguard Worker            return a + a + a
4873*da0073e9SAndroid Build Coastguard Worker        s = Variable(torch.rand(2))
4874*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s + s + s, foo(s))
4875*da0073e9SAndroid Build Coastguard Worker
4876*da0073e9SAndroid Build Coastguard Worker    def test_torch_pow(self):
4877*da0073e9SAndroid Build Coastguard Worker        def func(a, b):
4878*da0073e9SAndroid Build Coastguard Worker            return pow(a, b)
4879*da0073e9SAndroid Build Coastguard Worker
4880*da0073e9SAndroid Build Coastguard Worker        def func2(a, b, c, d):
4881*da0073e9SAndroid Build Coastguard Worker            return pow(pow(c + a, b), d)
4882*da0073e9SAndroid Build Coastguard Worker
4883*da0073e9SAndroid Build Coastguard Worker        def func3(a : int, b : float):
4884*da0073e9SAndroid Build Coastguard Worker            # type: (int, float) -> float
4885*da0073e9SAndroid Build Coastguard Worker            return pow(a, b)
4886*da0073e9SAndroid Build Coastguard Worker
4887*da0073e9SAndroid Build Coastguard Worker        def func4():
4888*da0073e9SAndroid Build Coastguard Worker            # type: () -> float
4889*da0073e9SAndroid Build Coastguard Worker            return pow(2, -2)
4890*da0073e9SAndroid Build Coastguard Worker
4891*da0073e9SAndroid Build Coastguard Worker        def func5(x, y):
4892*da0073e9SAndroid Build Coastguard Worker            return pow(x.item(), y.item())
4893*da0073e9SAndroid Build Coastguard Worker
4894*da0073e9SAndroid Build Coastguard Worker        def func6(a : int, b : int):
4895*da0073e9SAndroid Build Coastguard Worker            # type: (int, int) -> float
4896*da0073e9SAndroid Build Coastguard Worker            return pow(a, b)
4897*da0073e9SAndroid Build Coastguard Worker
4898*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(1)
4899*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(1)
4900*da0073e9SAndroid Build Coastguard Worker        c = torch.rand(1)
4901*da0073e9SAndroid Build Coastguard Worker        d = torch.rand(1)
4902*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func, (a, b))
4903*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func2, (a, b, c, d))
4904*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func3, (4, -0.5))
4905*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func4, ())
4906*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func6, (2, 4))
4907*da0073e9SAndroid Build Coastguard Worker
4908*da0073e9SAndroid Build Coastguard Worker        inputs = [torch.tensor(2), torch.tensor(-2), torch.tensor(.5), torch.tensor(.2)]
4909*da0073e9SAndroid Build Coastguard Worker        for x in inputs:
4910*da0073e9SAndroid Build Coastguard Worker            for y in inputs:
4911*da0073e9SAndroid Build Coastguard Worker                if x < 0:
4912*da0073e9SAndroid Build Coastguard Worker                    continue
4913*da0073e9SAndroid Build Coastguard Worker                else:
4914*da0073e9SAndroid Build Coastguard Worker                    self.checkScript(func5, (x, y))
4915*da0073e9SAndroid Build Coastguard Worker
4916*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "device tests require CUDA")
4917*da0073e9SAndroid Build Coastguard Worker    def test_pow_scalar_backward_cuda(self):
4918*da0073e9SAndroid Build Coastguard Worker        # see that scalar exponent works with cuda base (#19253)
4919*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode_for_profiling_tests():
4920*da0073e9SAndroid Build Coastguard Worker            for dtype in [torch.float, torch.double]:
4921*da0073e9SAndroid Build Coastguard Worker                @torch.jit.script
4922*da0073e9SAndroid Build Coastguard Worker                def func(a, b):
4923*da0073e9SAndroid Build Coastguard Worker                    # type: (Tensor, float) -> Tensor
4924*da0073e9SAndroid Build Coastguard Worker                    return (a * 2) ** b
4925*da0073e9SAndroid Build Coastguard Worker
4926*da0073e9SAndroid Build Coastguard Worker                a = torch.rand(1, requires_grad=True, device='cuda', dtype=dtype)
4927*da0073e9SAndroid Build Coastguard Worker                func(a, 1, profile_and_replay=True).backward()
4928*da0073e9SAndroid Build Coastguard Worker
4929*da0073e9SAndroid Build Coastguard Worker                @torch.jit.script
4930*da0073e9SAndroid Build Coastguard Worker                def func(a, b):
4931*da0073e9SAndroid Build Coastguard Worker                    # type: (float, Tensor) -> Tensor
4932*da0073e9SAndroid Build Coastguard Worker                    return a ** (b * 2 + 1)
4933*da0073e9SAndroid Build Coastguard Worker
4934*da0073e9SAndroid Build Coastguard Worker                a = torch.rand(1, requires_grad=True, device='cuda', dtype=dtype)
4935*da0073e9SAndroid Build Coastguard Worker                func(2, a, profile_and_replay=True).backward()
4936*da0073e9SAndroid Build Coastguard Worker
4937*da0073e9SAndroid Build Coastguard Worker    def _check_code(self, code_str, fn_name, inputs):
4938*da0073e9SAndroid Build Coastguard Worker        scope = {}
4939*da0073e9SAndroid Build Coastguard Worker        exec(code_str, globals(), scope)
4940*da0073e9SAndroid Build Coastguard Worker        cu = torch.jit.CompilationUnit(code_str)
4941*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cu.func(*inputs), scope[fn_name](*inputs))
4942*da0073e9SAndroid Build Coastguard Worker
4943*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, 'no CUDA')
4944*da0073e9SAndroid Build Coastguard Worker    def test_scriptmodule_releases_tensors_cuda(self):
4945*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode_for_profiling_tests():
4946*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
4947*da0073e9SAndroid Build Coastguard Worker            def fn(x, y):
4948*da0073e9SAndroid Build Coastguard Worker                return x.sigmoid() * y.tanh()
4949*da0073e9SAndroid Build Coastguard Worker
4950*da0073e9SAndroid Build Coastguard Worker            def test(backward=False):
4951*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(3, 3, dtype=torch.double, device='cuda', requires_grad=True)
4952*da0073e9SAndroid Build Coastguard Worker                y = torch.randn(3, 3, dtype=torch.double, device='cuda', requires_grad=True)
4953*da0073e9SAndroid Build Coastguard Worker                out = fn(x, y, profile_and_replay=True)
4954*da0073e9SAndroid Build Coastguard Worker                if backward:
4955*da0073e9SAndroid Build Coastguard Worker                    out.sum().backward()
4956*da0073e9SAndroid Build Coastguard Worker
4957*da0073e9SAndroid Build Coastguard Worker            with self.assertLeaksNoCudaTensors():
4958*da0073e9SAndroid Build Coastguard Worker                test()
4959*da0073e9SAndroid Build Coastguard Worker                test()
4960*da0073e9SAndroid Build Coastguard Worker                test()
4961*da0073e9SAndroid Build Coastguard Worker
4962*da0073e9SAndroid Build Coastguard Worker            if GRAPH_EXECUTOR != ProfilingMode.SIMPLE:
4963*da0073e9SAndroid Build Coastguard Worker                with self.assertLeaksNoCudaTensors():
4964*da0073e9SAndroid Build Coastguard Worker                    test(backward=True)
4965*da0073e9SAndroid Build Coastguard Worker                    test(backward=True)
4966*da0073e9SAndroid Build Coastguard Worker                    test(backward=True)
4967*da0073e9SAndroid Build Coastguard Worker
4968*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
4969*da0073e9SAndroid Build Coastguard Worker    def test_index(self):
4970*da0073e9SAndroid Build Coastguard Worker        def consec(size, start=0):
4971*da0073e9SAndroid Build Coastguard Worker            numel = torch.tensor(size).prod().item()
4972*da0073e9SAndroid Build Coastguard Worker            return torch.arange(numel).view(size)
4973*da0073e9SAndroid Build Coastguard Worker
4974*da0073e9SAndroid Build Coastguard Worker        def consec_list(size):
4975*da0073e9SAndroid Build Coastguard Worker            return list(range(size))
4976*da0073e9SAndroid Build Coastguard Worker
4977*da0073e9SAndroid Build Coastguard Worker        def random_string(size):
4978*da0073e9SAndroid Build Coastguard Worker            letters = string.ascii_lowercase
4979*da0073e9SAndroid Build Coastguard Worker            return "".join(random.choice(letters) for i in range(size))
4980*da0073e9SAndroid Build Coastguard Worker
4981*da0073e9SAndroid Build Coastguard Worker        def check_indexing(indexing, tensor):
4982*da0073e9SAndroid Build Coastguard Worker            template = dedent("""
4983*da0073e9SAndroid Build Coastguard Worker            def func(x):
4984*da0073e9SAndroid Build Coastguard Worker                return x{}
4985*da0073e9SAndroid Build Coastguard Worker            """)
4986*da0073e9SAndroid Build Coastguard Worker
4987*da0073e9SAndroid Build Coastguard Worker            self._check_code(template.format(indexing), "func", [tensor])
4988*da0073e9SAndroid Build Coastguard Worker
4989*da0073e9SAndroid Build Coastguard Worker        def check_dynamic_indexing(indexing, tensor, value1, value2):
4990*da0073e9SAndroid Build Coastguard Worker            value1 = torch.tensor(value1)
4991*da0073e9SAndroid Build Coastguard Worker            value2 = torch.tensor(value2)
4992*da0073e9SAndroid Build Coastguard Worker
4993*da0073e9SAndroid Build Coastguard Worker            template = dedent("""
4994*da0073e9SAndroid Build Coastguard Worker            def func(x, value1, value2):
4995*da0073e9SAndroid Build Coastguard Worker                i = int(value1)
4996*da0073e9SAndroid Build Coastguard Worker                j = int(value2)
4997*da0073e9SAndroid Build Coastguard Worker                return x{}
4998*da0073e9SAndroid Build Coastguard Worker            """)
4999*da0073e9SAndroid Build Coastguard Worker
5000*da0073e9SAndroid Build Coastguard Worker            self._check_code(template.format(indexing), "func", [tensor, value1, value2])
5001*da0073e9SAndroid Build Coastguard Worker
5002*da0073e9SAndroid Build Coastguard Worker        # Torchscript assumes type Tensor by default, so we need this explicit
5003*da0073e9SAndroid Build Coastguard Worker        # declaration.
5004*da0073e9SAndroid Build Coastguard Worker        def check_indexing_list_int(indexing, list):
5005*da0073e9SAndroid Build Coastguard Worker            template = dedent("""
5006*da0073e9SAndroid Build Coastguard Worker            def func(x):
5007*da0073e9SAndroid Build Coastguard Worker                # type: (List[int]) -> Any
5008*da0073e9SAndroid Build Coastguard Worker                return x{}
5009*da0073e9SAndroid Build Coastguard Worker            """)
5010*da0073e9SAndroid Build Coastguard Worker
5011*da0073e9SAndroid Build Coastguard Worker            self._check_code(template.format(indexing), "func", [list])
5012*da0073e9SAndroid Build Coastguard Worker
5013*da0073e9SAndroid Build Coastguard Worker        def check_indexing_str(indexing, str):
5014*da0073e9SAndroid Build Coastguard Worker            template = dedent("""
5015*da0073e9SAndroid Build Coastguard Worker            def func(x):
5016*da0073e9SAndroid Build Coastguard Worker                # type: (str) -> Any
5017*da0073e9SAndroid Build Coastguard Worker                return x{}
5018*da0073e9SAndroid Build Coastguard Worker            """)
5019*da0073e9SAndroid Build Coastguard Worker
5020*da0073e9SAndroid Build Coastguard Worker            self._check_code(template.format(indexing), "func", [str])
5021*da0073e9SAndroid Build Coastguard Worker
5022*da0073e9SAndroid Build Coastguard Worker        # basic slices
5023*da0073e9SAndroid Build Coastguard Worker        check_indexing('[0]', consec((3, 3)))
5024*da0073e9SAndroid Build Coastguard Worker        check_indexing('[1]', consec((3, 3), 10))
5025*da0073e9SAndroid Build Coastguard Worker        check_indexing('[2]', consec((3, 3), 19))
5026*da0073e9SAndroid Build Coastguard Worker        check_indexing('[2]', consec((3,)))
5027*da0073e9SAndroid Build Coastguard Worker        check_indexing('[-1]', consec((3, 3), 19))
5028*da0073e9SAndroid Build Coastguard Worker        check_indexing('[0:2]', consec((3, 3, 3)))
5029*da0073e9SAndroid Build Coastguard Worker        check_indexing('[1:-1]', consec((3, 3, 3)))
5030*da0073e9SAndroid Build Coastguard Worker        check_indexing('[-3:-1]', consec((6, 3)))
5031*da0073e9SAndroid Build Coastguard Worker        check_indexing('[1:]', consec((3, 3)))
5032*da0073e9SAndroid Build Coastguard Worker        check_indexing('[:1]', consec((3, 3)))
5033*da0073e9SAndroid Build Coastguard Worker        check_indexing('[:]', consec((3, 2)))
5034*da0073e9SAndroid Build Coastguard Worker
5035*da0073e9SAndroid Build Coastguard Worker        # multi-dim: indexes
5036*da0073e9SAndroid Build Coastguard Worker        check_indexing('[0, 1]', consec((3, 3)))
5037*da0073e9SAndroid Build Coastguard Worker        check_indexing('[0, 1]', consec((3, 3, 2)))
5038*da0073e9SAndroid Build Coastguard Worker        check_indexing('[1, 0, 2]', consec((3, 3, 3)))
5039*da0073e9SAndroid Build Coastguard Worker        check_indexing('[2, -1]', consec((3, 3)))
5040*da0073e9SAndroid Build Coastguard Worker
5041*da0073e9SAndroid Build Coastguard Worker        # multi-dim: mixed slicing and indexing
5042*da0073e9SAndroid Build Coastguard Worker        check_indexing('[0, 1:2]', consec((3, 3)))
5043*da0073e9SAndroid Build Coastguard Worker        check_indexing('[0, :1]', consec((3, 3, 2)))
5044*da0073e9SAndroid Build Coastguard Worker        check_indexing('[1, 2:]', consec((3, 3, 3)))
5045*da0073e9SAndroid Build Coastguard Worker        check_indexing('[-1, 1:, 0]', consec((3, 3, 3, 3)))
5046*da0073e9SAndroid Build Coastguard Worker        check_indexing('[1:, -1, 0]', consec((3, 3, 3, 3)))
5047*da0073e9SAndroid Build Coastguard Worker        check_indexing('[-1, 2:, 1:2]', consec((3, 3, 3, 3)))
5048*da0073e9SAndroid Build Coastguard Worker        check_indexing('[-1, 1:, 0]', consec((3, 3, 3, 3)))
5049*da0073e9SAndroid Build Coastguard Worker        check_indexing('[-1, :, 0, 2]', consec((3, 3, 3, 3)))
5050*da0073e9SAndroid Build Coastguard Worker
5051*da0073e9SAndroid Build Coastguard Worker        # zero-sized slices
5052*da0073e9SAndroid Build Coastguard Worker        check_indexing('[0:0]', consec((2, 2)))
5053*da0073e9SAndroid Build Coastguard Worker        check_indexing('[0:0, 1]', consec((3, 3)))
5054*da0073e9SAndroid Build Coastguard Worker
5055*da0073e9SAndroid Build Coastguard Worker        # trivial expression usage
5056*da0073e9SAndroid Build Coastguard Worker        check_indexing('[1+1]', consec((3, 3)))
5057*da0073e9SAndroid Build Coastguard Worker        check_indexing('[1:(0 + 2)]', consec((3, 3, 3)))
5058*da0073e9SAndroid Build Coastguard Worker
5059*da0073e9SAndroid Build Coastguard Worker        # None for new dimensions
5060*da0073e9SAndroid Build Coastguard Worker        check_indexing('[None, 0]', consec((3, 3)))
5061*da0073e9SAndroid Build Coastguard Worker        check_indexing('[1, None]', consec((3, 3), 10))
5062*da0073e9SAndroid Build Coastguard Worker        check_indexing('[None, None, 2]', consec((3, 3), 19))
5063*da0073e9SAndroid Build Coastguard Worker        check_indexing('[None, 2, None]', consec((3,)))
5064*da0073e9SAndroid Build Coastguard Worker        check_indexing('[0:2, None]', consec((3, 3, 3)))
5065*da0073e9SAndroid Build Coastguard Worker        check_indexing('[None, 1:-1]', consec((3, 3, 3)))
5066*da0073e9SAndroid Build Coastguard Worker        check_indexing('[None, -3:-1, None]', consec((6, 3)))
5067*da0073e9SAndroid Build Coastguard Worker        check_indexing('[-1, None, 2:, None, 1:2]', consec((3, 3, 3, 3)))
5068*da0073e9SAndroid Build Coastguard Worker        check_indexing('[None, -1, None, 2:, None, 1:2, None]', consec((3, 3, 3, 3)))
5069*da0073e9SAndroid Build Coastguard Worker
5070*da0073e9SAndroid Build Coastguard Worker        # dynamic expression usage
5071*da0073e9SAndroid Build Coastguard Worker        check_dynamic_indexing("[i + j]", consec((3, 3)), 0, 1)
5072*da0073e9SAndroid Build Coastguard Worker        check_dynamic_indexing("[i:j, i]", consec((3, 3, 2)), 0, 2)
5073*da0073e9SAndroid Build Coastguard Worker
5074*da0073e9SAndroid Build Coastguard Worker        # positive striding
5075*da0073e9SAndroid Build Coastguard Worker        check_indexing_list_int('[0]', consec_list(6))
5076*da0073e9SAndroid Build Coastguard Worker        check_indexing_list_int('[1]', consec_list(7))
5077*da0073e9SAndroid Build Coastguard Worker        check_indexing_list_int('[2]', consec_list(8))
5078*da0073e9SAndroid Build Coastguard Worker        check_indexing_list_int('[2]', consec_list(9))
5079*da0073e9SAndroid Build Coastguard Worker        check_indexing_list_int('[-1]', consec_list(10))
5080*da0073e9SAndroid Build Coastguard Worker        check_indexing_list_int('[0:2]', consec_list(11))
5081*da0073e9SAndroid Build Coastguard Worker        check_indexing_list_int('[1:-1]', consec_list(12))
5082*da0073e9SAndroid Build Coastguard Worker        check_indexing_list_int('[-3:-1]', consec_list(13))
5083*da0073e9SAndroid Build Coastguard Worker        check_indexing_list_int('[1:]', consec_list(15))
5084*da0073e9SAndroid Build Coastguard Worker        check_indexing_list_int('[:1]', consec_list(16))
5085*da0073e9SAndroid Build Coastguard Worker        check_indexing_list_int('[:]', consec_list(17))
5086*da0073e9SAndroid Build Coastguard Worker        check_indexing_list_int('[::]', consec_list(0))
5087*da0073e9SAndroid Build Coastguard Worker        check_indexing_list_int('[1000::]', consec_list(0))
5088*da0073e9SAndroid Build Coastguard Worker        check_indexing_list_int('[:1000:]', consec_list(0))
5089*da0073e9SAndroid Build Coastguard Worker
5090*da0073e9SAndroid Build Coastguard Worker        # negative striding
5091*da0073e9SAndroid Build Coastguard Worker        check_indexing_list_int('[::-1]', consec_list(7))
5092*da0073e9SAndroid Build Coastguard Worker        check_indexing_list_int('[:3:-1]', consec_list(7))
5093*da0073e9SAndroid Build Coastguard Worker        check_indexing_list_int('[3::-1]', consec_list(7))
5094*da0073e9SAndroid Build Coastguard Worker        check_indexing_list_int('[1000::-1]', consec_list(7))
5095*da0073e9SAndroid Build Coastguard Worker        check_indexing_list_int('[3:0:-1]', consec_list(7))
5096*da0073e9SAndroid Build Coastguard Worker        check_indexing_list_int('[3:-1000:-1]', consec_list(7))
5097*da0073e9SAndroid Build Coastguard Worker        check_indexing_list_int('[0:0:-1]', consec_list(7))
5098*da0073e9SAndroid Build Coastguard Worker        check_indexing_list_int('[0:-1000:-1]', consec_list(7))
5099*da0073e9SAndroid Build Coastguard Worker
5100*da0073e9SAndroid Build Coastguard Worker        # only step is specified
5101*da0073e9SAndroid Build Coastguard Worker        check_indexing_list_int('[::-1]', consec_list(0))
5102*da0073e9SAndroid Build Coastguard Worker        check_indexing_list_int('[::-1]', consec_list(7))
5103*da0073e9SAndroid Build Coastguard Worker        check_indexing_list_int('[::-2]', consec_list(7))
5104*da0073e9SAndroid Build Coastguard Worker        check_indexing_list_int('[::2]', consec_list(7))
5105*da0073e9SAndroid Build Coastguard Worker        check_indexing_list_int('[::42]', consec_list(7))
5106*da0073e9SAndroid Build Coastguard Worker        check_indexing_list_int('[::-42]', consec_list(7))
5107*da0073e9SAndroid Build Coastguard Worker        check_indexing_list_int('[::42]', consec_list(0))
5108*da0073e9SAndroid Build Coastguard Worker        check_indexing_list_int('[::-42]', consec_list(0))
5109*da0073e9SAndroid Build Coastguard Worker        check_indexing_list_int('[::9223372036854775807]', consec_list(42))
5110*da0073e9SAndroid Build Coastguard Worker        check_indexing_list_int('[::-9223372036854775807]', consec_list(42))
5111*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "out of bounds"):
5112*da0073e9SAndroid Build Coastguard Worker            check_indexing_list_int('[::-9223372036854775808]', consec_list(42))
5113*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "should have non-zero step"):
5114*da0073e9SAndroid Build Coastguard Worker            check_indexing_list_int('[::0]', consec_list(42))
5115*da0073e9SAndroid Build Coastguard Worker
5116*da0073e9SAndroid Build Coastguard Worker        # striding strings
5117*da0073e9SAndroid Build Coastguard Worker        check_indexing_str('[0]', random_string(6))
5118*da0073e9SAndroid Build Coastguard Worker        check_indexing_str('[1]', random_string(7))
5119*da0073e9SAndroid Build Coastguard Worker        check_indexing_str('[2]', random_string(8))
5120*da0073e9SAndroid Build Coastguard Worker        check_indexing_str('[2]', random_string(9))
5121*da0073e9SAndroid Build Coastguard Worker        check_indexing_str('[-1]', random_string(10))
5122*da0073e9SAndroid Build Coastguard Worker        check_indexing_str('[0:2]', random_string(11))
5123*da0073e9SAndroid Build Coastguard Worker        check_indexing_str('[1:-1]', random_string(12))
5124*da0073e9SAndroid Build Coastguard Worker        check_indexing_str('[-3:-1]', random_string(13))
5125*da0073e9SAndroid Build Coastguard Worker        check_indexing_str('[1:]', random_string(15))
5126*da0073e9SAndroid Build Coastguard Worker        check_indexing_str('[:1]', random_string(16))
5127*da0073e9SAndroid Build Coastguard Worker        check_indexing_str('[:]', random_string(17))
5128*da0073e9SAndroid Build Coastguard Worker        check_indexing_str('[::]', random_string(0))
5129*da0073e9SAndroid Build Coastguard Worker        check_indexing_str('[1000::]', random_string(0))
5130*da0073e9SAndroid Build Coastguard Worker        check_indexing_str('[:1000:]', random_string(0))
5131*da0073e9SAndroid Build Coastguard Worker
5132*da0073e9SAndroid Build Coastguard Worker        check_indexing_str('[::-1]', random_string(7))
5133*da0073e9SAndroid Build Coastguard Worker        check_indexing_str('[:3:-1]', random_string(7))
5134*da0073e9SAndroid Build Coastguard Worker        check_indexing_str('[3::-1]', random_string(7))
5135*da0073e9SAndroid Build Coastguard Worker        check_indexing_str('[1000::-1]', random_string(7))
5136*da0073e9SAndroid Build Coastguard Worker        check_indexing_str('[3:0:-1]', random_string(7))
5137*da0073e9SAndroid Build Coastguard Worker        check_indexing_str('[3:-1000:-1]', random_string(7))
5138*da0073e9SAndroid Build Coastguard Worker        check_indexing_str('[0:0:-1]', random_string(7))
5139*da0073e9SAndroid Build Coastguard Worker        check_indexing_str('[0:-1000:-1]', random_string(7))
5140*da0073e9SAndroid Build Coastguard Worker
5141*da0073e9SAndroid Build Coastguard Worker        check_indexing_str('[::-1]', random_string(0))
5142*da0073e9SAndroid Build Coastguard Worker        check_indexing_str('[::-1]', random_string(7))
5143*da0073e9SAndroid Build Coastguard Worker        check_indexing_str('[::-2]', random_string(7))
5144*da0073e9SAndroid Build Coastguard Worker        check_indexing_str('[::2]', random_string(7))
5145*da0073e9SAndroid Build Coastguard Worker        check_indexing_str('[::42]', random_string(7))
5146*da0073e9SAndroid Build Coastguard Worker        check_indexing_str('[::-42]', random_string(7))
5147*da0073e9SAndroid Build Coastguard Worker        check_indexing_str('[::42]', random_string(0))
5148*da0073e9SAndroid Build Coastguard Worker        check_indexing_str('[::-42]', random_string(0))
5149*da0073e9SAndroid Build Coastguard Worker        check_indexing_str('[::9223372036854775807]', random_string(42))
5150*da0073e9SAndroid Build Coastguard Worker        check_indexing_str('[::-9223372036854775807]', random_string(42))
5151*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "out of bounds"):
5152*da0073e9SAndroid Build Coastguard Worker            check_indexing_str('[::-9223372036854775808]', random_string(42))
5153*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "should have non-zero step"):
5154*da0073e9SAndroid Build Coastguard Worker            check_indexing_str('[::0]', random_string(42))
5155*da0073e9SAndroid Build Coastguard Worker
5156*da0073e9SAndroid Build Coastguard Worker    def test_module_copy_with_attributes(self):
5157*da0073e9SAndroid Build Coastguard Worker        class Vocabulary(torch.jit.ScriptModule):
5158*da0073e9SAndroid Build Coastguard Worker            def __init__(self, vocab_list):
5159*da0073e9SAndroid Build Coastguard Worker                super().__init__()
5160*da0073e9SAndroid Build Coastguard Worker                self._vocab = torch.jit.Attribute(vocab_list, List[str])
5161*da0073e9SAndroid Build Coastguard Worker                self.some_idx = torch.jit.Attribute(2, int)
5162*da0073e9SAndroid Build Coastguard Worker                self.idx = torch.jit.Attribute(
5163*da0073e9SAndroid Build Coastguard Worker                    {word: i for i, word in enumerate(vocab_list)}, Dict[str, int]
5164*da0073e9SAndroid Build Coastguard Worker                )
5165*da0073e9SAndroid Build Coastguard Worker
5166*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
5167*da0073e9SAndroid Build Coastguard Worker            def lookup_indices_1d(self, values):
5168*da0073e9SAndroid Build Coastguard Worker                # type: (List[str]) -> List[int]
5169*da0073e9SAndroid Build Coastguard Worker                result = torch.jit.annotate(List[int], [])
5170*da0073e9SAndroid Build Coastguard Worker                # Direct list iteration not supported
5171*da0073e9SAndroid Build Coastguard Worker                for i in range(len(values)):
5172*da0073e9SAndroid Build Coastguard Worker                    value = values[i]
5173*da0073e9SAndroid Build Coastguard Worker                    result.append(self.idx.get(value, self.some_idx))
5174*da0073e9SAndroid Build Coastguard Worker                return result
5175*da0073e9SAndroid Build Coastguard Worker
5176*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
5177*da0073e9SAndroid Build Coastguard Worker            def forward(self, values):
5178*da0073e9SAndroid Build Coastguard Worker                # type: (List[List[str]]) -> List[List[int]]
5179*da0073e9SAndroid Build Coastguard Worker                result = torch.jit.annotate(List[List[int]], [])
5180*da0073e9SAndroid Build Coastguard Worker                # Direct list iteration not supported
5181*da0073e9SAndroid Build Coastguard Worker                for i in range(len(values)):
5182*da0073e9SAndroid Build Coastguard Worker                    result.append(self.lookup_indices_1d(values[i]))
5183*da0073e9SAndroid Build Coastguard Worker                return result
5184*da0073e9SAndroid Build Coastguard Worker
5185*da0073e9SAndroid Build Coastguard Worker        v = Vocabulary(list('uabcdefg'))
5186*da0073e9SAndroid Build Coastguard Worker        v.__copy__()
5187*da0073e9SAndroid Build Coastguard Worker
5188*da0073e9SAndroid Build Coastguard Worker    def test_tuple_to_opt_list(self):
5189*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
5190*da0073e9SAndroid Build Coastguard Worker        def foo(x):
5191*da0073e9SAndroid Build Coastguard Worker            # type: (Optional[List[int]]) -> int
5192*da0073e9SAndroid Build Coastguard Worker            return 1
5193*da0073e9SAndroid Build Coastguard Worker
5194*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
5195*da0073e9SAndroid Build Coastguard Worker        def tuple_call():
5196*da0073e9SAndroid Build Coastguard Worker            return foo((1, 2))
5197*da0073e9SAndroid Build Coastguard Worker
5198*da0073e9SAndroid Build Coastguard Worker    def test_keyword(self):
5199*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
5200*da0073e9SAndroid Build Coastguard Worker        def func(x):
5201*da0073e9SAndroid Build Coastguard Worker            return torch.sum(x, dim=0)
5202*da0073e9SAndroid Build Coastguard Worker
5203*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(10, dtype=torch.float, requires_grad=True)
5204*da0073e9SAndroid Build Coastguard Worker        y = func(x)
5205*da0073e9SAndroid Build Coastguard Worker        y2 = torch.sum(x, dim=0)
5206*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, y2)
5207*da0073e9SAndroid Build Coastguard Worker
5208*da0073e9SAndroid Build Coastguard Worker    def test_constant_pooling_none(self):
5209*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
5210*da0073e9SAndroid Build Coastguard Worker        def typed_nones(a=None, b=None, c=None):
5211*da0073e9SAndroid Build Coastguard Worker            # type: (Optional[int], Optional[bool], Optional[Tensor]) -> Tuple[Optional[int], Optional[bool], Optional[Tensor]]
5212*da0073e9SAndroid Build Coastguard Worker            return a, b, c
5213*da0073e9SAndroid Build Coastguard Worker
5214*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
5215*da0073e9SAndroid Build Coastguard Worker        def test(a):
5216*da0073e9SAndroid Build Coastguard Worker            # type: (bool) -> None
5217*da0073e9SAndroid Build Coastguard Worker            if a:
5218*da0073e9SAndroid Build Coastguard Worker                print(typed_nones())
5219*da0073e9SAndroid Build Coastguard Worker            else:
5220*da0073e9SAndroid Build Coastguard Worker                print(typed_nones())
5221*da0073e9SAndroid Build Coastguard Worker
5222*da0073e9SAndroid Build Coastguard Worker        graph_str = str(test.graph)
5223*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(graph_str.count("NoneType = prim::Constant") == 1)
5224*da0073e9SAndroid Build Coastguard Worker
5225*da0073e9SAndroid Build Coastguard Worker    def test_constant_pooling_same_identity(self):
5226*da0073e9SAndroid Build Coastguard Worker        def foo():
5227*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor([4])
5228*da0073e9SAndroid Build Coastguard Worker            b = (a,)
5229*da0073e9SAndroid Build Coastguard Worker            index = len(a) - 1
5230*da0073e9SAndroid Build Coastguard Worker            c = b[index]
5231*da0073e9SAndroid Build Coastguard Worker            d = b[index]
5232*da0073e9SAndroid Build Coastguard Worker            return c, d
5233*da0073e9SAndroid Build Coastguard Worker
5234*da0073e9SAndroid Build Coastguard Worker        foo_script = torch.jit.script(foo)
5235*da0073e9SAndroid Build Coastguard Worker        self.run_pass('constant_propagation', foo_script.graph)
5236*da0073e9SAndroid Build Coastguard Worker        self.run_pass('constant_pooling', foo_script.graph)
5237*da0073e9SAndroid Build Coastguard Worker        # even though the c & d escape scope, we are still able
5238*da0073e9SAndroid Build Coastguard Worker        # pool them into one constant because they are the same object
5239*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("prim::Constant", 1, exactly=True).run(foo_script.graph)
5240*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(), foo_script())
5241*da0073e9SAndroid Build Coastguard Worker
5242*da0073e9SAndroid Build Coastguard Worker    def test_constant_pooling_introduce_aliasing(self):
5243*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
5244*da0073e9SAndroid Build Coastguard Worker        def foo():
5245*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor(1)
5246*da0073e9SAndroid Build Coastguard Worker            b = torch.tensor(1)
5247*da0073e9SAndroid Build Coastguard Worker            return a, b
5248*da0073e9SAndroid Build Coastguard Worker
5249*da0073e9SAndroid Build Coastguard Worker        self.run_pass('constant_propagation', foo.graph)
5250*da0073e9SAndroid Build Coastguard Worker        self.run_pass('constant_pooling', foo.graph)
5251*da0073e9SAndroid Build Coastguard Worker        # dont pool constants bc it would introduce observable alias relationship changing
5252*da0073e9SAndroid Build Coastguard Worker        a, b = foo()
5253*da0073e9SAndroid Build Coastguard Worker        self.assertIsNot(a, b)
5254*da0073e9SAndroid Build Coastguard Worker
5255*da0073e9SAndroid Build Coastguard Worker    def test_literal(self):
5256*da0073e9SAndroid Build Coastguard Worker        def func1(a, b):
5257*da0073e9SAndroid Build Coastguard Worker            c = a, b
5258*da0073e9SAndroid Build Coastguard Worker            d, e = c
5259*da0073e9SAndroid Build Coastguard Worker            return d + e
5260*da0073e9SAndroid Build Coastguard Worker
5261*da0073e9SAndroid Build Coastguard Worker        def func2(a, b):
5262*da0073e9SAndroid Build Coastguard Worker            c = a, (a, b)
5263*da0073e9SAndroid Build Coastguard Worker            d, e = c
5264*da0073e9SAndroid Build Coastguard Worker            f, g = e
5265*da0073e9SAndroid Build Coastguard Worker            return d + f + g
5266*da0073e9SAndroid Build Coastguard Worker
5267*da0073e9SAndroid Build Coastguard Worker        def func3(a, b):
5268*da0073e9SAndroid Build Coastguard Worker            # type: (float, float) -> float
5269*da0073e9SAndroid Build Coastguard Worker            c = 0., (0., 0.)
5270*da0073e9SAndroid Build Coastguard Worker            x = True
5271*da0073e9SAndroid Build Coastguard Worker            while x:
5272*da0073e9SAndroid Build Coastguard Worker                x = False
5273*da0073e9SAndroid Build Coastguard Worker                c = a, (a, b)
5274*da0073e9SAndroid Build Coastguard Worker            d, e = c
5275*da0073e9SAndroid Build Coastguard Worker            f, g = e
5276*da0073e9SAndroid Build Coastguard Worker            return d + f + g
5277*da0073e9SAndroid Build Coastguard Worker
5278*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(1, requires_grad=True)
5279*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(1, requires_grad=True)
5280*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func1, (a, b), optimize=True)
5281*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func2, (a, b), optimize=True)
5282*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func3, (a.item(), b.item()), optimize=True)
5283*da0073e9SAndroid Build Coastguard Worker
5284*da0073e9SAndroid Build Coastguard Worker    def test_expand(self):
5285*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
5286*da0073e9SAndroid Build Coastguard Worker        def func(x, y):
5287*da0073e9SAndroid Build Coastguard Worker            return x + y
5288*da0073e9SAndroid Build Coastguard Worker
5289*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(2, 3, dtype=torch.float, requires_grad=True)
5290*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(3, dtype=torch.float, requires_grad=True)
5291*da0073e9SAndroid Build Coastguard Worker        out = func(x, y)
5292*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(func(x, y), x + y)
5293*da0073e9SAndroid Build Coastguard Worker
5294*da0073e9SAndroid Build Coastguard Worker        grad = torch.randn(2, 3, dtype=torch.float)
5295*da0073e9SAndroid Build Coastguard Worker        out.backward(grad)
5296*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, grad)
5297*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad, grad.sum(dim=0))
5298*da0073e9SAndroid Build Coastguard Worker
5299*da0073e9SAndroid Build Coastguard Worker    def test_sum(self):
5300*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
5301*da0073e9SAndroid Build Coastguard Worker        def func(x):
5302*da0073e9SAndroid Build Coastguard Worker            return x.sum(dim=[4])
5303*da0073e9SAndroid Build Coastguard Worker
5304*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
5305*da0073e9SAndroid Build Coastguard Worker        def func2(x):
5306*da0073e9SAndroid Build Coastguard Worker            return x.sum(dim=4)
5307*da0073e9SAndroid Build Coastguard Worker
5308*da0073e9SAndroid Build Coastguard Worker        # test that shape analysis is written correctly for sum with OptionalIntArrayRef[1] dim argument
5309*da0073e9SAndroid Build Coastguard Worker        self.run_pass('constant_propagation', func.graph)
5310*da0073e9SAndroid Build Coastguard Worker        self.run_pass('constant_propagation', func2.graph)
5311*da0073e9SAndroid Build Coastguard Worker        g = _propagate_shapes(func.graph, (torch.zeros(1, 1, 1, 1, 4),), False)
5312*da0073e9SAndroid Build Coastguard Worker        g2 = _propagate_shapes(func2.graph, (torch.zeros(1, 1, 1, 1, 4),), False)
5313*da0073e9SAndroid Build Coastguard Worker
5314*da0073e9SAndroid Build Coastguard Worker    def test_cat(self):
5315*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode_for_profiling_tests():
5316*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
5317*da0073e9SAndroid Build Coastguard Worker            def func(x):
5318*da0073e9SAndroid Build Coastguard Worker                return torch.cat((x, x), dim=0)
5319*da0073e9SAndroid Build Coastguard Worker
5320*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(10, dtype=torch.float, requires_grad=True)
5321*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(func(x, profile_and_replay=True), torch.cat((x, x), dim=0))
5322*da0073e9SAndroid Build Coastguard Worker
5323*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
5324*da0073e9SAndroid Build Coastguard Worker            def func2(x, y):
5325*da0073e9SAndroid Build Coastguard Worker                return torch.cat((x, x), y)
5326*da0073e9SAndroid Build Coastguard Worker
5327*da0073e9SAndroid Build Coastguard Worker            with disable_autodiff_subgraph_inlining():
5328*da0073e9SAndroid Build Coastguard Worker                for sizes in ((2, 2), (0, 2)):
5329*da0073e9SAndroid Build Coastguard Worker                    x = torch.rand(sizes).requires_grad_()
5330*da0073e9SAndroid Build Coastguard Worker                    y = torch.tensor(1)
5331*da0073e9SAndroid Build Coastguard Worker
5332*da0073e9SAndroid Build Coastguard Worker                    output = func2(x, y, profile_and_replay=True)
5333*da0073e9SAndroid Build Coastguard Worker                    output_ref = torch.cat((x, x), y)
5334*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(output, output_ref)
5335*da0073e9SAndroid Build Coastguard Worker
5336*da0073e9SAndroid Build Coastguard Worker                    if GRAPH_EXECUTOR != ProfilingMode.SIMPLE:
5337*da0073e9SAndroid Build Coastguard Worker                        self.assertAutodiffNode(func2.graph_for(x, y), True, ['aten::cat'], [])
5338*da0073e9SAndroid Build Coastguard Worker
5339*da0073e9SAndroid Build Coastguard Worker                        grad = torch.autograd.grad(output.sum(), x)
5340*da0073e9SAndroid Build Coastguard Worker                        grad_ref = torch.autograd.grad(output_ref.sum(), x)
5341*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(grad, grad_ref)
5342*da0073e9SAndroid Build Coastguard Worker
5343*da0073e9SAndroid Build Coastguard Worker    def test_cat_lifts(self):
5344*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
5345*da0073e9SAndroid Build Coastguard Worker        def foo(x):
5346*da0073e9SAndroid Build Coastguard Worker            return torch.cat([x, x], dim=1)
5347*da0073e9SAndroid Build Coastguard Worker
5348*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
5349*da0073e9SAndroid Build Coastguard Worker        def foo2(x):
5350*da0073e9SAndroid Build Coastguard Worker            return torch.cat([], dim=1)
5351*da0073e9SAndroid Build Coastguard Worker
5352*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
5353*da0073e9SAndroid Build Coastguard Worker        def foo3(x):
5354*da0073e9SAndroid Build Coastguard Worker            return torch.cat([x], dim=1)
5355*da0073e9SAndroid Build Coastguard Worker
5356*da0073e9SAndroid Build Coastguard Worker        for g in [foo.graph, foo2.graph, foo3.graph]:
5357*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("int =").check("ListConstruct").check("aten::cat").run(str(g))
5358*da0073e9SAndroid Build Coastguard Worker
5359*da0073e9SAndroid Build Coastguard Worker    def test_stack(self):
5360*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode_for_profiling_tests():
5361*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
5362*da0073e9SAndroid Build Coastguard Worker            def func(x):
5363*da0073e9SAndroid Build Coastguard Worker                return torch.stack((x, x), dim=1)
5364*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(10, 10)
5365*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(func(x, profile_and_replay=True), torch.stack((x, x), dim=1))
5366*da0073e9SAndroid Build Coastguard Worker
5367*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
5368*da0073e9SAndroid Build Coastguard Worker            def func2(x, y):
5369*da0073e9SAndroid Build Coastguard Worker                return torch.stack((x, y), dim=0)
5370*da0073e9SAndroid Build Coastguard Worker
5371*da0073e9SAndroid Build Coastguard Worker            with disable_autodiff_subgraph_inlining():
5372*da0073e9SAndroid Build Coastguard Worker                x = torch.randn([2, 2]).requires_grad_()
5373*da0073e9SAndroid Build Coastguard Worker                y = torch.randn([2, 2]).requires_grad_()
5374*da0073e9SAndroid Build Coastguard Worker
5375*da0073e9SAndroid Build Coastguard Worker                output = func2(x, y, profile_and_replay=True)
5376*da0073e9SAndroid Build Coastguard Worker                output_ref = torch.stack((x, y), 0)
5377*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(output, output_ref)
5378*da0073e9SAndroid Build Coastguard Worker                if GRAPH_EXECUTOR != ProfilingMode.SIMPLE:
5379*da0073e9SAndroid Build Coastguard Worker                    self.assertAutodiffNode(func2.graph_for(x, y), True, ['aten::stack'], [])
5380*da0073e9SAndroid Build Coastguard Worker
5381*da0073e9SAndroid Build Coastguard Worker                    grads = torch.autograd.grad(output.sum(), (x, y))
5382*da0073e9SAndroid Build Coastguard Worker                    grads_ref = torch.autograd.grad(output_ref.sum(), (x, y))
5383*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(grads, grads_ref)
5384*da0073e9SAndroid Build Coastguard Worker
5385*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY,
5386*da0073e9SAndroid Build Coastguard Worker                     "Profiling executor will be using different heuristics for constructing differentiable graphs")
5387*da0073e9SAndroid Build Coastguard Worker    def test_unbind(self):
5388*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode_for_profiling_tests():
5389*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
5390*da0073e9SAndroid Build Coastguard Worker            def func(x, y):
5391*da0073e9SAndroid Build Coastguard Worker                # type: (Tensor, int) -> List[Tensor]
5392*da0073e9SAndroid Build Coastguard Worker                return torch.unbind(x, y)
5393*da0073e9SAndroid Build Coastguard Worker
5394*da0073e9SAndroid Build Coastguard Worker            with disable_autodiff_subgraph_inlining():
5395*da0073e9SAndroid Build Coastguard Worker                x = torch.rand([2, 2]).requires_grad_()
5396*da0073e9SAndroid Build Coastguard Worker                y = 0
5397*da0073e9SAndroid Build Coastguard Worker                outputs = func(x, y, profile_and_replay=True)
5398*da0073e9SAndroid Build Coastguard Worker                outputs_ref = torch.unbind(x, dim=y)
5399*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(outputs, outputs_ref)
5400*da0073e9SAndroid Build Coastguard Worker                self.assertAutodiffNode(func.graph_for(x, y), True, [], [])
5401*da0073e9SAndroid Build Coastguard Worker
5402*da0073e9SAndroid Build Coastguard Worker                grad = torch.autograd.grad(_sum_of_list(outputs), x)
5403*da0073e9SAndroid Build Coastguard Worker                grad_ref = torch.autograd.grad(_sum_of_list(outputs_ref), x)
5404*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(grad, grad_ref)
5405*da0073e9SAndroid Build Coastguard Worker
5406*da0073e9SAndroid Build Coastguard Worker
5407*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.PROFILING,
5408*da0073e9SAndroid Build Coastguard Worker                     "Profiling executor fails to recognize that tensors in a list require gradients")
5409*da0073e9SAndroid Build Coastguard Worker    def test_meshgrid(self):
5410*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode_for_profiling_tests():
5411*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
5412*da0073e9SAndroid Build Coastguard Worker            def func(a):
5413*da0073e9SAndroid Build Coastguard Worker                # type: (List[Tensor]) -> List[Tensor]
5414*da0073e9SAndroid Build Coastguard Worker                return torch.meshgrid(a)
5415*da0073e9SAndroid Build Coastguard Worker            with disable_autodiff_subgraph_inlining():
5416*da0073e9SAndroid Build Coastguard Worker                a = torch.tensor([1.0, 2, 3]).requires_grad_()
5417*da0073e9SAndroid Build Coastguard Worker                b = torch.tensor([1.0, 2, 3, 4]).requires_grad_()
5418*da0073e9SAndroid Build Coastguard Worker                inputs = [a, b]
5419*da0073e9SAndroid Build Coastguard Worker
5420*da0073e9SAndroid Build Coastguard Worker                outputs_ref = torch.meshgrid(inputs)
5421*da0073e9SAndroid Build Coastguard Worker                outputs = func(inputs, profile_and_replay=True)
5422*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(outputs, outputs_ref)
5423*da0073e9SAndroid Build Coastguard Worker
5424*da0073e9SAndroid Build Coastguard Worker                if GRAPH_EXECUTOR != ProfilingMode.SIMPLE:
5425*da0073e9SAndroid Build Coastguard Worker                    self.assertAutodiffNode(func.graph_for(inputs), True, [], [])
5426*da0073e9SAndroid Build Coastguard Worker
5427*da0073e9SAndroid Build Coastguard Worker                    grads = torch.autograd.grad(_sum_of_list(outputs), inputs)
5428*da0073e9SAndroid Build Coastguard Worker                    grads_ref = torch.autograd.grad(_sum_of_list(outputs_ref), inputs)
5429*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(grads, grads_ref)
5430*da0073e9SAndroid Build Coastguard Worker
5431*da0073e9SAndroid Build Coastguard Worker    def test_tensor_len(self):
5432*da0073e9SAndroid Build Coastguard Worker        def func(x):
5433*da0073e9SAndroid Build Coastguard Worker            return len(x)
5434*da0073e9SAndroid Build Coastguard Worker
5435*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func, [torch.ones(4, 5, 6)])
5436*da0073e9SAndroid Build Coastguard Worker
5437*da0073e9SAndroid Build Coastguard Worker    def test_func_call(self):
5438*da0073e9SAndroid Build Coastguard Worker        def add(a, b):
5439*da0073e9SAndroid Build Coastguard Worker            return a + b
5440*da0073e9SAndroid Build Coastguard Worker
5441*da0073e9SAndroid Build Coastguard Worker        def mul(a, x):
5442*da0073e9SAndroid Build Coastguard Worker            return a * x
5443*da0073e9SAndroid Build Coastguard Worker
5444*da0073e9SAndroid Build Coastguard Worker        def func(alpha, beta, x, y):
5445*da0073e9SAndroid Build Coastguard Worker            return add(mul(alpha, x), mul(beta, y))
5446*da0073e9SAndroid Build Coastguard Worker
5447*da0073e9SAndroid Build Coastguard Worker        alpha = torch.rand(1, dtype=torch.float, requires_grad=True)
5448*da0073e9SAndroid Build Coastguard Worker        beta = torch.rand(1, dtype=torch.float, requires_grad=True)
5449*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, dtype=torch.float, requires_grad=True)
5450*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(3, dtype=torch.float, requires_grad=True)
5451*da0073e9SAndroid Build Coastguard Worker
5452*da0073e9SAndroid Build Coastguard Worker        # NOTE: cannot optimize yet because broadcasts are not inserted before the fuser runs
5453*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func, [alpha, beta, x, y], optimize=False)
5454*da0073e9SAndroid Build Coastguard Worker
5455*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("bailouts are being deprecated")
5456*da0073e9SAndroid Build Coastguard Worker    def test_profiling_graph_executor(self):
5457*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
5458*da0073e9SAndroid Build Coastguard Worker        def def_in_one_branch(x, z):
5459*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor, bool) -> float
5460*da0073e9SAndroid Build Coastguard Worker            y = x
5461*da0073e9SAndroid Build Coastguard Worker            if z is False:
5462*da0073e9SAndroid Build Coastguard Worker                y = x + 1
5463*da0073e9SAndroid Build Coastguard Worker
5464*da0073e9SAndroid Build Coastguard Worker            return y.sum()
5465*da0073e9SAndroid Build Coastguard Worker
5466*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(2, 3)
5467*da0073e9SAndroid Build Coastguard Worker
5468*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode_for_profiling_tests():
5469*da0073e9SAndroid Build Coastguard Worker            # check prim::profile are inserted
5470*da0073e9SAndroid Build Coastguard Worker            profiled_graph_str = str(def_in_one_branch.graph_for(a, True))
5471*da0073e9SAndroid Build Coastguard Worker            FileCheck().check_count("prim::profile", 4).run(profiled_graph_str)
5472*da0073e9SAndroid Build Coastguard Worker            # this call is optimized for
5473*da0073e9SAndroid Build Coastguard Worker            # the given shape of (2, 3)
5474*da0073e9SAndroid Build Coastguard Worker            def_in_one_branch(a, False)
5475*da0073e9SAndroid Build Coastguard Worker            # change shape to (3)
5476*da0073e9SAndroid Build Coastguard Worker            # so we go down a bailout path
5477*da0073e9SAndroid Build Coastguard Worker            a = torch.ones(3)
5478*da0073e9SAndroid Build Coastguard Worker            # check prim::BailOuts are inserted
5479*da0073e9SAndroid Build Coastguard Worker            bailout_graph_str = str(def_in_one_branch.graph_for(a, True))
5480*da0073e9SAndroid Build Coastguard Worker            FileCheck().check_count("prim::BailOut", 3).run(bailout_graph_str)
5481*da0073e9SAndroid Build Coastguard Worker            # this triggers all 3 bailouts
5482*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(def_in_one_branch(a, False), 6.0)
5483*da0073e9SAndroid Build Coastguard Worker            # this triggers 2 bailouts
5484*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(def_in_one_branch(a, True), 3.0)
5485*da0073e9SAndroid Build Coastguard Worker
5486*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("bailouts are being deprecated")
5487*da0073e9SAndroid Build Coastguard Worker    def test_maxpool_guard_elimination(self):
5488*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
5489*da0073e9SAndroid Build Coastguard Worker        def my_maxpool(x):
5490*da0073e9SAndroid Build Coastguard Worker            return F.max_pool1d(x, kernel_size=[1]) + torch.ones([32, 32, 32])
5491*da0073e9SAndroid Build Coastguard Worker
5492*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(32, 32, 32)
5493*da0073e9SAndroid Build Coastguard Worker
5494*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode_for_profiling_tests():
5495*da0073e9SAndroid Build Coastguard Worker            my_maxpool(a)
5496*da0073e9SAndroid Build Coastguard Worker            bailout_graph_str = str(my_maxpool.graph_for(a))
5497*da0073e9SAndroid Build Coastguard Worker            FileCheck().check_count("prim::BailOut", 1).run(bailout_graph_str)
5498*da0073e9SAndroid Build Coastguard Worker
5499*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("bailouts are being deprecated")
5500*da0073e9SAndroid Build Coastguard Worker    def test_slice_guard_elimination(self):
5501*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
5502*da0073e9SAndroid Build Coastguard Worker        def my_slice(x):
5503*da0073e9SAndroid Build Coastguard Worker            return x[0:16:2] + x[0:16:2]
5504*da0073e9SAndroid Build Coastguard Worker
5505*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(32, 4)
5506*da0073e9SAndroid Build Coastguard Worker
5507*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode_for_profiling_tests():
5508*da0073e9SAndroid Build Coastguard Worker            my_slice(a)
5509*da0073e9SAndroid Build Coastguard Worker            bailout_graph_str = str(my_slice.graph_for(a))
5510*da0073e9SAndroid Build Coastguard Worker            FileCheck().check_count("prim::BailOut", 1).run(bailout_graph_str)
5511*da0073e9SAndroid Build Coastguard Worker
5512*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("bailouts are being deprecated")
5513*da0073e9SAndroid Build Coastguard Worker    def test_unsqueeze_guard_elimination(self):
5514*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
5515*da0073e9SAndroid Build Coastguard Worker        def my_unsqueeze(x):
5516*da0073e9SAndroid Build Coastguard Worker            return torch.unsqueeze(x, 0) + torch.unsqueeze(x, 0)
5517*da0073e9SAndroid Build Coastguard Worker
5518*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(32, 4)
5519*da0073e9SAndroid Build Coastguard Worker
5520*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode_for_profiling_tests():
5521*da0073e9SAndroid Build Coastguard Worker            my_unsqueeze(a)
5522*da0073e9SAndroid Build Coastguard Worker            bailout_graph_str = str(my_unsqueeze.graph_for(a))
5523*da0073e9SAndroid Build Coastguard Worker            FileCheck().check_count("prim::BailOut", 2).run(bailout_graph_str)
5524*da0073e9SAndroid Build Coastguard Worker
5525*da0073e9SAndroid Build Coastguard Worker    def test_resize_input_ops(self):
5526*da0073e9SAndroid Build Coastguard Worker        # resize_ and resize_as resize the input tensor. because our shape analysis
5527*da0073e9SAndroid Build Coastguard Worker        # is flow invariant, we set any Tensor that can alias a resized Tensor
5528*da0073e9SAndroid Build Coastguard Worker        # to the base Tensor Type, without size information.
5529*da0073e9SAndroid Build Coastguard Worker
5530*da0073e9SAndroid Build Coastguard Worker        # testing that value which is an input of a graph gets handled
5531*da0073e9SAndroid Build Coastguard Worker        def out_op_graph_input():
5532*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
5533*da0073e9SAndroid Build Coastguard Worker            def test(x, y, z):
5534*da0073e9SAndroid Build Coastguard Worker                torch.mul(x, y, out=z)
5535*da0073e9SAndroid Build Coastguard Worker                return z
5536*da0073e9SAndroid Build Coastguard Worker
5537*da0073e9SAndroid Build Coastguard Worker            graph = _propagate_shapes(test.graph,
5538*da0073e9SAndroid Build Coastguard Worker                                      (torch.zeros(2, 1), torch.zeros(1, 2), torch.zeros(1, 1, 1)), False)
5539*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(next(graph.outputs()).type() == TensorType.get())
5540*da0073e9SAndroid Build Coastguard Worker        out_op_graph_input()
5541*da0073e9SAndroid Build Coastguard Worker
5542*da0073e9SAndroid Build Coastguard Worker        def test_resize():
5543*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
5544*da0073e9SAndroid Build Coastguard Worker            def test(x):
5545*da0073e9SAndroid Build Coastguard Worker                after_resize_alias = torch.zeros([2])
5546*da0073e9SAndroid Build Coastguard Worker                for _i in range(5):
5547*da0073e9SAndroid Build Coastguard Worker                    b = x + 1
5548*da0073e9SAndroid Build Coastguard Worker                    f = [1]
5549*da0073e9SAndroid Build Coastguard Worker                    before_resize_alias = b.sub_(1)
5550*da0073e9SAndroid Build Coastguard Worker                    # for i in range(10):
5551*da0073e9SAndroid Build Coastguard Worker                    f.append(1)
5552*da0073e9SAndroid Build Coastguard Worker                    b.resize_(f)
5553*da0073e9SAndroid Build Coastguard Worker                    after_resize_alias = b.add_(1)
5554*da0073e9SAndroid Build Coastguard Worker                return after_resize_alias
5555*da0073e9SAndroid Build Coastguard Worker
5556*da0073e9SAndroid Build Coastguard Worker            self.run_pass('constant_propagation', test.graph)
5557*da0073e9SAndroid Build Coastguard Worker            g = _propagate_shapes(test.graph, (torch.zeros(1, 1),), False)
5558*da0073e9SAndroid Build Coastguard Worker            resize_node = g.findNode("aten::resize_")
5559*da0073e9SAndroid Build Coastguard Worker            # first input and output of b.resize_ is b
5560*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(next(resize_node.inputs()).type() == TensorType.get())
5561*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(next(resize_node.outputs()).type() == TensorType.get())
5562*da0073e9SAndroid Build Coastguard Worker
5563*da0073e9SAndroid Build Coastguard Worker            # correctly propagates to b alias set
5564*da0073e9SAndroid Build Coastguard Worker            before_resize = g.findNode("aten::sub_")
5565*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(next(before_resize.outputs()).type() == TensorType.get())
5566*da0073e9SAndroid Build Coastguard Worker
5567*da0073e9SAndroid Build Coastguard Worker            after_resize = g.findNode("aten::add_")
5568*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(next(after_resize.outputs()).type() == TensorType.get())
5569*da0073e9SAndroid Build Coastguard Worker
5570*da0073e9SAndroid Build Coastguard Worker        test_resize()
5571*da0073e9SAndroid Build Coastguard Worker
5572*da0073e9SAndroid Build Coastguard Worker        def test_resize_as():
5573*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
5574*da0073e9SAndroid Build Coastguard Worker            def test(x):
5575*da0073e9SAndroid Build Coastguard Worker                b = torch.zeros([2, 2])
5576*da0073e9SAndroid Build Coastguard Worker                b.resize_as_(x)
5577*da0073e9SAndroid Build Coastguard Worker                return b
5578*da0073e9SAndroid Build Coastguard Worker
5579*da0073e9SAndroid Build Coastguard Worker            g = test.graph
5580*da0073e9SAndroid Build Coastguard Worker            self.run_pass('constant_propagation', g)
5581*da0073e9SAndroid Build Coastguard Worker            g = _propagate_shapes(test.graph, (torch.zeros(1, 1),), False)
5582*da0073e9SAndroid Build Coastguard Worker
5583*da0073e9SAndroid Build Coastguard Worker            # x doesn't alias a resized op so it shouldn't be set to base Tensor type
5584*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(next(g.inputs()).type() != TensorType.get())
5585*da0073e9SAndroid Build Coastguard Worker            # return is resized
5586*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(next(g.outputs()).type() == TensorType.get())
5587*da0073e9SAndroid Build Coastguard Worker
5588*da0073e9SAndroid Build Coastguard Worker        test_resize_as()
5589*da0073e9SAndroid Build Coastguard Worker
5590*da0073e9SAndroid Build Coastguard Worker    def test_uninitialized(self):
5591*da0073e9SAndroid Build Coastguard Worker        graph_str = """graph():
5592*da0073e9SAndroid Build Coastguard Worker          %1 : int = prim::Uninitialized()
5593*da0073e9SAndroid Build Coastguard Worker          %2 : int = prim::Constant[value=1]()
5594*da0073e9SAndroid Build Coastguard Worker          %3 : int = aten::add(%1, %2)
5595*da0073e9SAndroid Build Coastguard Worker          return (%3)
5596*da0073e9SAndroid Build Coastguard Worker        """
5597*da0073e9SAndroid Build Coastguard Worker        g = parse_ir(graph_str)
5598*da0073e9SAndroid Build Coastguard Worker        m = self.createFunctionFromGraph(g)
5599*da0073e9SAndroid Build Coastguard Worker        self.getExportImportCopy(m)
5600*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "expected int"):
5601*da0073e9SAndroid Build Coastguard Worker            m()
5602*da0073e9SAndroid Build Coastguard Worker
5603*da0073e9SAndroid Build Coastguard Worker
5604*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't use requires_grad information")
5605*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.PROFILING, "Peeling is now disabled")
5606*da0073e9SAndroid Build Coastguard Worker    def test_requires_grad_loop(self):
5607*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
5608*da0073e9SAndroid Build Coastguard Worker        def test(x, y, z):
5609*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor, Tensor, int) -> Tensor
5610*da0073e9SAndroid Build Coastguard Worker            for _ in range(z):
5611*da0073e9SAndroid Build Coastguard Worker                x = y
5612*da0073e9SAndroid Build Coastguard Worker            return x
5613*da0073e9SAndroid Build Coastguard Worker
5614*da0073e9SAndroid Build Coastguard Worker        # x requires grad, y does not
5615*da0073e9SAndroid Build Coastguard Worker        # testing that requires grad analysis correctly exits, with its input
5616*da0073e9SAndroid Build Coastguard Worker        # to the loop (x) requiring grad and its output to the loop not requiring grad
5617*da0073e9SAndroid Build Coastguard Worker        # and the output of the node conservatively setting grad to true
5618*da0073e9SAndroid Build Coastguard Worker
5619*da0073e9SAndroid Build Coastguard Worker        inps = (torch.tensor(1.0, requires_grad=True), torch.tensor(1), 10)
5620*da0073e9SAndroid Build Coastguard Worker        test(*inps, profile_and_replay=True)
5621*da0073e9SAndroid Build Coastguard Worker
5622*da0073e9SAndroid Build Coastguard Worker        graph = test.graph_for(*inps)
5623*da0073e9SAndroid Build Coastguard Worker        loop = graph.findNode("prim::Loop")
5624*da0073e9SAndroid Build Coastguard Worker        loop_body = next(loop.blocks())
5625*da0073e9SAndroid Build Coastguard Worker        loop_inputs = list(loop_body.inputs())
5626*da0073e9SAndroid Build Coastguard Worker        loop_outputs = list(loop_body.outputs())
5627*da0073e9SAndroid Build Coastguard Worker
5628*da0073e9SAndroid Build Coastguard Worker        if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
5629*da0073e9SAndroid Build Coastguard Worker            # TODO: simplify this test as it's very sensitive
5630*da0073e9SAndroid Build Coastguard Worker            # the optimized graph will have 3 loops
5631*da0073e9SAndroid Build Coastguard Worker            # the original loop is peeled
5632*da0073e9SAndroid Build Coastguard Worker            # peeled loop also gets unrolled
5633*da0073e9SAndroid Build Coastguard Worker            index_of_x_in_peeled_unrolled_loop = -2
5634*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(loop_inputs[index_of_x_in_peeled_unrolled_loop].requires_grad())
5635*da0073e9SAndroid Build Coastguard Worker            bailouts_in_outer_block = graph.findAllNodes("prim::BailOut", False)
5636*da0073e9SAndroid Build Coastguard Worker            last_bailout_index_on_loops_output = -1
5637*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(bailouts_in_outer_block[last_bailout_index_on_loops_output].output().requires_grad())
5638*da0073e9SAndroid Build Coastguard Worker        else:
5639*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(loop_inputs[1].requires_grad())
5640*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(loop.output().requires_grad())
5641*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(loop_outputs[1].requires_grad())
5642*da0073e9SAndroid Build Coastguard Worker
5643*da0073e9SAndroid Build Coastguard Worker    def test_view_shape_prop(self):
5644*da0073e9SAndroid Build Coastguard Worker        cu = torch.jit.CompilationUnit('''
5645*da0073e9SAndroid Build Coastguard Worker        def test_view_shape_prop(a):
5646*da0073e9SAndroid Build Coastguard Worker            return a.view(size=[-1])
5647*da0073e9SAndroid Build Coastguard Worker        ''')
5648*da0073e9SAndroid Build Coastguard Worker        inputs = [torch.zeros(10, 10)]
5649*da0073e9SAndroid Build Coastguard Worker        outputs = torch.zeros(100)
5650*da0073e9SAndroid Build Coastguard Worker
5651*da0073e9SAndroid Build Coastguard Worker        real_outs = cu.test_view_shape_prop(*inputs)
5652*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(real_outs, outputs)
5653*da0073e9SAndroid Build Coastguard Worker
5654*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
5655*da0073e9SAndroid Build Coastguard Worker    def test_view_listconstruct_shape_prop(self):
5656*da0073e9SAndroid Build Coastguard Worker        def fn(x):
5657*da0073e9SAndroid Build Coastguard Worker            B = x.size(0)
5658*da0073e9SAndroid Build Coastguard Worker            C = x.size(1)
5659*da0073e9SAndroid Build Coastguard Worker            T = x.size(2)
5660*da0073e9SAndroid Build Coastguard Worker            return x.view(T, B, C)
5661*da0073e9SAndroid Build Coastguard Worker
5662*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 1, 5, requires_grad=True)
5663*da0073e9SAndroid Build Coastguard Worker        fn = torch.jit.script(fn)
5664*da0073e9SAndroid Build Coastguard Worker        graph = _propagate_shapes(fn.graph, (x,), False)
5665*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(next(graph.outputs()).type().scalarType() == 'Float')
5666*da0073e9SAndroid Build Coastguard Worker
5667*da0073e9SAndroid Build Coastguard Worker    def test_shape_prop_promotion(self):
5668*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
5669*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
5670*da0073e9SAndroid Build Coastguard Worker            return x + y
5671*da0073e9SAndroid Build Coastguard Worker
5672*da0073e9SAndroid Build Coastguard Worker        x, y = torch.rand(3, 4, dtype=torch.float), torch.rand(3, 4, dtype=torch.double)
5673*da0073e9SAndroid Build Coastguard Worker        graph = _propagate_shapes(fn.graph, (x, y), False)
5674*da0073e9SAndroid Build Coastguard Worker        FileCheck().check('Double(*, *, device=cpu) = aten::add').run(graph)
5675*da0073e9SAndroid Build Coastguard Worker
5676*da0073e9SAndroid Build Coastguard Worker    def test_shape_prop_promote_scalar_arg(self):
5677*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
5678*da0073e9SAndroid Build Coastguard Worker        def fn(x):
5679*da0073e9SAndroid Build Coastguard Worker            return math.pi + x
5680*da0073e9SAndroid Build Coastguard Worker
5681*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(3, 4, dtype=torch.long)
5682*da0073e9SAndroid Build Coastguard Worker        graph = _propagate_shapes(fn.graph, (x,), False)
5683*da0073e9SAndroid Build Coastguard Worker        default = torch.get_default_dtype()
5684*da0073e9SAndroid Build Coastguard Worker        if default == torch.float:
5685*da0073e9SAndroid Build Coastguard Worker            FileCheck().check('Float(*, *, requires_grad=0, device=cpu) = aten::add').run(graph)
5686*da0073e9SAndroid Build Coastguard Worker        else:
5687*da0073e9SAndroid Build Coastguard Worker            FileCheck().check('Double(*, *, requires_grad=0, device=cpu) = aten::add').run(graph)
5688*da0073e9SAndroid Build Coastguard Worker
5689*da0073e9SAndroid Build Coastguard Worker    def test_integral_shape_inference(self):
5690*da0073e9SAndroid Build Coastguard Worker        cu = torch.jit.CompilationUnit('''
5691*da0073e9SAndroid Build Coastguard Worker        def test_integral_shape_inference(a):
5692*da0073e9SAndroid Build Coastguard Worker            return a * a
5693*da0073e9SAndroid Build Coastguard Worker        ''')
5694*da0073e9SAndroid Build Coastguard Worker        inputs = [torch.ones(10, 10, dtype=torch.long)]
5695*da0073e9SAndroid Build Coastguard Worker        outputs = torch.ones(10, 10, dtype=torch.long)
5696*da0073e9SAndroid Build Coastguard Worker
5697*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cu.test_integral_shape_inference(*inputs), outputs)
5698*da0073e9SAndroid Build Coastguard Worker
5699*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(RUN_CUDA, 'This tests the CPU fuser')
5700*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle")
5701*da0073e9SAndroid Build Coastguard Worker    @enable_cpu_fuser
5702*da0073e9SAndroid Build Coastguard Worker    def test_batchnorm_fuser_cpu(self):
5703*da0073e9SAndroid Build Coastguard Worker        code = '''
5704*da0073e9SAndroid Build Coastguard Worker            graph(%3 : Tensor,
5705*da0073e9SAndroid Build Coastguard Worker                  %7 : Tensor,
5706*da0073e9SAndroid Build Coastguard Worker                  %12 : Float(*, *),
5707*da0073e9SAndroid Build Coastguard Worker                  %13 : Tensor,
5708*da0073e9SAndroid Build Coastguard Worker                  %25 : Tensor):
5709*da0073e9SAndroid Build Coastguard Worker                %23 : int = prim::Constant[value=1]()
5710*da0073e9SAndroid Build Coastguard Worker                %22 : float = prim::Constant[value=1e-05]()
5711*da0073e9SAndroid Build Coastguard Worker                %26 : Tensor = aten::sqrt(%25)
5712*da0073e9SAndroid Build Coastguard Worker                %24 : Tensor = aten::add(%26, %22, %23)
5713*da0073e9SAndroid Build Coastguard Worker                %20 : Tensor = aten::reciprocal(%24)
5714*da0073e9SAndroid Build Coastguard Worker                %norm_invstd : Tensor = aten::mul(%20, %23)
5715*da0073e9SAndroid Build Coastguard Worker                %15 : Tensor = aten::sub(%12, %13, %23)
5716*da0073e9SAndroid Build Coastguard Worker                %11 : Tensor = aten::mul(%15, %norm_invstd)
5717*da0073e9SAndroid Build Coastguard Worker                %8 : Tensor = aten::mul(%11, %7)
5718*da0073e9SAndroid Build Coastguard Worker                %5 : Tensor = aten::add(%8, %3, %23)
5719*da0073e9SAndroid Build Coastguard Worker                %1 : Float(*, *) = aten::relu(%5)
5720*da0073e9SAndroid Build Coastguard Worker                return (%1)
5721*da0073e9SAndroid Build Coastguard Worker        '''
5722*da0073e9SAndroid Build Coastguard Worker
5723*da0073e9SAndroid Build Coastguard Worker        graph = parse_ir(code)
5724*da0073e9SAndroid Build Coastguard Worker        inputs = 5 * [torch.rand(26, 2048, dtype=torch.float)]
5725*da0073e9SAndroid Build Coastguard Worker        code = torch._C._jit_fuser_get_fused_kernel_code(graph, inputs)
5726*da0073e9SAndroid Build Coastguard Worker        FileCheck().check('sqrtf').run(code)
5727*da0073e9SAndroid Build Coastguard Worker
5728*da0073e9SAndroid Build Coastguard Worker    @slowTest
5729*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(RUN_CUDA, 'This tests the CPU fuser')
5730*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle")
5731*da0073e9SAndroid Build Coastguard Worker    @enable_cpu_fuser
5732*da0073e9SAndroid Build Coastguard Worker    def test_fuser_double_float_codegen(self):
5733*da0073e9SAndroid Build Coastguard Worker        fns = ['log', 'log10', 'log1p', 'log2', 'lgamma', 'exp', 'expm1', 'erf',
5734*da0073e9SAndroid Build Coastguard Worker               'erfc', 'cos', 'acos', 'cosh', 'sin', 'asin', 'sinh', 'tan',
5735*da0073e9SAndroid Build Coastguard Worker               'atan', 'tanh', 'sqrt', 'ceil', 'floor', 'round', 'trunc',
5736*da0073e9SAndroid Build Coastguard Worker               'frac']
5737*da0073e9SAndroid Build Coastguard Worker
5738*da0073e9SAndroid Build Coastguard Worker        def lookup_c_equivalent_fn(aten_fn):
5739*da0073e9SAndroid Build Coastguard Worker            return aten_fn
5740*da0073e9SAndroid Build Coastguard Worker
5741*da0073e9SAndroid Build Coastguard Worker        def test_dispatch(op, expects, dtype, binary=False):
5742*da0073e9SAndroid Build Coastguard Worker            if dtype == torch.double:
5743*da0073e9SAndroid Build Coastguard Worker                dtype_str = 'Double'
5744*da0073e9SAndroid Build Coastguard Worker            elif dtype == torch.float:
5745*da0073e9SAndroid Build Coastguard Worker                dtype_str = 'Float'
5746*da0073e9SAndroid Build Coastguard Worker            else:
5747*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError('Unknown dtype')
5748*da0073e9SAndroid Build Coastguard Worker
5749*da0073e9SAndroid Build Coastguard Worker            if binary:
5750*da0073e9SAndroid Build Coastguard Worker                code = f'''
5751*da0073e9SAndroid Build Coastguard Worker                    graph(%3 : Tensor, %4 : Tensor):
5752*da0073e9SAndroid Build Coastguard Worker                        %2 : {dtype_str}(*, *) = aten::{op}(%3, %4)
5753*da0073e9SAndroid Build Coastguard Worker                        %1 : {dtype_str}(*, *) = aten::relu(%2)
5754*da0073e9SAndroid Build Coastguard Worker                        return (%1)
5755*da0073e9SAndroid Build Coastguard Worker                '''
5756*da0073e9SAndroid Build Coastguard Worker            else:
5757*da0073e9SAndroid Build Coastguard Worker                code = f'''
5758*da0073e9SAndroid Build Coastguard Worker                    graph(%3 : Tensor):
5759*da0073e9SAndroid Build Coastguard Worker                        %2 : {dtype_str}(*, *) = aten::{op}(%3)
5760*da0073e9SAndroid Build Coastguard Worker                        %1 : {dtype_str}(*, *) = aten::relu(%2)
5761*da0073e9SAndroid Build Coastguard Worker                        return (%1)
5762*da0073e9SAndroid Build Coastguard Worker                '''
5763*da0073e9SAndroid Build Coastguard Worker
5764*da0073e9SAndroid Build Coastguard Worker            graph = parse_ir(code)
5765*da0073e9SAndroid Build Coastguard Worker            inputs = (2 if binary else 1) * [torch.rand(26, 2048, dtype=dtype)]
5766*da0073e9SAndroid Build Coastguard Worker            code = torch._C._jit_fuser_get_fused_kernel_code(graph, inputs)
5767*da0073e9SAndroid Build Coastguard Worker            FileCheck().check(expects).run(code)
5768*da0073e9SAndroid Build Coastguard Worker
5769*da0073e9SAndroid Build Coastguard Worker        for fn in fns:
5770*da0073e9SAndroid Build Coastguard Worker            test_dispatch(fn, lookup_c_equivalent_fn(fn) + '(', torch.double)
5771*da0073e9SAndroid Build Coastguard Worker            test_dispatch(fn, lookup_c_equivalent_fn(fn) + 'f(', torch.float)
5772*da0073e9SAndroid Build Coastguard Worker
5773*da0073e9SAndroid Build Coastguard Worker        # 'min', 'max' were previously tested but are now replaced with ternary expressions
5774*da0073e9SAndroid Build Coastguard Worker        # instead of fmin() and fmax()
5775*da0073e9SAndroid Build Coastguard Worker        binary_fns = ['pow']
5776*da0073e9SAndroid Build Coastguard Worker        for fn in binary_fns:
5777*da0073e9SAndroid Build Coastguard Worker            test_dispatch(fn, lookup_c_equivalent_fn(fn) + '(', torch.double, binary=True)
5778*da0073e9SAndroid Build Coastguard Worker            test_dispatch(fn, lookup_c_equivalent_fn(fn) + 'f(', torch.float, binary=True)
5779*da0073e9SAndroid Build Coastguard Worker
5780*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(RUN_CUDA, 'This tests the CPU fuser')
5781*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle")
5782*da0073e9SAndroid Build Coastguard Worker    @enable_cpu_fuser
5783*da0073e9SAndroid Build Coastguard Worker    def test_fuser_double_literal_precision(self):
5784*da0073e9SAndroid Build Coastguard Worker        code = '''
5785*da0073e9SAndroid Build Coastguard Worker        graph(%2 : Float(*, *)):
5786*da0073e9SAndroid Build Coastguard Worker            %4 : int = prim::Constant[value=1]()
5787*da0073e9SAndroid Build Coastguard Worker            %3 : float = prim::Constant[value=1.282549830161864]()
5788*da0073e9SAndroid Build Coastguard Worker            %5 : Float(*, *) = aten::add(%2, %3, %4)
5789*da0073e9SAndroid Build Coastguard Worker            %1 : Float(*, *) = aten::relu(%5)
5790*da0073e9SAndroid Build Coastguard Worker            return (%1)
5791*da0073e9SAndroid Build Coastguard Worker        '''
5792*da0073e9SAndroid Build Coastguard Worker
5793*da0073e9SAndroid Build Coastguard Worker        graph = parse_ir(code)
5794*da0073e9SAndroid Build Coastguard Worker        code = torch._C._jit_fuser_get_fused_kernel_code(graph, [torch.rand(3, 4)])
5795*da0073e9SAndroid Build Coastguard Worker        FileCheck().check('1.282549830161864').run(code)
5796*da0073e9SAndroid Build Coastguard Worker
5797*da0073e9SAndroid Build Coastguard Worker    def test_fuser_multiple_blocks(self):
5798*da0073e9SAndroid Build Coastguard Worker        cu = torch.jit.CompilationUnit('''
5799*da0073e9SAndroid Build Coastguard Worker        def test_fuser_multiple_blocks(this, that, theother, meme):
5800*da0073e9SAndroid Build Coastguard Worker            i = 0
5801*da0073e9SAndroid Build Coastguard Worker            while i < 20:
5802*da0073e9SAndroid Build Coastguard Worker                this = torch.cat([this, meme], dim=0)
5803*da0073e9SAndroid Build Coastguard Worker                that = torch.cat([that, meme], dim=0)
5804*da0073e9SAndroid Build Coastguard Worker                theother = torch.cat([theother, meme], dim=0)
5805*da0073e9SAndroid Build Coastguard Worker                i = i + 1
5806*da0073e9SAndroid Build Coastguard Worker            return this, that, theother
5807*da0073e9SAndroid Build Coastguard Worker        ''')
5808*da0073e9SAndroid Build Coastguard Worker
5809*da0073e9SAndroid Build Coastguard Worker        inputs = [torch.ones(0, 10, 10)] * 3
5810*da0073e9SAndroid Build Coastguard Worker        inputs += [torch.ones(1, 10, 10)]
5811*da0073e9SAndroid Build Coastguard Worker        outputs = [torch.ones(20, 10, 10)] * 3
5812*da0073e9SAndroid Build Coastguard Worker
5813*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cu.test_fuser_multiple_blocks(*inputs), outputs)
5814*da0073e9SAndroid Build Coastguard Worker
5815*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("RuntimeError: VariableType::ID() not implemented")
5816*da0073e9SAndroid Build Coastguard Worker    def test_cast(self):
5817*da0073e9SAndroid Build Coastguard Worker        script = '''
5818*da0073e9SAndroid Build Coastguard Worker        def to_int(x):
5819*da0073e9SAndroid Build Coastguard Worker            return int(x)
5820*da0073e9SAndroid Build Coastguard Worker        '''
5821*da0073e9SAndroid Build Coastguard Worker        x = Variable(torch.FloatTensor([1.1, 2.3]), requires_grad=True)
5822*da0073e9SAndroid Build Coastguard Worker        out = Variable(torch.IntTensor([1, 2]), requires_grad=True)
5823*da0073e9SAndroid Build Coastguard Worker        self.checkScript(script, [x], optimize=True, outputs=[out], func='to_int')
5824*da0073e9SAndroid Build Coastguard Worker
5825*da0073e9SAndroid Build Coastguard Worker    def test_str_cast(self):
5826*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
5827*da0073e9SAndroid Build Coastguard Worker        def to_str(x):
5828*da0073e9SAndroid Build Coastguard Worker            # type: (int) -> str
5829*da0073e9SAndroid Build Coastguard Worker            return str((x, x))
5830*da0073e9SAndroid Build Coastguard Worker
5831*da0073e9SAndroid Build Coastguard Worker        self.assertEqual("(1, 1)", to_str(1))
5832*da0073e9SAndroid Build Coastguard Worker
5833*da0073e9SAndroid Build Coastguard Worker    def test_int_cast(self):
5834*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
5835*da0073e9SAndroid Build Coastguard Worker        def to_int(x):
5836*da0073e9SAndroid Build Coastguard Worker            # type: (str) -> int
5837*da0073e9SAndroid Build Coastguard Worker            return int(x)
5838*da0073e9SAndroid Build Coastguard Worker
5839*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(5, to_int('5'))
5840*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(-5, to_int('-5'))
5841*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(2147483647, to_int('2147483647'))
5842*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(-2147483648, to_int('-2147483648'))
5843*da0073e9SAndroid Build Coastguard Worker
5844*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "invalid literal for int()"):
5845*da0073e9SAndroid Build Coastguard Worker            to_int('0x20')
5846*da0073e9SAndroid Build Coastguard Worker
5847*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "invalid literal for int()"):
5848*da0073e9SAndroid Build Coastguard Worker            to_int('0b0001')
5849*da0073e9SAndroid Build Coastguard Worker
5850*da0073e9SAndroid Build Coastguard Worker    def test_python_frontend(self):
5851*da0073e9SAndroid Build Coastguard Worker        def fn(x, y, z):
5852*da0073e9SAndroid Build Coastguard Worker            q = None
5853*da0073e9SAndroid Build Coastguard Worker            q = x + y - z.sigmoid()
5854*da0073e9SAndroid Build Coastguard Worker            print(q)
5855*da0073e9SAndroid Build Coastguard Worker            w = -z
5856*da0073e9SAndroid Build Coastguard Worker            if not x and not y and z:
5857*da0073e9SAndroid Build Coastguard Worker                m = x if not z else y
5858*da0073e9SAndroid Build Coastguard Worker            while x < y > z:
5859*da0073e9SAndroid Build Coastguard Worker                q = x
5860*da0073e9SAndroid Build Coastguard Worker            assert 1 == 1, "hello"
5861*da0073e9SAndroid Build Coastguard Worker            return x
5862*da0073e9SAndroid Build Coastguard Worker
5863*da0073e9SAndroid Build Coastguard Worker        ast = torch.jit.frontend.get_jit_def(fn, fn.__name__)
5864*da0073e9SAndroid Build Coastguard Worker        self.assertExpected(str(ast))
5865*da0073e9SAndroid Build Coastguard Worker
5866*da0073e9SAndroid Build Coastguard Worker    def test_python_frontend_source_range(self):
5867*da0073e9SAndroid Build Coastguard Worker        def fn():
5868*da0073e9SAndroid Build Coastguard Worker            raise Exception("hello")  # noqa: TRY002
5869*da0073e9SAndroid Build Coastguard Worker        ast = torch.jit.frontend.get_jit_def(fn, fn.__name__)
5870*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("SourceRange at:") \
5871*da0073e9SAndroid Build Coastguard Worker                   .check("def fn():") \
5872*da0073e9SAndroid Build Coastguard Worker                   .check("~~~~~~~~~") \
5873*da0073e9SAndroid Build Coastguard Worker                   .check('raise Exception("hello")') \
5874*da0073e9SAndroid Build Coastguard Worker                   .check('~~~~~~~~~~~~~~~~~ <--- HERE') \
5875*da0073e9SAndroid Build Coastguard Worker                   .run(str(ast.range()))
5876*da0073e9SAndroid Build Coastguard Worker
5877*da0073e9SAndroid Build Coastguard Worker    def test_python_frontend_py3(self):
5878*da0073e9SAndroid Build Coastguard Worker        def fn():
5879*da0073e9SAndroid Build Coastguard Worker            raise Exception("hello")  # noqa: TRY002
5880*da0073e9SAndroid Build Coastguard Worker        ast = torch.jit.frontend.get_jit_def(fn, fn.__name__)
5881*da0073e9SAndroid Build Coastguard Worker        self.assertExpected(str(ast))
5882*da0073e9SAndroid Build Coastguard Worker
5883*da0073e9SAndroid Build Coastguard Worker    def _make_scalar_vars(self, arr, dtype):
5884*da0073e9SAndroid Build Coastguard Worker        return [torch.tensor(val, dtype=dtype) for val in arr]
5885*da0073e9SAndroid Build Coastguard Worker
5886*da0073e9SAndroid Build Coastguard Worker
5887*da0073e9SAndroid Build Coastguard Worker    def test_string_print(self):
5888*da0073e9SAndroid Build Coastguard Worker        def func(a):
5889*da0073e9SAndroid Build Coastguard Worker            print(a, "a" 'b' '''c''' """d""", 2, 1.5)
5890*da0073e9SAndroid Build Coastguard Worker            return a
5891*da0073e9SAndroid Build Coastguard Worker
5892*da0073e9SAndroid Build Coastguard Worker        inputs = self._make_scalar_vars([1], torch.int64)
5893*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func, inputs, capture_output=True)
5894*da0073e9SAndroid Build Coastguard Worker
5895*da0073e9SAndroid Build Coastguard Worker    def test_while(self):
5896*da0073e9SAndroid Build Coastguard Worker        def func(a, b, max):
5897*da0073e9SAndroid Build Coastguard Worker            while bool(a < max):
5898*da0073e9SAndroid Build Coastguard Worker                a = a + 1
5899*da0073e9SAndroid Build Coastguard Worker                b = b + 1
5900*da0073e9SAndroid Build Coastguard Worker            c = a + b
5901*da0073e9SAndroid Build Coastguard Worker            return c
5902*da0073e9SAndroid Build Coastguard Worker
5903*da0073e9SAndroid Build Coastguard Worker        inputs = self._make_scalar_vars([1, 1, 10], torch.int64)
5904*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func, inputs, optimize=True)
5905*da0073e9SAndroid Build Coastguard Worker
5906*da0073e9SAndroid Build Coastguard Worker    def test_fibb(self):
5907*da0073e9SAndroid Build Coastguard Worker        def func(lim):
5908*da0073e9SAndroid Build Coastguard Worker            first = 1
5909*da0073e9SAndroid Build Coastguard Worker            second = 1
5910*da0073e9SAndroid Build Coastguard Worker            i = 1
5911*da0073e9SAndroid Build Coastguard Worker            somenum = 5
5912*da0073e9SAndroid Build Coastguard Worker            dontmutateme = 3
5913*da0073e9SAndroid Build Coastguard Worker            third = 0
5914*da0073e9SAndroid Build Coastguard Worker            while bool(i < lim):
5915*da0073e9SAndroid Build Coastguard Worker                third = first + second
5916*da0073e9SAndroid Build Coastguard Worker                first = second
5917*da0073e9SAndroid Build Coastguard Worker                second = third
5918*da0073e9SAndroid Build Coastguard Worker                j = 0
5919*da0073e9SAndroid Build Coastguard Worker                while j < 10:
5920*da0073e9SAndroid Build Coastguard Worker                    somenum = somenum * 2
5921*da0073e9SAndroid Build Coastguard Worker                    j = j + 1
5922*da0073e9SAndroid Build Coastguard Worker                i = i + j
5923*da0073e9SAndroid Build Coastguard Worker                i = i + dontmutateme
5924*da0073e9SAndroid Build Coastguard Worker
5925*da0073e9SAndroid Build Coastguard Worker            st = second + third
5926*da0073e9SAndroid Build Coastguard Worker            fs = first + second
5927*da0073e9SAndroid Build Coastguard Worker            return third, st, fs
5928*da0073e9SAndroid Build Coastguard Worker
5929*da0073e9SAndroid Build Coastguard Worker        inputs = self._make_scalar_vars([10], torch.int64)
5930*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func, inputs, optimize=True)
5931*da0073e9SAndroid Build Coastguard Worker
5932*da0073e9SAndroid Build Coastguard Worker    def test_fibb_totally_better(self):
5933*da0073e9SAndroid Build Coastguard Worker        def fib(x):
5934*da0073e9SAndroid Build Coastguard Worker            # type: (int) -> int
5935*da0073e9SAndroid Build Coastguard Worker            prev = 1
5936*da0073e9SAndroid Build Coastguard Worker            v = 1
5937*da0073e9SAndroid Build Coastguard Worker            for i in range(0, x):
5938*da0073e9SAndroid Build Coastguard Worker                save = v
5939*da0073e9SAndroid Build Coastguard Worker                v = v + prev
5940*da0073e9SAndroid Build Coastguard Worker                prev = save
5941*da0073e9SAndroid Build Coastguard Worker            return v
5942*da0073e9SAndroid Build Coastguard Worker
5943*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fib, (10,))
5944*da0073e9SAndroid Build Coastguard Worker
5945*da0073e9SAndroid Build Coastguard Worker    def test_if(self):
5946*da0073e9SAndroid Build Coastguard Worker        def func(a, b):
5947*da0073e9SAndroid Build Coastguard Worker            # type: (int, int) -> int
5948*da0073e9SAndroid Build Coastguard Worker            d = 3
5949*da0073e9SAndroid Build Coastguard Worker            if bool(a > 10):
5950*da0073e9SAndroid Build Coastguard Worker                a = 3 + d
5951*da0073e9SAndroid Build Coastguard Worker            else:
5952*da0073e9SAndroid Build Coastguard Worker                b = 3 + d
5953*da0073e9SAndroid Build Coastguard Worker                d = 4
5954*da0073e9SAndroid Build Coastguard Worker            c = a + b
5955*da0073e9SAndroid Build Coastguard Worker            return c
5956*da0073e9SAndroid Build Coastguard Worker
5957*da0073e9SAndroid Build Coastguard Worker        inputs = self._make_scalar_vars([1, -1], torch.int64)
5958*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func, inputs, optimize=True)
5959*da0073e9SAndroid Build Coastguard Worker
5960*da0073e9SAndroid Build Coastguard Worker    def test_if_for_in_range(self):
5961*da0073e9SAndroid Build Coastguard Worker        def func(a, b):
5962*da0073e9SAndroid Build Coastguard Worker            # type: (int, int) -> int
5963*da0073e9SAndroid Build Coastguard Worker            d = 3
5964*da0073e9SAndroid Build Coastguard Worker            for _ in range(20):
5965*da0073e9SAndroid Build Coastguard Worker                if bool(a > 10):
5966*da0073e9SAndroid Build Coastguard Worker                    a = 3 + d
5967*da0073e9SAndroid Build Coastguard Worker                else:
5968*da0073e9SAndroid Build Coastguard Worker                    b = 3 + d
5969*da0073e9SAndroid Build Coastguard Worker                    d = 4
5970*da0073e9SAndroid Build Coastguard Worker                c = a + b
5971*da0073e9SAndroid Build Coastguard Worker            return d
5972*da0073e9SAndroid Build Coastguard Worker        inputs = self._make_scalar_vars([1, -1], torch.int64)
5973*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func, inputs, optimize=True)
5974*da0073e9SAndroid Build Coastguard Worker
5975*da0073e9SAndroid Build Coastguard Worker    def test_if_noelse(self):
5976*da0073e9SAndroid Build Coastguard Worker        def func(a, b):
5977*da0073e9SAndroid Build Coastguard Worker            if bool(a > 10):
5978*da0073e9SAndroid Build Coastguard Worker                a = 3 + b
5979*da0073e9SAndroid Build Coastguard Worker            c = a + b
5980*da0073e9SAndroid Build Coastguard Worker            return c
5981*da0073e9SAndroid Build Coastguard Worker
5982*da0073e9SAndroid Build Coastguard Worker        inputs = self._make_scalar_vars([-1, 1], torch.int64)
5983*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func, inputs, optimize=True)
5984*da0073e9SAndroid Build Coastguard Worker
5985*da0073e9SAndroid Build Coastguard Worker    def test_if_is_none_dispatch(self):
5986*da0073e9SAndroid Build Coastguard Worker
5987*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
5988*da0073e9SAndroid Build Coastguard Worker        def test_lhs_none_rhs_none():
5989*da0073e9SAndroid Build Coastguard Worker            # LHS, RHS both alwaysNone, dispatch always_none_branch
5990*da0073e9SAndroid Build Coastguard Worker            # only emit one prim::Constant
5991*da0073e9SAndroid Build Coastguard Worker            if None is None:
5992*da0073e9SAndroid Build Coastguard Worker                return 1
5993*da0073e9SAndroid Build Coastguard Worker            elif None is not None:
5994*da0073e9SAndroid Build Coastguard Worker                return 2
5995*da0073e9SAndroid Build Coastguard Worker            else:
5996*da0073e9SAndroid Build Coastguard Worker                return 3
5997*da0073e9SAndroid Build Coastguard Worker
5998*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(str(test_lhs_none_rhs_none.graph).count(': int = prim::Constant') == 1)
5999*da0073e9SAndroid Build Coastguard Worker
6000*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6001*da0073e9SAndroid Build Coastguard Worker        def test_lhs_opt_rhs_none(lhs=None):
6002*da0073e9SAndroid Build Coastguard Worker            # type: (Optional[Tensor]) -> int
6003*da0073e9SAndroid Build Coastguard Worker            # LHS maybeNone: emit normal if stmt that contains 3 constants
6004*da0073e9SAndroid Build Coastguard Worker            if lhs is not None:
6005*da0073e9SAndroid Build Coastguard Worker                return 2
6006*da0073e9SAndroid Build Coastguard Worker            elif lhs is None:
6007*da0073e9SAndroid Build Coastguard Worker                return 1
6008*da0073e9SAndroid Build Coastguard Worker            else:
6009*da0073e9SAndroid Build Coastguard Worker                return 3
6010*da0073e9SAndroid Build Coastguard Worker
6011*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(str(test_lhs_opt_rhs_none.graph).count(': int = prim::Constant') == 3)
6012*da0073e9SAndroid Build Coastguard Worker
6013*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6014*da0073e9SAndroid Build Coastguard Worker        def test_lhs_none_rhs_opt(rhs=None):
6015*da0073e9SAndroid Build Coastguard Worker            # type: (Optional[Tensor]) -> int
6016*da0073e9SAndroid Build Coastguard Worker            # RHS maybeNone, emit normal if stmt that contains 3 constants
6017*da0073e9SAndroid Build Coastguard Worker            if None is rhs:
6018*da0073e9SAndroid Build Coastguard Worker                return 1
6019*da0073e9SAndroid Build Coastguard Worker            elif None is not rhs:
6020*da0073e9SAndroid Build Coastguard Worker                return 2
6021*da0073e9SAndroid Build Coastguard Worker            else:
6022*da0073e9SAndroid Build Coastguard Worker                return 3
6023*da0073e9SAndroid Build Coastguard Worker
6024*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(str(test_lhs_opt_rhs_none.graph).count(': int = prim::Constant') == 3)
6025*da0073e9SAndroid Build Coastguard Worker
6026*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6027*da0073e9SAndroid Build Coastguard Worker        def test_lhs_never_rhs_none(lhs):
6028*da0073e9SAndroid Build Coastguard Worker            # LHS neverNone, RHS alwaysNone dispatch never_none_branch
6029*da0073e9SAndroid Build Coastguard Worker            # only emit one prim::Constant
6030*da0073e9SAndroid Build Coastguard Worker            if lhs is None:
6031*da0073e9SAndroid Build Coastguard Worker                return 1
6032*da0073e9SAndroid Build Coastguard Worker            elif lhs is not None:
6033*da0073e9SAndroid Build Coastguard Worker                return 2
6034*da0073e9SAndroid Build Coastguard Worker            else:
6035*da0073e9SAndroid Build Coastguard Worker                return 3
6036*da0073e9SAndroid Build Coastguard Worker
6037*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(str(test_lhs_never_rhs_none.graph).count(': int = prim::Constant') == 1)
6038*da0073e9SAndroid Build Coastguard Worker
6039*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6040*da0073e9SAndroid Build Coastguard Worker        def test_lhs_none_rhs_never(rhs):
6041*da0073e9SAndroid Build Coastguard Worker            # LHS alwaysNone, RHS neverNone dispatch never_none_branch
6042*da0073e9SAndroid Build Coastguard Worker            # only emit one prim::Constant
6043*da0073e9SAndroid Build Coastguard Worker            if None is rhs:
6044*da0073e9SAndroid Build Coastguard Worker                return 1
6045*da0073e9SAndroid Build Coastguard Worker            elif None is not rhs:
6046*da0073e9SAndroid Build Coastguard Worker                return 2
6047*da0073e9SAndroid Build Coastguard Worker            else:
6048*da0073e9SAndroid Build Coastguard Worker                return 3
6049*da0073e9SAndroid Build Coastguard Worker
6050*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(str(test_lhs_none_rhs_never.graph).count(': int = prim::Constant') == 1)
6051*da0073e9SAndroid Build Coastguard Worker
6052*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6053*da0073e9SAndroid Build Coastguard Worker        def test_bool_arith_and(lhs):
6054*da0073e9SAndroid Build Coastguard Worker            if lhs is None and lhs is not None:
6055*da0073e9SAndroid Build Coastguard Worker                return 1
6056*da0073e9SAndroid Build Coastguard Worker            else:
6057*da0073e9SAndroid Build Coastguard Worker                return 2
6058*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(test_bool_arith_and(torch.zeros(3)), 2)
6059*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(str(test_bool_arith_and.graph).count('if') == 0)
6060*da0073e9SAndroid Build Coastguard Worker
6061*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6062*da0073e9SAndroid Build Coastguard Worker        def test_bool_arith_or(lhs):
6063*da0073e9SAndroid Build Coastguard Worker            if lhs is None or lhs is not None:
6064*da0073e9SAndroid Build Coastguard Worker                return 1
6065*da0073e9SAndroid Build Coastguard Worker            else:
6066*da0073e9SAndroid Build Coastguard Worker                return 2
6067*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(test_bool_arith_or(torch.zeros(3)), 1)
6068*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(str(test_bool_arith_or.graph).count('if') == 0)
6069*da0073e9SAndroid Build Coastguard Worker
6070*da0073e9SAndroid Build Coastguard Worker
6071*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6072*da0073e9SAndroid Build Coastguard Worker        def test_bool_arith_not(lhs):
6073*da0073e9SAndroid Build Coastguard Worker            if lhs is not None:
6074*da0073e9SAndroid Build Coastguard Worker                return 1
6075*da0073e9SAndroid Build Coastguard Worker            else:
6076*da0073e9SAndroid Build Coastguard Worker                return 2
6077*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(test_bool_arith_not(torch.zeros(3)), 1)
6078*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(str(test_bool_arith_not.graph).count('if') == 0)
6079*da0073e9SAndroid Build Coastguard Worker
6080*da0073e9SAndroid Build Coastguard Worker    def test_conditional_casting(self):
6081*da0073e9SAndroid Build Coastguard Worker        def test_bool_cast_tensor(x):
6082*da0073e9SAndroid Build Coastguard Worker            if x:
6083*da0073e9SAndroid Build Coastguard Worker                return 1
6084*da0073e9SAndroid Build Coastguard Worker            else:
6085*da0073e9SAndroid Build Coastguard Worker                return 0
6086*da0073e9SAndroid Build Coastguard Worker
6087*da0073e9SAndroid Build Coastguard Worker        for make_one_dim in [True, False]:
6088*da0073e9SAndroid Build Coastguard Worker            for inp_val in [0.1, 0.0, -0.0, -0.1, -1, 0, 1]:
6089*da0073e9SAndroid Build Coastguard Worker                inp_val = [inp_val] if make_one_dim else inp_val
6090*da0073e9SAndroid Build Coastguard Worker                self.checkScript(test_bool_cast_tensor, (torch.tensor(inp_val),))
6091*da0073e9SAndroid Build Coastguard Worker
6092*da0073e9SAndroid Build Coastguard Worker        self.checkScriptRaisesRegex(test_bool_cast_tensor, (torch.tensor([1, 1]),), Exception,
6093*da0073e9SAndroid Build Coastguard Worker                                    "Boolean value of Tensor with more than one value")
6094*da0073e9SAndroid Build Coastguard Worker
6095*da0073e9SAndroid Build Coastguard Worker        def test_not_cast(x):
6096*da0073e9SAndroid Build Coastguard Worker            if not x:
6097*da0073e9SAndroid Build Coastguard Worker                return 1
6098*da0073e9SAndroid Build Coastguard Worker            else:
6099*da0073e9SAndroid Build Coastguard Worker                return 0
6100*da0073e9SAndroid Build Coastguard Worker
6101*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_not_cast, (torch.tensor(1),))
6102*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_not_cast, (torch.tensor(0),))
6103*da0073e9SAndroid Build Coastguard Worker
6104*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"Could not cast value of type Tuple\[Tensor, Tensor\]"):  # noqa: W605
6105*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
6106*da0073e9SAndroid Build Coastguard Worker            def test_mult(x, y):
6107*da0073e9SAndroid Build Coastguard Worker                return not (x, y)
6108*da0073e9SAndroid Build Coastguard Worker
6109*da0073e9SAndroid Build Coastguard Worker        def test_cast_int(x):
6110*da0073e9SAndroid Build Coastguard Worker            # type: (int) -> int
6111*da0073e9SAndroid Build Coastguard Worker            if x:
6112*da0073e9SAndroid Build Coastguard Worker                return 1
6113*da0073e9SAndroid Build Coastguard Worker            else:
6114*da0073e9SAndroid Build Coastguard Worker                return 0
6115*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_cast_int, (1,))
6116*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_cast_int, (0,))
6117*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_cast_int, (-1,))
6118*da0073e9SAndroid Build Coastguard Worker
6119*da0073e9SAndroid Build Coastguard Worker        def test_cast_float(x):
6120*da0073e9SAndroid Build Coastguard Worker            # type: (float) -> int
6121*da0073e9SAndroid Build Coastguard Worker            if x:
6122*da0073e9SAndroid Build Coastguard Worker                return 1
6123*da0073e9SAndroid Build Coastguard Worker            else:
6124*da0073e9SAndroid Build Coastguard Worker                return 0
6125*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_cast_float, (1.,))
6126*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_cast_float, (0.,))
6127*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_cast_float, (-1.,))
6128*da0073e9SAndroid Build Coastguard Worker
6129*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"Could not cast value of type Tuple\[int, int\] to bool"):  # noqa: W605
6130*da0073e9SAndroid Build Coastguard Worker
6131*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
6132*da0073e9SAndroid Build Coastguard Worker            def test_bad_conditional(x):
6133*da0073e9SAndroid Build Coastguard Worker                if (1, 2):  # noqa: F634
6134*da0073e9SAndroid Build Coastguard Worker                    return
6135*da0073e9SAndroid Build Coastguard Worker                else:
6136*da0073e9SAndroid Build Coastguard Worker                    return 0
6137*da0073e9SAndroid Build Coastguard Worker
6138*da0073e9SAndroid Build Coastguard Worker    def test_while_nonexistent_value(self):
6139*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "undefined value x"):
6140*da0073e9SAndroid Build Coastguard Worker            torch.jit.CompilationUnit('''
6141*da0073e9SAndroid Build Coastguard Worker            def test_while(a, b):
6142*da0073e9SAndroid Build Coastguard Worker                while bool(a < 10):
6143*da0073e9SAndroid Build Coastguard Worker                    a = a + x
6144*da0073e9SAndroid Build Coastguard Worker                    b = b + 1
6145*da0073e9SAndroid Build Coastguard Worker                return a + b
6146*da0073e9SAndroid Build Coastguard Worker            ''')
6147*da0073e9SAndroid Build Coastguard Worker
6148*da0073e9SAndroid Build Coastguard Worker    def test_while_nonexistent_cond_value(self):
6149*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "undefined value x"):
6150*da0073e9SAndroid Build Coastguard Worker            torch.jit.CompilationUnit('''
6151*da0073e9SAndroid Build Coastguard Worker            def test_while(a, b):
6152*da0073e9SAndroid Build Coastguard Worker                while a < x:
6153*da0073e9SAndroid Build Coastguard Worker                    a = a + 1
6154*da0073e9SAndroid Build Coastguard Worker                    b = b + 1
6155*da0073e9SAndroid Build Coastguard Worker                return a + b
6156*da0073e9SAndroid Build Coastguard Worker            ''')
6157*da0073e9SAndroid Build Coastguard Worker
6158*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6159*da0073e9SAndroid Build Coastguard Worker        def test_ternary(x):
6160*da0073e9SAndroid Build Coastguard Worker            # type: (Optional[int]) -> int
6161*da0073e9SAndroid Build Coastguard Worker            x = x if x is not None else 2
6162*da0073e9SAndroid Build Coastguard Worker            return x
6163*da0073e9SAndroid Build Coastguard Worker
6164*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6165*da0073e9SAndroid Build Coastguard Worker        def test_not_none(x):
6166*da0073e9SAndroid Build Coastguard Worker            # type: (Optional[int]) -> None
6167*da0073e9SAndroid Build Coastguard Worker            if x is not None:
6168*da0073e9SAndroid Build Coastguard Worker                print(x + 1)
6169*da0073e9SAndroid Build Coastguard Worker
6170*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6171*da0073e9SAndroid Build Coastguard Worker        def test_and(x, y):
6172*da0073e9SAndroid Build Coastguard Worker            # type: (Optional[int], Optional[int]) -> None
6173*da0073e9SAndroid Build Coastguard Worker            if x is not None and y is not None:
6174*da0073e9SAndroid Build Coastguard Worker                print(x + y)
6175*da0073e9SAndroid Build Coastguard Worker
6176*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6177*da0073e9SAndroid Build Coastguard Worker        def test_not(x, y):
6178*da0073e9SAndroid Build Coastguard Worker            # type: (Optional[int], Optional[int]) -> None
6179*da0073e9SAndroid Build Coastguard Worker            if not (x is not None and y is not None):
6180*da0073e9SAndroid Build Coastguard Worker                pass
6181*da0073e9SAndroid Build Coastguard Worker            else:
6182*da0073e9SAndroid Build Coastguard Worker                print(x + y)
6183*da0073e9SAndroid Build Coastguard Worker
6184*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6185*da0073e9SAndroid Build Coastguard Worker        def test_bool_expression(x):
6186*da0073e9SAndroid Build Coastguard Worker            # type: (Optional[int]) -> None
6187*da0073e9SAndroid Build Coastguard Worker            if x is not None and x < 2:
6188*da0073e9SAndroid Build Coastguard Worker                print(x + 1)
6189*da0073e9SAndroid Build Coastguard Worker
6190*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6191*da0073e9SAndroid Build Coastguard Worker        def test_nested_bool_expression(x, y):
6192*da0073e9SAndroid Build Coastguard Worker            # type: (Optional[int], Optional[int]) -> int
6193*da0073e9SAndroid Build Coastguard Worker            if x is not None and x < 2 and y is not None:
6194*da0073e9SAndroid Build Coastguard Worker                x = x + y
6195*da0073e9SAndroid Build Coastguard Worker            else:
6196*da0073e9SAndroid Build Coastguard Worker                x = 5
6197*da0073e9SAndroid Build Coastguard Worker            return x + 2
6198*da0073e9SAndroid Build Coastguard Worker
6199*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6200*da0073e9SAndroid Build Coastguard Worker        def test_or(x, y):
6201*da0073e9SAndroid Build Coastguard Worker            # type: (Optional[int], Optional[int]) -> None
6202*da0073e9SAndroid Build Coastguard Worker            if y is None or x is None:
6203*da0073e9SAndroid Build Coastguard Worker                pass
6204*da0073e9SAndroid Build Coastguard Worker            else:
6205*da0073e9SAndroid Build Coastguard Worker                print(x + y)
6206*da0073e9SAndroid Build Coastguard Worker
6207*da0073e9SAndroid Build Coastguard Worker        # backwards compatibility
6208*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6209*da0073e9SAndroid Build Coastguard Worker        def test_manual_unwrap_opt(x):
6210*da0073e9SAndroid Build Coastguard Worker            # type: (Optional[int]) -> int
6211*da0073e9SAndroid Build Coastguard Worker            if x is None:
6212*da0073e9SAndroid Build Coastguard Worker                x = 1
6213*da0073e9SAndroid Build Coastguard Worker            else:
6214*da0073e9SAndroid Build Coastguard Worker                x = torch.jit._unwrap_optional(x)
6215*da0073e9SAndroid Build Coastguard Worker            return x  # noqa: T484
6216*da0073e9SAndroid Build Coastguard Worker
6217*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):
6218*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
6219*da0073e9SAndroid Build Coastguard Worker            def or_error(x, y):
6220*da0073e9SAndroid Build Coastguard Worker                # type: (Optional[int], Optional[int]) -> None
6221*da0073e9SAndroid Build Coastguard Worker                if x is None or y is None:
6222*da0073e9SAndroid Build Coastguard Worker                    print(x + y)  # noqa: T484
6223*da0073e9SAndroid Build Coastguard Worker
6224*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):
6225*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
6226*da0073e9SAndroid Build Coastguard Worker            def and_error(x, y):
6227*da0073e9SAndroid Build Coastguard Worker                # type: (Optional[int], Optional[int]) -> None
6228*da0073e9SAndroid Build Coastguard Worker                if x is None and y is None:
6229*da0073e9SAndroid Build Coastguard Worker                    pass
6230*da0073e9SAndroid Build Coastguard Worker                else:
6231*da0073e9SAndroid Build Coastguard Worker                    print(x + y)  # noqa: T484
6232*da0073e9SAndroid Build Coastguard Worker
6233*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):
6234*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
6235*da0073e9SAndroid Build Coastguard Worker            def named_var(x):
6236*da0073e9SAndroid Build Coastguard Worker                # type: (Optional[int]) -> None
6237*da0073e9SAndroid Build Coastguard Worker                x_none = x is not None
6238*da0073e9SAndroid Build Coastguard Worker                if x_none:
6239*da0073e9SAndroid Build Coastguard Worker                    print(x + 1)  # noqa: T484
6240*da0073e9SAndroid Build Coastguard Worker
6241*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):
6242*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
6243*da0073e9SAndroid Build Coastguard Worker            def named_var_and(x, y):
6244*da0073e9SAndroid Build Coastguard Worker                # type: (Optional[int], Optional[int]) -> None
6245*da0073e9SAndroid Build Coastguard Worker                x_none = x is not None
6246*da0073e9SAndroid Build Coastguard Worker                if y is not None and x_none:
6247*da0073e9SAndroid Build Coastguard Worker                    print(x + y)  # noqa: T484
6248*da0073e9SAndroid Build Coastguard Worker
6249*da0073e9SAndroid Build Coastguard Worker    def test_assertion_optional_refinement(self):
6250*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6251*da0073e9SAndroid Build Coastguard Worker        def test(x, y):
6252*da0073e9SAndroid Build Coastguard Worker            # type: (Optional[int], Optional[int]) -> int
6253*da0073e9SAndroid Build Coastguard Worker            assert x is not None and y is not None
6254*da0073e9SAndroid Build Coastguard Worker            return x + y
6255*da0073e9SAndroid Build Coastguard Worker
6256*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(test(2, 2), 4)
6257*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(Exception, ""):
6258*da0073e9SAndroid Build Coastguard Worker            test(1, None)
6259*da0073e9SAndroid Build Coastguard Worker
6260*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "the current version of Profiler doesn't profile/specialize Optionals")
6261*da0073e9SAndroid Build Coastguard Worker    def test_optional_tensor(self):
6262*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6263*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
6264*da0073e9SAndroid Build Coastguard Worker            # type: (Optional[Tensor], int) -> int
6265*da0073e9SAndroid Build Coastguard Worker            if x is None:
6266*da0073e9SAndroid Build Coastguard Worker                return y
6267*da0073e9SAndroid Build Coastguard Worker            else:
6268*da0073e9SAndroid Build Coastguard Worker                return 0
6269*da0073e9SAndroid Build Coastguard Worker
6270*da0073e9SAndroid Build Coastguard Worker        res = fn(None, 1)
6271*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, 1)
6272*da0073e9SAndroid Build Coastguard Worker        g = torch.jit.last_executed_optimized_graph()
6273*da0073e9SAndroid Build Coastguard Worker        first_input = next(g.inputs())
6274*da0073e9SAndroid Build Coastguard Worker        # check if input is disconnected
6275*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(first_input.type().kind(), 'OptionalType')
6276*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(first_input.uses(), [])
6277*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(1)
6278*da0073e9SAndroid Build Coastguard Worker        res = fn(t, 1)
6279*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, 0)
6280*da0073e9SAndroid Build Coastguard Worker        g = torch.jit.last_executed_optimized_graph()
6281*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(next(g.inputs()).type().kind(), 'TensorType')
6282*da0073e9SAndroid Build Coastguard Worker
6283*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6284*da0073e9SAndroid Build Coastguard Worker        def fn(x, y, b):
6285*da0073e9SAndroid Build Coastguard Worker            # type: (Optional[Tensor], Tensor, bool) -> Tensor
6286*da0073e9SAndroid Build Coastguard Worker            if b:
6287*da0073e9SAndroid Build Coastguard Worker                res = y
6288*da0073e9SAndroid Build Coastguard Worker            else:
6289*da0073e9SAndroid Build Coastguard Worker                res = torch.jit._unwrap_optional(x)
6290*da0073e9SAndroid Build Coastguard Worker            return res
6291*da0073e9SAndroid Build Coastguard Worker
6292*da0073e9SAndroid Build Coastguard Worker        t2 = torch.zeros(1)
6293*da0073e9SAndroid Build Coastguard Worker        res = fn(t, t2, True)
6294*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, t2)
6295*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Unwrapping null optional"):
6296*da0073e9SAndroid Build Coastguard Worker            res = fn(None, t2, False)
6297*da0073e9SAndroid Build Coastguard Worker        res = fn(None, t2, True)
6298*da0073e9SAndroid Build Coastguard Worker        g = torch.jit.last_executed_optimized_graph()
6299*da0073e9SAndroid Build Coastguard Worker        self.assertIn(next(g.outputs()).type().str(), ("Tensor", "Tensor(requires_grad=1)"))
6300*da0073e9SAndroid Build Coastguard Worker
6301*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "the current version of Profiler doesn't profile/specialize Optionals")
6302*da0073e9SAndroid Build Coastguard Worker    def test_optional_list(self):
6303*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6304*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
6305*da0073e9SAndroid Build Coastguard Worker            # type: (Optional[List[int]], int) -> int
6306*da0073e9SAndroid Build Coastguard Worker            if x is None:
6307*da0073e9SAndroid Build Coastguard Worker                return y
6308*da0073e9SAndroid Build Coastguard Worker            else:
6309*da0073e9SAndroid Build Coastguard Worker                res = 0
6310*da0073e9SAndroid Build Coastguard Worker                for d in x:
6311*da0073e9SAndroid Build Coastguard Worker                    res += d
6312*da0073e9SAndroid Build Coastguard Worker                return res
6313*da0073e9SAndroid Build Coastguard Worker
6314*da0073e9SAndroid Build Coastguard Worker        res = fn(None, 1)
6315*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, 1)
6316*da0073e9SAndroid Build Coastguard Worker        g = torch.jit.last_executed_optimized_graph()
6317*da0073e9SAndroid Build Coastguard Worker        first_input = next(g.inputs())
6318*da0073e9SAndroid Build Coastguard Worker        # check if input is disconnected
6319*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(first_input.type().kind(), 'OptionalType')
6320*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(first_input.uses(), [])
6321*da0073e9SAndroid Build Coastguard Worker        l = [2, 3]
6322*da0073e9SAndroid Build Coastguard Worker        res = fn(l, 1)
6323*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, 5)
6324*da0073e9SAndroid Build Coastguard Worker        g = torch.jit.last_executed_optimized_graph()
6325*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(next(g.inputs()).type().kind(), 'ListType')
6326*da0073e9SAndroid Build Coastguard Worker
6327*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6328*da0073e9SAndroid Build Coastguard Worker        def fn(x, y, b):
6329*da0073e9SAndroid Build Coastguard Worker            # type: (Optional[List[int]], List[int], bool) -> List[int]
6330*da0073e9SAndroid Build Coastguard Worker            if b:
6331*da0073e9SAndroid Build Coastguard Worker                l = torch.jit._unwrap_optional(x)
6332*da0073e9SAndroid Build Coastguard Worker            else:
6333*da0073e9SAndroid Build Coastguard Worker                l = y
6334*da0073e9SAndroid Build Coastguard Worker            return l
6335*da0073e9SAndroid Build Coastguard Worker
6336*da0073e9SAndroid Build Coastguard Worker        l2 = [0, 1]
6337*da0073e9SAndroid Build Coastguard Worker        res = fn(l, l2, True)
6338*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, l)
6339*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Unwrapping null optional"):
6340*da0073e9SAndroid Build Coastguard Worker            res = fn(None, l2, True)
6341*da0073e9SAndroid Build Coastguard Worker        res = fn(None, l2, False)
6342*da0073e9SAndroid Build Coastguard Worker        g = torch.jit.last_executed_optimized_graph()
6343*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(next(g.outputs()).type().str(), "int[]")
6344*da0073e9SAndroid Build Coastguard Worker
6345*da0073e9SAndroid Build Coastguard Worker    def test_alias_covariant_type_containers(self):
6346*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6347*da0073e9SAndroid Build Coastguard Worker        def foo(x):
6348*da0073e9SAndroid Build Coastguard Worker            # type: (bool)
6349*da0073e9SAndroid Build Coastguard Worker            if x:
6350*da0073e9SAndroid Build Coastguard Worker                a = (None,)
6351*da0073e9SAndroid Build Coastguard Worker            else:
6352*da0073e9SAndroid Build Coastguard Worker                a = ([],)
6353*da0073e9SAndroid Build Coastguard Worker            return a
6354*da0073e9SAndroid Build Coastguard Worker
6355*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6356*da0073e9SAndroid Build Coastguard Worker        def foo2(x, li):
6357*da0073e9SAndroid Build Coastguard Worker            # type: (bool, Tuple[Optional[List[Tensor]]])
6358*da0073e9SAndroid Build Coastguard Worker            if x:
6359*da0073e9SAndroid Build Coastguard Worker                li = (None,)
6360*da0073e9SAndroid Build Coastguard Worker            return li
6361*da0073e9SAndroid Build Coastguard Worker
6362*da0073e9SAndroid Build Coastguard Worker    def test_while_write_outer_then_read(self):
6363*da0073e9SAndroid Build Coastguard Worker        def func(a, b):
6364*da0073e9SAndroid Build Coastguard Worker            while bool(a < 10):
6365*da0073e9SAndroid Build Coastguard Worker                a = a + 1
6366*da0073e9SAndroid Build Coastguard Worker                b = a + 1
6367*da0073e9SAndroid Build Coastguard Worker            return a + b
6368*da0073e9SAndroid Build Coastguard Worker
6369*da0073e9SAndroid Build Coastguard Worker        inputs = self._make_scalar_vars([42, 1337], torch.int64)
6370*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func, inputs, optimize=True)
6371*da0073e9SAndroid Build Coastguard Worker
6372*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
6373*da0073e9SAndroid Build Coastguard Worker    def test_while_nest_if(self):
6374*da0073e9SAndroid Build Coastguard Worker        def func(a, b):
6375*da0073e9SAndroid Build Coastguard Worker            # type: (int, int) -> int
6376*da0073e9SAndroid Build Coastguard Worker            c = 0
6377*da0073e9SAndroid Build Coastguard Worker            while a < 10:
6378*da0073e9SAndroid Build Coastguard Worker                a = a + 1
6379*da0073e9SAndroid Build Coastguard Worker                b = b + 1
6380*da0073e9SAndroid Build Coastguard Worker                if a > b:
6381*da0073e9SAndroid Build Coastguard Worker                    c = -a
6382*da0073e9SAndroid Build Coastguard Worker                else:
6383*da0073e9SAndroid Build Coastguard Worker                    c = -b
6384*da0073e9SAndroid Build Coastguard Worker            return c + 1
6385*da0073e9SAndroid Build Coastguard Worker
6386*da0073e9SAndroid Build Coastguard Worker        inputs = self._make_scalar_vars([-1234, 4321], torch.int64)
6387*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func, inputs, optimize=True)
6388*da0073e9SAndroid Build Coastguard Worker
6389*da0073e9SAndroid Build Coastguard Worker    def test_divmod(self):
6390*da0073e9SAndroid Build Coastguard Worker        def func_int(a, b):
6391*da0073e9SAndroid Build Coastguard Worker            # type: (int, int) -> Tuple[int, int]
6392*da0073e9SAndroid Build Coastguard Worker            return divmod(a, b)
6393*da0073e9SAndroid Build Coastguard Worker
6394*da0073e9SAndroid Build Coastguard Worker        def func_float(a, b):
6395*da0073e9SAndroid Build Coastguard Worker            # type: (float, float) -> Tuple[float, float]
6396*da0073e9SAndroid Build Coastguard Worker            return divmod(a, b)
6397*da0073e9SAndroid Build Coastguard Worker
6398*da0073e9SAndroid Build Coastguard Worker        def func_int_float(a, b):
6399*da0073e9SAndroid Build Coastguard Worker            # type: (int, float) -> Tuple[float, float]
6400*da0073e9SAndroid Build Coastguard Worker            return divmod(a, b)
6401*da0073e9SAndroid Build Coastguard Worker
6402*da0073e9SAndroid Build Coastguard Worker        def func_float_int(a, b):
6403*da0073e9SAndroid Build Coastguard Worker            # type: (float, int) -> Tuple[float, float]
6404*da0073e9SAndroid Build Coastguard Worker            return divmod(a, b)
6405*da0073e9SAndroid Build Coastguard Worker
6406*da0073e9SAndroid Build Coastguard Worker        def divmod_test_iterator(func, num, den):
6407*da0073e9SAndroid Build Coastguard Worker            for i in num:
6408*da0073e9SAndroid Build Coastguard Worker                for j in den:
6409*da0073e9SAndroid Build Coastguard Worker                    self.checkScript(func, (i, j), frames_up=2)
6410*da0073e9SAndroid Build Coastguard Worker
6411*da0073e9SAndroid Build Coastguard Worker        num_int = [1024, -1024]
6412*da0073e9SAndroid Build Coastguard Worker        den_int = [10, -10]
6413*da0073e9SAndroid Build Coastguard Worker        num_float = [5.3, -5.3]
6414*da0073e9SAndroid Build Coastguard Worker        den_float = [2.0, -2.0]
6415*da0073e9SAndroid Build Coastguard Worker        divmod_test_iterator(func_int, num_int, den_int)
6416*da0073e9SAndroid Build Coastguard Worker        divmod_test_iterator(func_float, num_float, den_float)
6417*da0073e9SAndroid Build Coastguard Worker        divmod_test_iterator(func_int_float, num_int, den_float)
6418*da0073e9SAndroid Build Coastguard Worker        divmod_test_iterator(func_float_int, num_float, den_int)
6419*da0073e9SAndroid Build Coastguard Worker
6420*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError: integer division or modulo by zero"):
6421*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit(dedent(inspect.getsource(func_int)))
6422*da0073e9SAndroid Build Coastguard Worker            cu.func_int(1024, 0)
6423*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError: float divmod()"):
6424*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit(dedent(inspect.getsource(func_float)))
6425*da0073e9SAndroid Build Coastguard Worker            cu.func_float(5.3, 0.0)
6426*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError: float divmod()"):
6427*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit(dedent(inspect.getsource(func_int_float)))
6428*da0073e9SAndroid Build Coastguard Worker            cu.func_int_float(1024, 0.0)
6429*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError: float divmod()"):
6430*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit(dedent(inspect.getsource(func_float_int)))
6431*da0073e9SAndroid Build Coastguard Worker            cu.func_float_int(5.3, 0)
6432*da0073e9SAndroid Build Coastguard Worker
6433*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
6434*da0073e9SAndroid Build Coastguard Worker    def test_math_ops(self):
6435*da0073e9SAndroid Build Coastguard Worker        def checkMathWrap(func_name, num_args=1, is_float=True, **args):
6436*da0073e9SAndroid Build Coastguard Worker            if is_float:
6437*da0073e9SAndroid Build Coastguard Worker                checkMath(func_name, num_args, True, **args)
6438*da0073e9SAndroid Build Coastguard Worker                checkMath(func_name, num_args, False, **args)
6439*da0073e9SAndroid Build Coastguard Worker            else:
6440*da0073e9SAndroid Build Coastguard Worker                checkMath(func_name, num_args, is_float, **args)
6441*da0073e9SAndroid Build Coastguard Worker
6442*da0073e9SAndroid Build Coastguard Worker        inf = float("inf")
6443*da0073e9SAndroid Build Coastguard Worker        NaN = float("nan")
6444*da0073e9SAndroid Build Coastguard Worker        mx_int = 2**31 - 1
6445*da0073e9SAndroid Build Coastguard Worker        mn_int = -2**31
6446*da0073e9SAndroid Build Coastguard Worker        float_vals = ([inf, NaN, 0.0, 1.0, 2.2, -1.0, -0.0, -2.2, -inf, 1, 0, 2] +
6447*da0073e9SAndroid Build Coastguard Worker                      [10.0 ** i for i in range(5)] + [-(10.0 ** i) for i in range(5)])
6448*da0073e9SAndroid Build Coastguard Worker        int_vals = list(range(-5, 5, 1)) + [mx_int + 5, mx_int * 2, mn_int - 5, mn_int * 2]
6449*da0073e9SAndroid Build Coastguard Worker
6450*da0073e9SAndroid Build Coastguard Worker        def checkMath(func_name, num_args, is_float=True, ret_type="float", debug=False, vals=None, args_type=None):
6451*da0073e9SAndroid Build Coastguard Worker            funcs_template = dedent('''
6452*da0073e9SAndroid Build Coastguard Worker            def func(a, b):
6453*da0073e9SAndroid Build Coastguard Worker                # type: {args_type} -> {ret_type}
6454*da0073e9SAndroid Build Coastguard Worker                return math.{func}({args})
6455*da0073e9SAndroid Build Coastguard Worker            ''')
6456*da0073e9SAndroid Build Coastguard Worker            if num_args == 1:
6457*da0073e9SAndroid Build Coastguard Worker                args = "a"
6458*da0073e9SAndroid Build Coastguard Worker            elif num_args == 2:
6459*da0073e9SAndroid Build Coastguard Worker                args = "a, b"
6460*da0073e9SAndroid Build Coastguard Worker            else:
6461*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError("Test doesn't support more than 2 arguments")
6462*da0073e9SAndroid Build Coastguard Worker            if args_type is None:
6463*da0073e9SAndroid Build Coastguard Worker                args_type = "(float, float)" if is_float else "(int, int)"
6464*da0073e9SAndroid Build Coastguard Worker            funcs_str = funcs_template.format(func=func_name, args=args, args_type=args_type, ret_type=ret_type)
6465*da0073e9SAndroid Build Coastguard Worker            scope = {}
6466*da0073e9SAndroid Build Coastguard Worker            execWrapper(funcs_str, globals(), scope)
6467*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit(funcs_str)
6468*da0073e9SAndroid Build Coastguard Worker            f_script = cu.func
6469*da0073e9SAndroid Build Coastguard Worker            f = scope['func']
6470*da0073e9SAndroid Build Coastguard Worker
6471*da0073e9SAndroid Build Coastguard Worker            if vals is None:
6472*da0073e9SAndroid Build Coastguard Worker                vals = float_vals if is_float else int_vals
6473*da0073e9SAndroid Build Coastguard Worker                vals = [(i, j) for i in vals for j in vals]
6474*da0073e9SAndroid Build Coastguard Worker
6475*da0073e9SAndroid Build Coastguard Worker            for a, b in vals:
6476*da0073e9SAndroid Build Coastguard Worker                res_python = None
6477*da0073e9SAndroid Build Coastguard Worker                res_script = None
6478*da0073e9SAndroid Build Coastguard Worker                try:
6479*da0073e9SAndroid Build Coastguard Worker                    res_python = f(a, b)
6480*da0073e9SAndroid Build Coastguard Worker                except Exception as e:
6481*da0073e9SAndroid Build Coastguard Worker                    res_python = e
6482*da0073e9SAndroid Build Coastguard Worker                try:
6483*da0073e9SAndroid Build Coastguard Worker                    res_script = f_script(a, b)
6484*da0073e9SAndroid Build Coastguard Worker                except Exception as e:
6485*da0073e9SAndroid Build Coastguard Worker                    res_script = e
6486*da0073e9SAndroid Build Coastguard Worker                if debug:
6487*da0073e9SAndroid Build Coastguard Worker                    print("in: ", a, b)
6488*da0073e9SAndroid Build Coastguard Worker                    print("out: ", res_python, res_script)
6489*da0073e9SAndroid Build Coastguard Worker                # We can't use assertEqual because of a couple of differences:
6490*da0073e9SAndroid Build Coastguard Worker                # 1. nan == nan should return true
6491*da0073e9SAndroid Build Coastguard Worker                # 2. When python functions throw an exception, we usually want to silently ignore them.
6492*da0073e9SAndroid Build Coastguard Worker                # (ie: We want to return `nan` for math.sqrt(-5))
6493*da0073e9SAndroid Build Coastguard Worker                if res_python != res_script:
6494*da0073e9SAndroid Build Coastguard Worker                    if isinstance(res_python, Exception):
6495*da0073e9SAndroid Build Coastguard Worker                        continue
6496*da0073e9SAndroid Build Coastguard Worker
6497*da0073e9SAndroid Build Coastguard Worker                    if type(res_python) == type(res_script):
6498*da0073e9SAndroid Build Coastguard Worker                        if isinstance(res_python, tuple) and (math.isnan(res_python[0]) == math.isnan(res_script[0])):
6499*da0073e9SAndroid Build Coastguard Worker                            continue
6500*da0073e9SAndroid Build Coastguard Worker                        if isinstance(res_python, float) and math.isnan(res_python) and math.isnan(res_script):
6501*da0073e9SAndroid Build Coastguard Worker                            continue
6502*da0073e9SAndroid Build Coastguard Worker                    msg = (f"Failed on {func_name} with inputs {a} {b}. Python: {res_python}, Script: {res_script}")
6503*da0073e9SAndroid Build Coastguard Worker                    # math.pow() behavior has changed in 3.11, see https://docs.python.org/3/library/math.html#math.pow
6504*da0073e9SAndroid Build Coastguard Worker                    if sys.version_info >= (3, 11) and func_name == "pow" and a == 0.0 and b == -math.inf:
6505*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue(res_python == math.inf and type(res_script) is RuntimeError)
6506*da0073e9SAndroid Build Coastguard Worker                    else:
6507*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(res_python, res_script, msg=msg, atol=(1e-4) * max(abs(res_python), res_script), rtol=0)
6508*da0073e9SAndroid Build Coastguard Worker
6509*da0073e9SAndroid Build Coastguard Worker        unary_float_ops = ["log", "log1p", "log10", "exp", "sqrt", "gamma", "lgamma", "erf",
6510*da0073e9SAndroid Build Coastguard Worker                           "erfc", "expm1", "fabs", "acos", "asin", "atan", "cos", "sin", "tan",
6511*da0073e9SAndroid Build Coastguard Worker                           "asinh", "atanh", "acosh", "sinh", "cosh", "tanh", "degrees", "radians"]
6512*da0073e9SAndroid Build Coastguard Worker        binary_float_ops = ["atan2", "fmod", "copysign"]
6513*da0073e9SAndroid Build Coastguard Worker        for op in unary_float_ops:
6514*da0073e9SAndroid Build Coastguard Worker            checkMathWrap(op, 1)
6515*da0073e9SAndroid Build Coastguard Worker        for op in binary_float_ops:
6516*da0073e9SAndroid Build Coastguard Worker            checkMathWrap(op, 2)
6517*da0073e9SAndroid Build Coastguard Worker
6518*da0073e9SAndroid Build Coastguard Worker        checkMath("modf", 1, ret_type="Tuple[float, float]")
6519*da0073e9SAndroid Build Coastguard Worker        checkMath("frexp", 1, ret_type="Tuple[float, int]")
6520*da0073e9SAndroid Build Coastguard Worker        checkMath("isnan", 1, ret_type="bool")
6521*da0073e9SAndroid Build Coastguard Worker        checkMath("isinf", 1, ret_type="bool")
6522*da0073e9SAndroid Build Coastguard Worker        checkMath("ldexp", 2, is_float=False, ret_type="float", args_type="(float, int)",
6523*da0073e9SAndroid Build Coastguard Worker                  vals=[(i, j) for i in float_vals for j in range(-10, 10)])
6524*da0073e9SAndroid Build Coastguard Worker        checkMath("pow", 2, is_float=False, ret_type="float")
6525*da0073e9SAndroid Build Coastguard Worker        checkMath("pow", 2, is_float=True, ret_type="float")
6526*da0073e9SAndroid Build Coastguard Worker        checkMathWrap("floor", ret_type="int")
6527*da0073e9SAndroid Build Coastguard Worker        checkMathWrap("ceil", ret_type="int")
6528*da0073e9SAndroid Build Coastguard Worker        checkMathWrap("gcd", 2, is_float=False, ret_type="int")
6529*da0073e9SAndroid Build Coastguard Worker        checkMath("isfinite", 1, ret_type="bool")
6530*da0073e9SAndroid Build Coastguard Worker        checkMathWrap("remainder", 2)
6531*da0073e9SAndroid Build Coastguard Worker        checkMathWrap("factorial", 1, is_float=False, ret_type="int", vals=[(i, 0) for i in range(-2, 10)])
6532*da0073e9SAndroid Build Coastguard Worker
6533*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
6534*da0073e9SAndroid Build Coastguard Worker    def test_if_nest_while(self):
6535*da0073e9SAndroid Build Coastguard Worker        def func(a, b):
6536*da0073e9SAndroid Build Coastguard Worker            # type: (int, int) -> int
6537*da0073e9SAndroid Build Coastguard Worker            c = 0
6538*da0073e9SAndroid Build Coastguard Worker            if a > b:
6539*da0073e9SAndroid Build Coastguard Worker                while a > b:
6540*da0073e9SAndroid Build Coastguard Worker                    b = b + 1
6541*da0073e9SAndroid Build Coastguard Worker                    c = -b
6542*da0073e9SAndroid Build Coastguard Worker            return c
6543*da0073e9SAndroid Build Coastguard Worker
6544*da0073e9SAndroid Build Coastguard Worker        inputs = self._make_scalar_vars([4321, 1234], torch.int64)
6545*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func, inputs)
6546*da0073e9SAndroid Build Coastguard Worker
6547*da0073e9SAndroid Build Coastguard Worker    def test_script_optional_none(self):
6548*da0073e9SAndroid Build Coastguard Worker        def none_stmt(x):
6549*da0073e9SAndroid Build Coastguard Worker            output = None
6550*da0073e9SAndroid Build Coastguard Worker            output = x
6551*da0073e9SAndroid Build Coastguard Worker            return output
6552*da0073e9SAndroid Build Coastguard Worker
6553*da0073e9SAndroid Build Coastguard Worker        def none_args(x):
6554*da0073e9SAndroid Build Coastguard Worker            # type: (Optional[Tensor]) -> Optional[Tensor]
6555*da0073e9SAndroid Build Coastguard Worker            return None
6556*da0073e9SAndroid Build Coastguard Worker
6557*da0073e9SAndroid Build Coastguard Worker        self.checkScript(none_stmt, [torch.arange(0, 2)], optimize=True)
6558*da0073e9SAndroid Build Coastguard Worker        self.checkScript(none_args, [None], optimize=True)
6559*da0073e9SAndroid Build Coastguard Worker
6560*da0073e9SAndroid Build Coastguard Worker        # test undefined tensor None as default param
6561*da0073e9SAndroid Build Coastguard Worker        def test_script_optional_tensor_none(x=None):
6562*da0073e9SAndroid Build Coastguard Worker            # type: (Optional[Tensor]) -> Tensor
6563*da0073e9SAndroid Build Coastguard Worker            res = torch.zeros(1, dtype=torch.int8)
6564*da0073e9SAndroid Build Coastguard Worker            if x is None:
6565*da0073e9SAndroid Build Coastguard Worker                res = res + 1
6566*da0073e9SAndroid Build Coastguard Worker            else:
6567*da0073e9SAndroid Build Coastguard Worker                res = x
6568*da0073e9SAndroid Build Coastguard Worker            return res
6569*da0073e9SAndroid Build Coastguard Worker
6570*da0073e9SAndroid Build Coastguard Worker        fn = test_script_optional_tensor_none
6571*da0073e9SAndroid Build Coastguard Worker        scripted_fn = torch.jit.script(fn)
6572*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(), scripted_fn())
6573*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(torch.zeros(1)), scripted_fn(torch.zeros(1)))
6574*da0073e9SAndroid Build Coastguard Worker
6575*da0073e9SAndroid Build Coastguard Worker        # test typical None as default param
6576*da0073e9SAndroid Build Coastguard Worker        def test_script_optional_other_none(x=None):
6577*da0073e9SAndroid Build Coastguard Worker            # type: (Optional[float]) -> float
6578*da0073e9SAndroid Build Coastguard Worker            res = 2.0
6579*da0073e9SAndroid Build Coastguard Worker            if x is None:
6580*da0073e9SAndroid Build Coastguard Worker                res = res + 1.0
6581*da0073e9SAndroid Build Coastguard Worker            else:
6582*da0073e9SAndroid Build Coastguard Worker                res = x
6583*da0073e9SAndroid Build Coastguard Worker            return res
6584*da0073e9SAndroid Build Coastguard Worker
6585*da0073e9SAndroid Build Coastguard Worker        fn = test_script_optional_other_none
6586*da0073e9SAndroid Build Coastguard Worker        scripted_fn = torch.jit.script(fn)
6587*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(), scripted_fn())
6588*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(1.0), scripted_fn(1.0))
6589*da0073e9SAndroid Build Coastguard Worker
6590*da0073e9SAndroid Build Coastguard Worker    def test_script_clamp_none(self):
6591*da0073e9SAndroid Build Coastguard Worker        def test_script_clamp_max_none(x):
6592*da0073e9SAndroid Build Coastguard Worker            return torch.clamp(x, min=2, max=None)
6593*da0073e9SAndroid Build Coastguard Worker
6594*da0073e9SAndroid Build Coastguard Worker        def test_script_clamp_max(x):
6595*da0073e9SAndroid Build Coastguard Worker            return torch.clamp(x, max=2)
6596*da0073e9SAndroid Build Coastguard Worker
6597*da0073e9SAndroid Build Coastguard Worker        def test_script_clamp_min_none(x):
6598*da0073e9SAndroid Build Coastguard Worker            return torch.clamp(x, min=None, max=2)
6599*da0073e9SAndroid Build Coastguard Worker
6600*da0073e9SAndroid Build Coastguard Worker        def test_script_clamp_min(x):
6601*da0073e9SAndroid Build Coastguard Worker            return torch.clamp(x, min=2)
6602*da0073e9SAndroid Build Coastguard Worker
6603*da0073e9SAndroid Build Coastguard Worker        input = [torch.arange(0, 3)]
6604*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_script_clamp_max_none, input, optimize=True)
6605*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_script_clamp_max, input, optimize=True)
6606*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_script_clamp_min_none, input, optimize=True)
6607*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_script_clamp_min, input, optimize=True)
6608*da0073e9SAndroid Build Coastguard Worker
6609*da0073e9SAndroid Build Coastguard Worker    def test_script_bool_constant(self):
6610*da0073e9SAndroid Build Coastguard Worker        def test_script_bool_constant():
6611*da0073e9SAndroid Build Coastguard Worker            a = True
6612*da0073e9SAndroid Build Coastguard Worker            return a
6613*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_script_bool_constant, [])
6614*da0073e9SAndroid Build Coastguard Worker
6615*da0073e9SAndroid Build Coastguard Worker    def test_ternary(self):
6616*da0073e9SAndroid Build Coastguard Worker        def func(a, b):
6617*da0073e9SAndroid Build Coastguard Worker            c = 3
6618*da0073e9SAndroid Build Coastguard Worker            c = a + b if bool(a > 3) else b
6619*da0073e9SAndroid Build Coastguard Worker            return c
6620*da0073e9SAndroid Build Coastguard Worker
6621*da0073e9SAndroid Build Coastguard Worker        inputs_true = self._make_scalar_vars([5, 2], torch.int64)
6622*da0073e9SAndroid Build Coastguard Worker        inputs_false = self._make_scalar_vars([1, 0], torch.int64)
6623*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func, inputs_true, optimize=True)
6624*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func, inputs_false, optimize=True)
6625*da0073e9SAndroid Build Coastguard Worker
6626*da0073e9SAndroid Build Coastguard Worker    def test_ternary_module_type_hint(self):
6627*da0073e9SAndroid Build Coastguard Worker        class M1(torch.nn.Module):
6628*da0073e9SAndroid Build Coastguard Worker            def forward(self) -> Any:
6629*da0073e9SAndroid Build Coastguard Worker                return 'out' if self.training else {}
6630*da0073e9SAndroid Build Coastguard Worker
6631*da0073e9SAndroid Build Coastguard Worker        class M2(torch.nn.Module):
6632*da0073e9SAndroid Build Coastguard Worker            def forward(self) -> Any:
6633*da0073e9SAndroid Build Coastguard Worker                out: Any = 'out' if self.training else {}
6634*da0073e9SAndroid Build Coastguard Worker                return out
6635*da0073e9SAndroid Build Coastguard Worker
6636*da0073e9SAndroid Build Coastguard Worker        class M3(torch.nn.Module):
6637*da0073e9SAndroid Build Coastguard Worker            def forward(self) -> Optional[int]:
6638*da0073e9SAndroid Build Coastguard Worker                return None if self.training else 1
6639*da0073e9SAndroid Build Coastguard Worker
6640*da0073e9SAndroid Build Coastguard Worker        for module in [M1, M2, M3]:
6641*da0073e9SAndroid Build Coastguard Worker            self.checkModule(module().train(), ())
6642*da0073e9SAndroid Build Coastguard Worker            self.checkModule(module().eval(), ())
6643*da0073e9SAndroid Build Coastguard Worker
6644*da0073e9SAndroid Build Coastguard Worker    def test_ternary_static_if(self):
6645*da0073e9SAndroid Build Coastguard Worker        # Test for True branch when condition variable
6646*da0073e9SAndroid Build Coastguard Worker        # is annotated as Final
6647*da0073e9SAndroid Build Coastguard Worker        class M1(torch.nn.Module):
6648*da0073e9SAndroid Build Coastguard Worker            flag: torch.jit.Final[bool]
6649*da0073e9SAndroid Build Coastguard Worker
6650*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
6651*da0073e9SAndroid Build Coastguard Worker                super().__init__()
6652*da0073e9SAndroid Build Coastguard Worker                self.flag = True
6653*da0073e9SAndroid Build Coastguard Worker
6654*da0073e9SAndroid Build Coastguard Worker            def forward(self) -> torch.Tensor:
6655*da0073e9SAndroid Build Coastguard Worker                return torch.ones(3) if self.flag else {}
6656*da0073e9SAndroid Build Coastguard Worker
6657*da0073e9SAndroid Build Coastguard Worker        # Test for True branch when condition variable
6658*da0073e9SAndroid Build Coastguard Worker        # is annotated as Final
6659*da0073e9SAndroid Build Coastguard Worker        class M2(torch.nn.Module):
6660*da0073e9SAndroid Build Coastguard Worker            flag: torch.jit.Final[bool]
6661*da0073e9SAndroid Build Coastguard Worker
6662*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
6663*da0073e9SAndroid Build Coastguard Worker                super().__init__()
6664*da0073e9SAndroid Build Coastguard Worker                self.flag = False
6665*da0073e9SAndroid Build Coastguard Worker
6666*da0073e9SAndroid Build Coastguard Worker            def forward(self) -> torch.Tensor:
6667*da0073e9SAndroid Build Coastguard Worker                return {} if self.flag else torch.ones(3)
6668*da0073e9SAndroid Build Coastguard Worker
6669*da0073e9SAndroid Build Coastguard Worker        model1 = M1()
6670*da0073e9SAndroid Build Coastguard Worker        model2 = M2()
6671*da0073e9SAndroid Build Coastguard Worker        script_model_1 = torch.jit.script(model1)
6672*da0073e9SAndroid Build Coastguard Worker        script_model_2 = torch.jit.script(model2)
6673*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model1.forward(), script_model_1.forward())
6674*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model2.forward(), script_model_2.forward())
6675*da0073e9SAndroid Build Coastguard Worker
6676*da0073e9SAndroid Build Coastguard Worker    def test_ternary_right_associative(self):
6677*da0073e9SAndroid Build Coastguard Worker        def plus_123(x: int):
6678*da0073e9SAndroid Build Coastguard Worker            return x + 1 if x == 1 else x + 2 if x == 2 else x + 3
6679*da0073e9SAndroid Build Coastguard Worker        self.checkScript(plus_123, (1,))
6680*da0073e9SAndroid Build Coastguard Worker        self.checkScript(plus_123, (2,))
6681*da0073e9SAndroid Build Coastguard Worker        self.checkScript(plus_123, (3,))
6682*da0073e9SAndroid Build Coastguard Worker
6683*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
6684*da0073e9SAndroid Build Coastguard Worker    def test_print(self):
6685*da0073e9SAndroid Build Coastguard Worker        def func(x, y):
6686*da0073e9SAndroid Build Coastguard Worker            q = (x + y).sigmoid()
6687*da0073e9SAndroid Build Coastguard Worker            print(q, 1, 2, [1, 2], [1.0, 2.0])
6688*da0073e9SAndroid Build Coastguard Worker            w = -q
6689*da0073e9SAndroid Build Coastguard Worker            return w * w
6690*da0073e9SAndroid Build Coastguard Worker
6691*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(4., requires_grad=True)
6692*da0073e9SAndroid Build Coastguard Worker        y = torch.arange(0., 8, 2, requires_grad=True)
6693*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func, [x, y], optimize=True, capture_output=True)
6694*da0073e9SAndroid Build Coastguard Worker
6695*da0073e9SAndroid Build Coastguard Worker    def test_format(self):
6696*da0073e9SAndroid Build Coastguard Worker        def func(x):
6697*da0073e9SAndroid Build Coastguard Worker            print("{}, I'm a {}".format("Hello", "test"))
6698*da0073e9SAndroid Build Coastguard Worker            print("format blank".format())
6699*da0073e9SAndroid Build Coastguard Worker            print("stuff before {}".format("hi"))
6700*da0073e9SAndroid Build Coastguard Worker            print("{} stuff after".format("hi"))
6701*da0073e9SAndroid Build Coastguard Worker            return x + 1
6702*da0073e9SAndroid Build Coastguard Worker
6703*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(4., requires_grad=True)
6704*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func, [x], optimize=True, capture_output=True)
6705*da0073e9SAndroid Build Coastguard Worker
6706*da0073e9SAndroid Build Coastguard Worker    def test_logical_short_circuit(self):
6707*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6708*da0073e9SAndroid Build Coastguard Worker        def testNoThrows(t):
6709*da0073e9SAndroid Build Coastguard Worker            c1 = 1
6710*da0073e9SAndroid Build Coastguard Worker            if (False and bool(t[1])) or (True or bool(t[1])):
6711*da0073e9SAndroid Build Coastguard Worker                c1 = 0
6712*da0073e9SAndroid Build Coastguard Worker            return c1
6713*da0073e9SAndroid Build Coastguard Worker
6714*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("prim::If").run(testNoThrows.graph)
6715*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, testNoThrows(torch.randn(0)))
6716*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, testNoThrows(torch.randn([2, 3])))
6717*da0073e9SAndroid Build Coastguard Worker
6718*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6719*da0073e9SAndroid Build Coastguard Worker        def throwsOr(t):
6720*da0073e9SAndroid Build Coastguard Worker            c0 = False or bool(t[1])
6721*da0073e9SAndroid Build Coastguard Worker            print(c0)
6722*da0073e9SAndroid Build Coastguard Worker
6723*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6724*da0073e9SAndroid Build Coastguard Worker        def throwsAnd(t):
6725*da0073e9SAndroid Build Coastguard Worker            c0 = True and bool(t[1])
6726*da0073e9SAndroid Build Coastguard Worker            print(c0)
6727*da0073e9SAndroid Build Coastguard Worker
6728*da0073e9SAndroid Build Coastguard Worker        t = torch.randn(0)
6729*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "index 1 out of range for tensor of size"):
6730*da0073e9SAndroid Build Coastguard Worker            throwsOr(t)
6731*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "index 1 out of range for tensor of size"):
6732*da0073e9SAndroid Build Coastguard Worker            throwsAnd(t)
6733*da0073e9SAndroid Build Coastguard Worker
6734*da0073e9SAndroid Build Coastguard Worker    def test_type_cast(self):
6735*da0073e9SAndroid Build Coastguard Worker        template = dedent('''
6736*da0073e9SAndroid Build Coastguard Worker        def func(v):
6737*da0073e9SAndroid Build Coastguard Worker            # type: ({from_type}) -> {to_type}
6738*da0073e9SAndroid Build Coastguard Worker            return {to_type}(v)
6739*da0073e9SAndroid Build Coastguard Worker        ''')
6740*da0073e9SAndroid Build Coastguard Worker
6741*da0073e9SAndroid Build Coastguard Worker        def check_cast(from_type, to_type, value, raises=False):
6742*da0073e9SAndroid Build Coastguard Worker            code = template.format(from_type=from_type, to_type=to_type)
6743*da0073e9SAndroid Build Coastguard Worker            self.checkScript(code, (value,))
6744*da0073e9SAndroid Build Coastguard Worker
6745*da0073e9SAndroid Build Coastguard Worker        check_cast('int', 'float', 1)
6746*da0073e9SAndroid Build Coastguard Worker        check_cast('int', 'bool', 1)
6747*da0073e9SAndroid Build Coastguard Worker        check_cast('int', 'bool', 0)
6748*da0073e9SAndroid Build Coastguard Worker
6749*da0073e9SAndroid Build Coastguard Worker        check_cast('float', 'int', 1.)
6750*da0073e9SAndroid Build Coastguard Worker        check_cast('float', 'bool', 1.)
6751*da0073e9SAndroid Build Coastguard Worker        check_cast('float', 'bool', 0.)
6752*da0073e9SAndroid Build Coastguard Worker
6753*da0073e9SAndroid Build Coastguard Worker        check_cast('bool', 'int', True)
6754*da0073e9SAndroid Build Coastguard Worker        check_cast('bool', 'float', True)
6755*da0073e9SAndroid Build Coastguard Worker
6756*da0073e9SAndroid Build Coastguard Worker    def test_multiple_assignment(self):
6757*da0073e9SAndroid Build Coastguard Worker        def outer_func(x):
6758*da0073e9SAndroid Build Coastguard Worker            return x * 2, x + 2
6759*da0073e9SAndroid Build Coastguard Worker
6760*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6761*da0073e9SAndroid Build Coastguard Worker        def func(x):
6762*da0073e9SAndroid Build Coastguard Worker            y, z = outer_func(x)
6763*da0073e9SAndroid Build Coastguard Worker            return y + z
6764*da0073e9SAndroid Build Coastguard Worker
6765*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(4)
6766*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(func(x), x * 2 + x + 2)
6767*da0073e9SAndroid Build Coastguard Worker
6768*da0073e9SAndroid Build Coastguard Worker    def test_literals(self):
6769*da0073e9SAndroid Build Coastguard Worker        def func(a):
6770*da0073e9SAndroid Build Coastguard Worker            return a.view(size=[1, 2, 3])
6771*da0073e9SAndroid Build Coastguard Worker
6772*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(6)
6773*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func, [a], optimize=True)
6774*da0073e9SAndroid Build Coastguard Worker
6775*da0073e9SAndroid Build Coastguard Worker    def test_return(self):
6776*da0073e9SAndroid Build Coastguard Worker        def no_return(a):
6777*da0073e9SAndroid Build Coastguard Worker            a + 1
6778*da0073e9SAndroid Build Coastguard Worker
6779*da0073e9SAndroid Build Coastguard Worker        def void_return(a):
6780*da0073e9SAndroid Build Coastguard Worker            return
6781*da0073e9SAndroid Build Coastguard Worker
6782*da0073e9SAndroid Build Coastguard Worker        def one_return(a):
6783*da0073e9SAndroid Build Coastguard Worker            return a + 1.
6784*da0073e9SAndroid Build Coastguard Worker
6785*da0073e9SAndroid Build Coastguard Worker        def multiple_returns(a):
6786*da0073e9SAndroid Build Coastguard Worker            return a * 1., a * 2., a * 3.
6787*da0073e9SAndroid Build Coastguard Worker
6788*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(1, dtype=torch.float)
6789*da0073e9SAndroid Build Coastguard Worker        self.checkScript(no_return, [a], optimize=True)
6790*da0073e9SAndroid Build Coastguard Worker        self.checkScript(void_return, [a], optimize=True)
6791*da0073e9SAndroid Build Coastguard Worker        self.checkScript(one_return, [a], optimize=True)
6792*da0073e9SAndroid Build Coastguard Worker        self.checkScript(multiple_returns, [a], optimize=True)
6793*da0073e9SAndroid Build Coastguard Worker
6794*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "does not return along all paths"):
6795*da0073e9SAndroid Build Coastguard Worker            torch.jit.CompilationUnit('''
6796*da0073e9SAndroid Build Coastguard Worker            def no_return_bad_annotation(a):
6797*da0073e9SAndroid Build Coastguard Worker                # type: (Tensor) -> Tensor
6798*da0073e9SAndroid Build Coastguard Worker                a + 1
6799*da0073e9SAndroid Build Coastguard Worker            ''')
6800*da0073e9SAndroid Build Coastguard Worker
6801*da0073e9SAndroid Build Coastguard Worker    def test_error(self):
6802*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6803*da0073e9SAndroid Build Coastguard Worker        def foo(a):
6804*da0073e9SAndroid Build Coastguard Worker            return a.t()
6805*da0073e9SAndroid Build Coastguard Worker        s = Variable(torch.rand(5, 5, 5))
6806*da0073e9SAndroid Build Coastguard Worker        # XXX: this should stay quiet in stay propagation and only fail in the interpreter
6807*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "failed in the TorchScript interpreter"):
6808*da0073e9SAndroid Build Coastguard Worker            foo(s)
6809*da0073e9SAndroid Build Coastguard Worker
6810*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6811*da0073e9SAndroid Build Coastguard Worker        def bar(c, b):
6812*da0073e9SAndroid Build Coastguard Worker            return c + b
6813*da0073e9SAndroid Build Coastguard Worker
6814*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "failed in the TorchScript interpreter"):
6815*da0073e9SAndroid Build Coastguard Worker            bar(Variable(torch.rand(10), requires_grad=True), Variable(torch.rand(9), requires_grad=True))
6816*da0073e9SAndroid Build Coastguard Worker
6817*da0073e9SAndroid Build Coastguard Worker    def test_error_stacktrace(self):
6818*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6819*da0073e9SAndroid Build Coastguard Worker        def baz(c, b):
6820*da0073e9SAndroid Build Coastguard Worker            return c + b
6821*da0073e9SAndroid Build Coastguard Worker
6822*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6823*da0073e9SAndroid Build Coastguard Worker        def foo(c, b):
6824*da0073e9SAndroid Build Coastguard Worker            return baz(c, b)
6825*da0073e9SAndroid Build Coastguard Worker
6826*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6827*da0073e9SAndroid Build Coastguard Worker        def bar(c, b):
6828*da0073e9SAndroid Build Coastguard Worker            return foo(c, b)
6829*da0073e9SAndroid Build Coastguard Worker
6830*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError) as cm:
6831*da0073e9SAndroid Build Coastguard Worker            bar(torch.rand(10), torch.rand(9))
6832*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("The following operation failed in the TorchScript interpreter") \
6833*da0073e9SAndroid Build Coastguard Worker                   .check("Traceback") \
6834*da0073e9SAndroid Build Coastguard Worker                   .check("in foo").check("in baz").run(str(cm.exception))
6835*da0073e9SAndroid Build Coastguard Worker
6836*da0073e9SAndroid Build Coastguard Worker    def test_error_stacktrace_interface(self):
6837*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6838*da0073e9SAndroid Build Coastguard Worker        def baz(c, b):
6839*da0073e9SAndroid Build Coastguard Worker            return c + b
6840*da0073e9SAndroid Build Coastguard Worker
6841*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6842*da0073e9SAndroid Build Coastguard Worker        def foo(c, b):
6843*da0073e9SAndroid Build Coastguard Worker            return baz(c, b)
6844*da0073e9SAndroid Build Coastguard Worker
6845*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6846*da0073e9SAndroid Build Coastguard Worker        def bar(c, b):
6847*da0073e9SAndroid Build Coastguard Worker            return foo(c, b)
6848*da0073e9SAndroid Build Coastguard Worker
6849*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6850*da0073e9SAndroid Build Coastguard Worker        class Bar:
6851*da0073e9SAndroid Build Coastguard Worker            def one(self, x, y):
6852*da0073e9SAndroid Build Coastguard Worker                return bar(x, y)
6853*da0073e9SAndroid Build Coastguard Worker
6854*da0073e9SAndroid Build Coastguard Worker        @torch.jit.interface
6855*da0073e9SAndroid Build Coastguard Worker        class IFace:
6856*da0073e9SAndroid Build Coastguard Worker            def one(self, x, y):
6857*da0073e9SAndroid Build Coastguard Worker                # type: (Tensor, Tensor) -> Tensor
6858*da0073e9SAndroid Build Coastguard Worker                pass
6859*da0073e9SAndroid Build Coastguard Worker
6860*da0073e9SAndroid Build Coastguard Worker        make_global(IFace)
6861*da0073e9SAndroid Build Coastguard Worker
6862*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6863*da0073e9SAndroid Build Coastguard Worker        def as_interface(x):
6864*da0073e9SAndroid Build Coastguard Worker            # type: (IFace) -> IFace
6865*da0073e9SAndroid Build Coastguard Worker            return x
6866*da0073e9SAndroid Build Coastguard Worker
6867*da0073e9SAndroid Build Coastguard Worker        f = as_interface(Bar())
6868*da0073e9SAndroid Build Coastguard Worker
6869*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError) as cm:
6870*da0073e9SAndroid Build Coastguard Worker            x = f.one(torch.rand(10), torch.rand(9))
6871*da0073e9SAndroid Build Coastguard Worker            bar(torch.rand(10), torch.rand(9))
6872*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("The following operation failed in the TorchScript interpreter") \
6873*da0073e9SAndroid Build Coastguard Worker                   .check("Traceback") \
6874*da0073e9SAndroid Build Coastguard Worker                   .check("in foo").check("in baz").run(str(cm.exception))
6875*da0073e9SAndroid Build Coastguard Worker
6876*da0073e9SAndroid Build Coastguard Worker    def test_operator_precedence(self):
6877*da0073e9SAndroid Build Coastguard Worker        def double(x):
6878*da0073e9SAndroid Build Coastguard Worker            # type: (int) -> int
6879*da0073e9SAndroid Build Coastguard Worker            return 2 * x
6880*da0073e9SAndroid Build Coastguard Worker
6881*da0073e9SAndroid Build Coastguard Worker        def complicated_arithmetic_operation():
6882*da0073e9SAndroid Build Coastguard Worker            # TODO we need to test exponent operator '**' and bitwise not
6883*da0073e9SAndroid Build Coastguard Worker            # operator '~' once they are properly supported.
6884*da0073e9SAndroid Build Coastguard Worker            list = [0, 1, 2, 3]
6885*da0073e9SAndroid Build Coastguard Worker            result = list[1:3][0] + double(4) + (-3 + 8) * 6 // 2 % 4 << 2 + 1 >> 1 | 23 & 16 + 3 ^ 4
6886*da0073e9SAndroid Build Coastguard Worker            return result
6887*da0073e9SAndroid Build Coastguard Worker
6888*da0073e9SAndroid Build Coastguard Worker        self.checkScript(complicated_arithmetic_operation, ())
6889*da0073e9SAndroid Build Coastguard Worker
6890*da0073e9SAndroid Build Coastguard Worker    def test_in_operator_with_two_strings(self):
6891*da0073e9SAndroid Build Coastguard Worker        def fn() -> bool:
6892*da0073e9SAndroid Build Coastguard Worker            return "a" in "abcd"
6893*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ())
6894*da0073e9SAndroid Build Coastguard Worker
6895*da0073e9SAndroid Build Coastguard Worker    def test_bitwise_ops(self):
6896*da0073e9SAndroid Build Coastguard Worker
6897*da0073e9SAndroid Build Coastguard Worker        def int_test():
6898*da0073e9SAndroid Build Coastguard Worker            return 2 & 3, 2 ^ 3, 2 | 3, 2 << 3, 2 >> 3
6899*da0073e9SAndroid Build Coastguard Worker
6900*da0073e9SAndroid Build Coastguard Worker        self.checkScript(int_test, ())
6901*da0073e9SAndroid Build Coastguard Worker
6902*da0073e9SAndroid Build Coastguard Worker        def bool_test(x, y):
6903*da0073e9SAndroid Build Coastguard Worker            # type: (bool, bool) -> Tuple[bool, bool, bool]
6904*da0073e9SAndroid Build Coastguard Worker            return x & y, x ^ y, x | y
6905*da0073e9SAndroid Build Coastguard Worker
6906*da0073e9SAndroid Build Coastguard Worker        self.checkScript(bool_test, (True, False))
6907*da0073e9SAndroid Build Coastguard Worker        self.checkScript(bool_test, (True, True))
6908*da0073e9SAndroid Build Coastguard Worker
6909*da0073e9SAndroid Build Coastguard Worker        def tensor_test(x, y):
6910*da0073e9SAndroid Build Coastguard Worker            return x & y, x ^ y, x | y
6911*da0073e9SAndroid Build Coastguard Worker
6912*da0073e9SAndroid Build Coastguard Worker        def tensor_with_int_test(x, y):
6913*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor, int) -> Tuple[Tensor, Tensor]
6914*da0073e9SAndroid Build Coastguard Worker            return x << y, x >> y
6915*da0073e9SAndroid Build Coastguard Worker
6916*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(2)
6917*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor(3)
6918*da0073e9SAndroid Build Coastguard Worker
6919*da0073e9SAndroid Build Coastguard Worker        self.checkScript(tensor_test, (x, y))
6920*da0073e9SAndroid Build Coastguard Worker        self.checkScript(tensor_with_int_test, (x, 2))
6921*da0073e9SAndroid Build Coastguard Worker
6922*da0073e9SAndroid Build Coastguard Worker        def not_test(x):
6923*da0073e9SAndroid Build Coastguard Worker            return ~x
6924*da0073e9SAndroid Build Coastguard Worker
6925*da0073e9SAndroid Build Coastguard Worker        self.checkScript(not_test, (torch.tensor([2, 4]), ))
6926*da0073e9SAndroid Build Coastguard Worker
6927*da0073e9SAndroid Build Coastguard Worker    def test_all(self):
6928*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6929*da0073e9SAndroid Build Coastguard Worker        def test_all_tensor(x):
6930*da0073e9SAndroid Build Coastguard Worker            return all(x)
6931*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(test_all_tensor(torch.tensor([1, 0, 3], dtype=torch.uint8)))
6932*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(test_all_tensor(torch.tensor([3.14, 3, 99], dtype=torch.uint8)))
6933*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(test_all_tensor(torch.tensor([True, True], dtype=torch.uint8)))
6934*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(test_all_tensor(torch.tensor([True, False], dtype=torch.uint8)))
6935*da0073e9SAndroid Build Coastguard Worker
6936*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6937*da0073e9SAndroid Build Coastguard Worker        def test_all_bool_list(x):
6938*da0073e9SAndroid Build Coastguard Worker            # type: (List[bool]) -> bool
6939*da0073e9SAndroid Build Coastguard Worker            return all(x)
6940*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(test_all_bool_list([True, True]))
6941*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(test_all_bool_list([True, 1]))
6942*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(test_all_bool_list([True, False]))
6943*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(test_all_bool_list([True, 0]))
6944*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(test_all_bool_list([False, 0]))
6945*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(test_all_bool_list([]))
6946*da0073e9SAndroid Build Coastguard Worker
6947*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6948*da0073e9SAndroid Build Coastguard Worker        def test_all_int_list(x):
6949*da0073e9SAndroid Build Coastguard Worker            # type: (List[int]) -> bool
6950*da0073e9SAndroid Build Coastguard Worker            return all(x)
6951*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(test_all_int_list([3, 6]))
6952*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(test_all_int_list([2, 0]))
6953*da0073e9SAndroid Build Coastguard Worker
6954*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
6955*da0073e9SAndroid Build Coastguard Worker        def test_all_float_list(x):
6956*da0073e9SAndroid Build Coastguard Worker            # type: (List[float]) -> bool
6957*da0073e9SAndroid Build Coastguard Worker            return all(x)
6958*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(test_all_float_list([3.14, 8.1]))
6959*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(test_all_float_list([3.14, 0, 8.9]))
6960*da0073e9SAndroid Build Coastguard Worker
6961*da0073e9SAndroid Build Coastguard Worker
6962*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
6963*da0073e9SAndroid Build Coastguard Worker    def test_number_math(self):
6964*da0073e9SAndroid Build Coastguard Worker        ops_template = dedent('''
6965*da0073e9SAndroid Build Coastguard Worker        def func():
6966*da0073e9SAndroid Build Coastguard Worker            return {scalar1} {op} {scalar2}
6967*da0073e9SAndroid Build Coastguard Worker        ''')
6968*da0073e9SAndroid Build Coastguard Worker        ops = ['+', '-', '*', '%', '<', '<=', '>', '>=', '==', '!=', '//']
6969*da0073e9SAndroid Build Coastguard Worker        funcs_template = dedent('''
6970*da0073e9SAndroid Build Coastguard Worker        def func():
6971*da0073e9SAndroid Build Coastguard Worker            return {func}({scalar1}, {scalar2})
6972*da0073e9SAndroid Build Coastguard Worker        ''')
6973*da0073e9SAndroid Build Coastguard Worker        funcs = ['min', 'max']
6974*da0073e9SAndroid Build Coastguard Worker        scalars = ['7', '2', '3', '-3', '3.14', '0.125', '-0.5', '2.0', '-2.0']
6975*da0073e9SAndroid Build Coastguard Worker        scalar_pairs = [(scalar1, scalar2) for scalar1 in scalars for scalar2 in scalars]
6976*da0073e9SAndroid Build Coastguard Worker
6977*da0073e9SAndroid Build Coastguard Worker        def run_test(code):
6978*da0073e9SAndroid Build Coastguard Worker            scope = {}
6979*da0073e9SAndroid Build Coastguard Worker            execWrapper(code, globals(), scope)
6980*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit(code)
6981*da0073e9SAndroid Build Coastguard Worker
6982*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cu.func(), scope['func']())
6983*da0073e9SAndroid Build Coastguard Worker
6984*da0073e9SAndroid Build Coastguard Worker        for scalar1, scalar2 in scalar_pairs:
6985*da0073e9SAndroid Build Coastguard Worker            for op in ops:
6986*da0073e9SAndroid Build Coastguard Worker                code = ops_template.format(op=op, scalar1=scalar1, scalar2=scalar2)
6987*da0073e9SAndroid Build Coastguard Worker                run_test(code)
6988*da0073e9SAndroid Build Coastguard Worker            for func in funcs:
6989*da0073e9SAndroid Build Coastguard Worker                code = funcs_template.format(func=func, scalar1=scalar1, scalar2=scalar2)
6990*da0073e9SAndroid Build Coastguard Worker                run_test(code)
6991*da0073e9SAndroid Build Coastguard Worker
6992*da0073e9SAndroid Build Coastguard Worker        # test Scalar overloads
6993*da0073e9SAndroid Build Coastguard Worker        for scalar1, scalar2 in scalar_pairs:
6994*da0073e9SAndroid Build Coastguard Worker            item1 = 'torch.tensor(' + scalar1 + ').item()'
6995*da0073e9SAndroid Build Coastguard Worker            item2 = 'torch.tensor(' + scalar2 + ').item()'
6996*da0073e9SAndroid Build Coastguard Worker            for op in ops:
6997*da0073e9SAndroid Build Coastguard Worker                code = ops_template.format(op=op, scalar1=item1, scalar2=scalar2)
6998*da0073e9SAndroid Build Coastguard Worker                run_test(code)
6999*da0073e9SAndroid Build Coastguard Worker                code = ops_template.format(op=op, scalar1=scalar1, scalar2=item2)
7000*da0073e9SAndroid Build Coastguard Worker                run_test(code)
7001*da0073e9SAndroid Build Coastguard Worker                code = ops_template.format(op=op, scalar1=item1, scalar2=item2)
7002*da0073e9SAndroid Build Coastguard Worker                run_test(code)
7003*da0073e9SAndroid Build Coastguard Worker            for func in funcs:
7004*da0073e9SAndroid Build Coastguard Worker                code = funcs_template.format(func=func, scalar1=item1, scalar2=scalar2)
7005*da0073e9SAndroid Build Coastguard Worker                run_test(code)
7006*da0073e9SAndroid Build Coastguard Worker                code = funcs_template.format(func=func, scalar1=scalar1, scalar2=item2)
7007*da0073e9SAndroid Build Coastguard Worker                run_test(code)
7008*da0073e9SAndroid Build Coastguard Worker                code = funcs_template.format(func=func, scalar1=item1, scalar2=item2)
7009*da0073e9SAndroid Build Coastguard Worker                run_test(code)
7010*da0073e9SAndroid Build Coastguard Worker
7011*da0073e9SAndroid Build Coastguard Worker    def test_number_abs(self):
7012*da0073e9SAndroid Build Coastguard Worker        def func1(x):
7013*da0073e9SAndroid Build Coastguard Worker            # type: (float) -> float
7014*da0073e9SAndroid Build Coastguard Worker            return abs(x)
7015*da0073e9SAndroid Build Coastguard Worker
7016*da0073e9SAndroid Build Coastguard Worker        def func2(x):
7017*da0073e9SAndroid Build Coastguard Worker            # type: (int) -> int
7018*da0073e9SAndroid Build Coastguard Worker            return abs(x)
7019*da0073e9SAndroid Build Coastguard Worker
7020*da0073e9SAndroid Build Coastguard Worker        def func3(x):
7021*da0073e9SAndroid Build Coastguard Worker            return abs(x)
7022*da0073e9SAndroid Build Coastguard Worker
7023*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func1, (-3.14,))
7024*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func1, (3.14,))
7025*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func2, (-10,))
7026*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func2, (10,))
7027*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func3, (torch.tensor([-5, -10, -20]),))
7028*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func3, (torch.tensor([5, 10, 20]),))
7029*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func3, (torch.tensor([-5, 10, -20]),))
7030*da0073e9SAndroid Build Coastguard Worker
7031*da0073e9SAndroid Build Coastguard Worker    def test_number_div(self):
7032*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(div_int_future(), torch.jit.script(div_int_future)())
7033*da0073e9SAndroid Build Coastguard Worker        self.checkScript(div_float_future, ())
7034*da0073e9SAndroid Build Coastguard Worker
7035*da0073e9SAndroid Build Coastguard Worker        self.checkScript(div_int_nofuture, ())
7036*da0073e9SAndroid Build Coastguard Worker        self.checkScript(div_float_nofuture, ())
7037*da0073e9SAndroid Build Coastguard Worker
7038*da0073e9SAndroid Build Coastguard Worker    # Testing bitwise shorthand aug assignment
7039*da0073e9SAndroid Build Coastguard Worker    def test_bool_augassign_bitwise_or(self):
7040*da0073e9SAndroid Build Coastguard Worker        def func(a: bool, b: bool) -> bool:
7041*da0073e9SAndroid Build Coastguard Worker            a |= b
7042*da0073e9SAndroid Build Coastguard Worker            return a
7043*da0073e9SAndroid Build Coastguard Worker
7044*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func, (True, False), optimize=True)
7045*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func, (True, True), optimize=True)
7046*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func, (False, False), optimize=True)
7047*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func, (False, True), optimize=True)
7048*da0073e9SAndroid Build Coastguard Worker
7049*da0073e9SAndroid Build Coastguard Worker    def test_bool_augassign_bitwise_and(self):
7050*da0073e9SAndroid Build Coastguard Worker        def func(a: bool, b: bool) -> bool:
7051*da0073e9SAndroid Build Coastguard Worker            a &= b
7052*da0073e9SAndroid Build Coastguard Worker            return a
7053*da0073e9SAndroid Build Coastguard Worker
7054*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func, (True, False), optimize=True)
7055*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func, (True, True), optimize=True)
7056*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func, (False, False), optimize=True)
7057*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func, (False, True), optimize=True)
7058*da0073e9SAndroid Build Coastguard Worker
7059*da0073e9SAndroid Build Coastguard Worker    def test_bool_augassign_bitwise_xor(self):
7060*da0073e9SAndroid Build Coastguard Worker        def func(a: bool, b: bool) -> bool:
7061*da0073e9SAndroid Build Coastguard Worker            a ^= b
7062*da0073e9SAndroid Build Coastguard Worker            return a
7063*da0073e9SAndroid Build Coastguard Worker
7064*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func, (True, False), optimize=True)
7065*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func, (True, True), optimize=True)
7066*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func, (False, False), optimize=True)
7067*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func, (False, True), optimize=True)
7068*da0073e9SAndroid Build Coastguard Worker
7069*da0073e9SAndroid Build Coastguard Worker    def test_number_augassign_bitwise_lshift(self):
7070*da0073e9SAndroid Build Coastguard Worker        def func() -> int:
7071*da0073e9SAndroid Build Coastguard Worker            z = 8
7072*da0073e9SAndroid Build Coastguard Worker            z <<= 2
7073*da0073e9SAndroid Build Coastguard Worker            return z
7074*da0073e9SAndroid Build Coastguard Worker
7075*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func, (), optimize=True)
7076*da0073e9SAndroid Build Coastguard Worker
7077*da0073e9SAndroid Build Coastguard Worker    def test_number_augassign_bitwise_rshift(self):
7078*da0073e9SAndroid Build Coastguard Worker        def func() -> int:
7079*da0073e9SAndroid Build Coastguard Worker            z = 8
7080*da0073e9SAndroid Build Coastguard Worker            z >>= 2
7081*da0073e9SAndroid Build Coastguard Worker            return z
7082*da0073e9SAndroid Build Coastguard Worker
7083*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func, (), optimize=True)
7084*da0073e9SAndroid Build Coastguard Worker
7085*da0073e9SAndroid Build Coastguard Worker    def test_number_augassign_bitwise_pow(self):
7086*da0073e9SAndroid Build Coastguard Worker        def func() -> float:
7087*da0073e9SAndroid Build Coastguard Worker            z = 8
7088*da0073e9SAndroid Build Coastguard Worker            z **= 2
7089*da0073e9SAndroid Build Coastguard Worker            return z
7090*da0073e9SAndroid Build Coastguard Worker
7091*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func, (), optimize=True)
7092*da0073e9SAndroid Build Coastguard Worker
7093*da0073e9SAndroid Build Coastguard Worker    def test_number_augassign(self):
7094*da0073e9SAndroid Build Coastguard Worker        def func():
7095*da0073e9SAndroid Build Coastguard Worker            z = 1
7096*da0073e9SAndroid Build Coastguard Worker            z += 2
7097*da0073e9SAndroid Build Coastguard Worker            return z
7098*da0073e9SAndroid Build Coastguard Worker
7099*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func, (), optimize=True)
7100*da0073e9SAndroid Build Coastguard Worker
7101*da0073e9SAndroid Build Coastguard Worker    def test_nested_select_assign(self):
7102*da0073e9SAndroid Build Coastguard Worker        class SubSubModule(torch.nn.Module):
7103*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
7104*da0073e9SAndroid Build Coastguard Worker                super().__init__()
7105*da0073e9SAndroid Build Coastguard Worker                self.abc = 11
7106*da0073e9SAndroid Build Coastguard Worker
7107*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
7108*da0073e9SAndroid Build Coastguard Worker                return self.abc
7109*da0073e9SAndroid Build Coastguard Worker
7110*da0073e9SAndroid Build Coastguard Worker        class SubModule(torch.nn.Module):
7111*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
7112*da0073e9SAndroid Build Coastguard Worker                super().__init__()
7113*da0073e9SAndroid Build Coastguard Worker                self.a = 11
7114*da0073e9SAndroid Build Coastguard Worker                self.nested = SubSubModule()
7115*da0073e9SAndroid Build Coastguard Worker
7116*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
7117*da0073e9SAndroid Build Coastguard Worker                return self.a
7118*da0073e9SAndroid Build Coastguard Worker
7119*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.nn.Module):
7120*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
7121*da0073e9SAndroid Build Coastguard Worker                super().__init__()
7122*da0073e9SAndroid Build Coastguard Worker                self.sub = SubModule()
7123*da0073e9SAndroid Build Coastguard Worker                self.hi = 1
7124*da0073e9SAndroid Build Coastguard Worker
7125*da0073e9SAndroid Build Coastguard Worker            def forward(self):
7126*da0073e9SAndroid Build Coastguard Worker                self.hi = 5
7127*da0073e9SAndroid Build Coastguard Worker                self.sub.a = 1
7128*da0073e9SAndroid Build Coastguard Worker                self.sub.nested.abc = 5
7129*da0073e9SAndroid Build Coastguard Worker                return self.sub.a * 20 + self.sub.nested.abc * 3 + self.hi
7130*da0073e9SAndroid Build Coastguard Worker
7131*da0073e9SAndroid Build Coastguard Worker        self.checkModule(TestModule(), ())
7132*da0073e9SAndroid Build Coastguard Worker
7133*da0073e9SAndroid Build Coastguard Worker    def test_number_neg(self):
7134*da0073e9SAndroid Build Coastguard Worker        # int -> int
7135*da0073e9SAndroid Build Coastguard Worker        def func1():
7136*da0073e9SAndroid Build Coastguard Worker            return -8
7137*da0073e9SAndroid Build Coastguard Worker
7138*da0073e9SAndroid Build Coastguard Worker        # float -> float
7139*da0073e9SAndroid Build Coastguard Worker        def func2():
7140*da0073e9SAndroid Build Coastguard Worker            return -3.14
7141*da0073e9SAndroid Build Coastguard Worker
7142*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func1, (), optimize=True)
7143*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func2, (), optimize=True)
7144*da0073e9SAndroid Build Coastguard Worker
7145*da0073e9SAndroid Build Coastguard Worker    def test_compare_two_bool_inputs(self):
7146*da0073e9SAndroid Build Coastguard Worker        def compare_eq(a: bool, b: bool):
7147*da0073e9SAndroid Build Coastguard Worker            return a == b
7148*da0073e9SAndroid Build Coastguard Worker
7149*da0073e9SAndroid Build Coastguard Worker        def compare_ne(a: bool, b: bool):
7150*da0073e9SAndroid Build Coastguard Worker            return a != b
7151*da0073e9SAndroid Build Coastguard Worker
7152*da0073e9SAndroid Build Coastguard Worker        scripted_fn_eq = torch.jit.script(compare_eq)
7153*da0073e9SAndroid Build Coastguard Worker        scripted_fn_ne = torch.jit.script(compare_ne)
7154*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_fn_eq(True, False), compare_eq(True, False))
7155*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_fn_eq(False, True), compare_eq(False, True))
7156*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_fn_eq(True, True), compare_eq(True, True))
7157*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_fn_eq(False, False), compare_eq(False, False))
7158*da0073e9SAndroid Build Coastguard Worker
7159*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_fn_ne(True, False), compare_ne(True, False))
7160*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_fn_ne(False, True), compare_ne(False, True))
7161*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_fn_ne(True, True), compare_ne(True, True))
7162*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_fn_ne(False, False), compare_ne(False, False))
7163*da0073e9SAndroid Build Coastguard Worker
7164*da0073e9SAndroid Build Coastguard Worker
7165*da0073e9SAndroid Build Coastguard Worker    def _test_tensor_number_math(self, device='cpu'):
7166*da0073e9SAndroid Build Coastguard Worker        template = dedent('''
7167*da0073e9SAndroid Build Coastguard Worker        def func(t):
7168*da0073e9SAndroid Build Coastguard Worker            return {lhs} {op} {rhs}
7169*da0073e9SAndroid Build Coastguard Worker        ''')
7170*da0073e9SAndroid Build Coastguard Worker
7171*da0073e9SAndroid Build Coastguard Worker        def test(op, tensor, const, swap_args, template=template):
7172*da0073e9SAndroid Build Coastguard Worker            args = ('t', const)
7173*da0073e9SAndroid Build Coastguard Worker            if swap_args:
7174*da0073e9SAndroid Build Coastguard Worker                args = (const, 't')
7175*da0073e9SAndroid Build Coastguard Worker
7176*da0073e9SAndroid Build Coastguard Worker            code = template.format(lhs=args[0], rhs=args[1], op=op)
7177*da0073e9SAndroid Build Coastguard Worker            scope = {}
7178*da0073e9SAndroid Build Coastguard Worker            execWrapper(code, globals(), scope)
7179*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit(code)
7180*da0073e9SAndroid Build Coastguard Worker            message = f'with code `{args[0]} {op} {args[1]}` and t={tensor}'
7181*da0073e9SAndroid Build Coastguard Worker            res1 = cu.func(tensor)
7182*da0073e9SAndroid Build Coastguard Worker            res2 = scope['func'](tensor)
7183*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res1, res2, msg=message + "\nres1=" + str(res1) + "\nres2=" + str(res2))
7184*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res1.dtype, res2.dtype, msg=message + "\nres1=" + str(res1) + "\nres2=" + str(res2))
7185*da0073e9SAndroid Build Coastguard Worker
7186*da0073e9SAndroid Build Coastguard Worker        var_int = [2, -2]
7187*da0073e9SAndroid Build Coastguard Worker        var_float = [1.4321, -1.2]
7188*da0073e9SAndroid Build Coastguard Worker
7189*da0073e9SAndroid Build Coastguard Worker        ops = ['+', '-', '*', '%', '<', '<=', '>', '>=', '==', '!=', '/']
7190*da0073e9SAndroid Build Coastguard Worker
7191*da0073e9SAndroid Build Coastguard Worker        float_tensor = torch.randn(5, 5, device=device)
7192*da0073e9SAndroid Build Coastguard Worker        double_tensor = torch.randn(5, 5, dtype=torch.double, device=device)
7193*da0073e9SAndroid Build Coastguard Worker        long_tensor = torch.randint(-5, 5, (5, 5), dtype=torch.long, device=device)
7194*da0073e9SAndroid Build Coastguard Worker        long_tensor[long_tensor == 0] = 2
7195*da0073e9SAndroid Build Coastguard Worker
7196*da0073e9SAndroid Build Coastguard Worker        tensors = [float_tensor, double_tensor, long_tensor]
7197*da0073e9SAndroid Build Coastguard Worker        consts = var_int + var_float
7198*da0073e9SAndroid Build Coastguard Worker
7199*da0073e9SAndroid Build Coastguard Worker        for op, tensor, const, swap_args in product(ops, tensors, consts, [True, False]):
7200*da0073e9SAndroid Build Coastguard Worker            # FIXME: things like 2 / long_tensor are not implemented correctly
7201*da0073e9SAndroid Build Coastguard Worker            # Look in torch/_tensor.py to see how pytorch implements it.
7202*da0073e9SAndroid Build Coastguard Worker            if op == '/' and tensor.data_ptr() == long_tensor.data_ptr():
7203*da0073e9SAndroid Build Coastguard Worker                continue
7204*da0073e9SAndroid Build Coastguard Worker
7205*da0073e9SAndroid Build Coastguard Worker            # % operator does not take: const % tensor
7206*da0073e9SAndroid Build Coastguard Worker            if op == '%' and swap_args is True:
7207*da0073e9SAndroid Build Coastguard Worker                continue
7208*da0073e9SAndroid Build Coastguard Worker
7209*da0073e9SAndroid Build Coastguard Worker            test(op, tensor, const, swap_args)
7210*da0073e9SAndroid Build Coastguard Worker
7211*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
7212*da0073e9SAndroid Build Coastguard Worker    def test_tensor_number_math(self):
7213*da0073e9SAndroid Build Coastguard Worker        self._test_tensor_number_math()
7214*da0073e9SAndroid Build Coastguard Worker
7215*da0073e9SAndroid Build Coastguard Worker    def test_torch_tensor_bad_input(self):
7216*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "must be of ints, floats, "
7217*da0073e9SAndroid Build Coastguard Worker                                    "or bools, got None"):
7218*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
7219*da0073e9SAndroid Build Coastguard Worker            def test():
7220*da0073e9SAndroid Build Coastguard Worker                return torch.tensor([None])
7221*da0073e9SAndroid Build Coastguard Worker            test()
7222*da0073e9SAndroid Build Coastguard Worker
7223*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"Empty lists default to List\[Tensor\]"):
7224*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
7225*da0073e9SAndroid Build Coastguard Worker            def tmp():
7226*da0073e9SAndroid Build Coastguard Worker                return torch.tensor([])
7227*da0073e9SAndroid Build Coastguard Worker            tmp()
7228*da0073e9SAndroid Build Coastguard Worker
7229*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
7230*da0073e9SAndroid Build Coastguard Worker        def foo():
7231*da0073e9SAndroid Build Coastguard Worker            return torch.tensor([[2, 2], [1]])
7232*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Expected sequence of length"):
7233*da0073e9SAndroid Build Coastguard Worker            foo()
7234*da0073e9SAndroid Build Coastguard Worker
7235*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
7236*da0073e9SAndroid Build Coastguard Worker    def test_torch_tensor_as_tensor_empty_list(self):
7237*da0073e9SAndroid Build Coastguard Worker        tensor_template = dedent('''
7238*da0073e9SAndroid Build Coastguard Worker        def func():
7239*da0073e9SAndroid Build Coastguard Worker            empty_list = torch.jit.annotate(List[int], [])
7240*da0073e9SAndroid Build Coastguard Worker            ten1 = torch.{tensor_op}({input})
7241*da0073e9SAndroid Build Coastguard Worker            return ten1
7242*da0073e9SAndroid Build Coastguard Worker        ''')
7243*da0073e9SAndroid Build Coastguard Worker        ops = ['tensor', 'as_tensor']
7244*da0073e9SAndroid Build Coastguard Worker        inputs = ['empty_list', '[empty_list, empty_list]', '[[[empty_list]]]']
7245*da0073e9SAndroid Build Coastguard Worker
7246*da0073e9SAndroid Build Coastguard Worker        for op in ops:
7247*da0073e9SAndroid Build Coastguard Worker            for inp in inputs:
7248*da0073e9SAndroid Build Coastguard Worker                code = tensor_template.format(tensor_op=op, input=inp)
7249*da0073e9SAndroid Build Coastguard Worker                scope = {}
7250*da0073e9SAndroid Build Coastguard Worker                exec(code, globals(), scope)
7251*da0073e9SAndroid Build Coastguard Worker                cu = torch.jit.CompilationUnit(code)
7252*da0073e9SAndroid Build Coastguard Worker                t1 = cu.func()
7253*da0073e9SAndroid Build Coastguard Worker                t2 = scope['func']()
7254*da0073e9SAndroid Build Coastguard Worker                if inp == 'empty_list':
7255*da0073e9SAndroid Build Coastguard Worker                    # torchscript returns int tensor, python returns float tensor
7256*da0073e9SAndroid Build Coastguard Worker                    self.assertNotEqual(t1.dtype, t2.dtype)
7257*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(t1, t2, exact_dtype=False)
7258*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(t1.device, t2.device)
7259*da0073e9SAndroid Build Coastguard Worker
7260*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "Simple Executor doesn't have any shapes to propagate")
7261*da0073e9SAndroid Build Coastguard Worker    def test_tensor_as_tensor_shape_prop(self):
7262*da0073e9SAndroid Build Coastguard Worker        tensor_template = dedent('''
7263*da0073e9SAndroid Build Coastguard Worker        def func():
7264*da0073e9SAndroid Build Coastguard Worker            return torch.{tensor_op}({input})
7265*da0073e9SAndroid Build Coastguard Worker        ''')
7266*da0073e9SAndroid Build Coastguard Worker        ops = ['tensor', 'as_tensor']
7267*da0073e9SAndroid Build Coastguard Worker        inputs = ['[1]', '[False]', '[2.5]', '0.5', '1', 'False', '[[1]]', 'torch.jit.annotate(List[List[int]], [])']
7268*da0073e9SAndroid Build Coastguard Worker        expected_shape = ["Long(*, device=cpu)", "Bool(*, device=cpu)",
7269*da0073e9SAndroid Build Coastguard Worker                          "Float(*, device=cpu)", "Float(device=cpu)",
7270*da0073e9SAndroid Build Coastguard Worker                          "Long(device=cpu)", "Bool(device=cpu)", "Long(*, *, device=cpu)"]
7271*da0073e9SAndroid Build Coastguard Worker
7272*da0073e9SAndroid Build Coastguard Worker        for op in ops:
7273*da0073e9SAndroid Build Coastguard Worker            for inp, expect in zip(inputs, expected_shape):
7274*da0073e9SAndroid Build Coastguard Worker                code = tensor_template.format(tensor_op=op, input=inp)
7275*da0073e9SAndroid Build Coastguard Worker                scope = {}
7276*da0073e9SAndroid Build Coastguard Worker                exec(code, globals(), scope)
7277*da0073e9SAndroid Build Coastguard Worker                cu = torch.jit.CompilationUnit(code)
7278*da0073e9SAndroid Build Coastguard Worker                torch._C._jit_pass_complete_shape_analysis(cu.func.graph, (), False)
7279*da0073e9SAndroid Build Coastguard Worker                FileCheck().check(expect).check(f"aten::{op}").run(cu.func.graph)
7280*da0073e9SAndroid Build Coastguard Worker
7281*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
7282*da0073e9SAndroid Build Coastguard Worker        def test_dtype(inp_dtype: torch.dtype):
7283*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor(1.0, dtype=torch.float, requires_grad=True)
7284*da0073e9SAndroid Build Coastguard Worker            return a, torch.tensor(1.0, dtype=inp_dtype)
7285*da0073e9SAndroid Build Coastguard Worker
7286*da0073e9SAndroid Build Coastguard Worker        if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
7287*da0073e9SAndroid Build Coastguard Worker            g = test_dtype.graph_for(5, profile_and_replay=True)
7288*da0073e9SAndroid Build Coastguard Worker            # both should have completed shapes
7289*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("Tensor = aten::tensor").check("Float(device=cpu) = prim::BailOut") \
7290*da0073e9SAndroid Build Coastguard Worker                       .check("Tensor = aten::tensor").check("Half(device=cpu) = prim::BailOut").run(g)
7291*da0073e9SAndroid Build Coastguard Worker        else:
7292*da0073e9SAndroid Build Coastguard Worker            g = test_dtype.graph_for(5)
7293*da0073e9SAndroid Build Coastguard Worker            # first should have type set second should not
7294*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("Float(requires_grad=1, device=cpu) = aten::tensor") \
7295*da0073e9SAndroid Build Coastguard Worker                       .check("Tensor(requires_grad=0) = aten::tensor").run(g)
7296*da0073e9SAndroid Build Coastguard Worker
7297*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
7298*da0073e9SAndroid Build Coastguard Worker        def test_as_tensor_tensor_input(input):
7299*da0073e9SAndroid Build Coastguard Worker            a = torch.as_tensor(input, dtype=input.dtype)
7300*da0073e9SAndroid Build Coastguard Worker            return a, torch.as_tensor(input, dtype=torch.float)
7301*da0073e9SAndroid Build Coastguard Worker
7302*da0073e9SAndroid Build Coastguard Worker        if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
7303*da0073e9SAndroid Build Coastguard Worker            g = test_as_tensor_tensor_input.graph_for(torch.ones(3, 4), profile_and_replay=True)
7304*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("Tensor = aten::as_tensor").check("Float(3, 4) = prim::BailOut") \
7305*da0073e9SAndroid Build Coastguard Worker                       .check("Tensor = aten::as_tensor").check("Float(3, 4) = prim::BailOut").run(g)
7306*da0073e9SAndroid Build Coastguard Worker        else:
7307*da0073e9SAndroid Build Coastguard Worker            g = test_as_tensor_tensor_input.graph_for(torch.ones(3, 4))
7308*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("Tensor = aten::as_tensor").check("Float(*, *, requires_grad=0, device=cpu) = aten::as_tensor").run(g)
7309*da0073e9SAndroid Build Coastguard Worker
7310*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "testing legacy behavior")
7311*da0073e9SAndroid Build Coastguard Worker    def test_tensor_requires_grad(self):
7312*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
7313*da0073e9SAndroid Build Coastguard Worker        def test(b):
7314*da0073e9SAndroid Build Coastguard Worker            # type: (bool) -> Tuple[Tensor, Tensor, Tensor]
7315*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor(1., requires_grad=b)
7316*da0073e9SAndroid Build Coastguard Worker            b = torch.tensor(1., requires_grad=True)
7317*da0073e9SAndroid Build Coastguard Worker            c = torch.tensor(1., requires_grad=False)
7318*da0073e9SAndroid Build Coastguard Worker            return a, b, c
7319*da0073e9SAndroid Build Coastguard Worker
7320*da0073e9SAndroid Build Coastguard Worker        g = test.graph_for(True)
7321*da0073e9SAndroid Build Coastguard Worker        out = next(g.outputs())
7322*da0073e9SAndroid Build Coastguard Worker        out_inp = list(out.node().inputs())
7323*da0073e9SAndroid Build Coastguard Worker
7324*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(out_inp[0].requires_grad())
7325*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(out_inp[1].requires_grad())
7326*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(out_inp[2].requires_grad())
7327*da0073e9SAndroid Build Coastguard Worker
7328*da0073e9SAndroid Build Coastguard Worker    def test_grad_from_script(self):
7329*da0073e9SAndroid Build Coastguard Worker        def test():
7330*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor(2.5, requires_grad=True)
7331*da0073e9SAndroid Build Coastguard Worker            b = a * 2
7332*da0073e9SAndroid Build Coastguard Worker            return a, b
7333*da0073e9SAndroid Build Coastguard Worker
7334*da0073e9SAndroid Build Coastguard Worker        a, b = test()
7335*da0073e9SAndroid Build Coastguard Worker        b.backward()
7336*da0073e9SAndroid Build Coastguard Worker
7337*da0073e9SAndroid Build Coastguard Worker        a_script, b_script = torch.jit.script(test)()
7338*da0073e9SAndroid Build Coastguard Worker        b_script.backward()
7339*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.grad, a_script.grad)
7340*da0073e9SAndroid Build Coastguard Worker
7341*da0073e9SAndroid Build Coastguard Worker    def test_torch_tensor_as_tensor(self):
7342*da0073e9SAndroid Build Coastguard Worker        tensor_template = dedent('''
7343*da0073e9SAndroid Build Coastguard Worker        def func():
7344*da0073e9SAndroid Build Coastguard Worker            li = {list_create}
7345*da0073e9SAndroid Build Coastguard Worker            ten1 = torch.{tensor_op}(li {options})
7346*da0073e9SAndroid Build Coastguard Worker            return ten1
7347*da0073e9SAndroid Build Coastguard Worker        ''')
7348*da0073e9SAndroid Build Coastguard Worker
7349*da0073e9SAndroid Build Coastguard Worker        lists = ["2.5", "4", "True", "False", "[2]", "[-.5]", "[False, True, False]", "[2, 2]", "(1, 1)",
7350*da0073e9SAndroid Build Coastguard Worker                 "torch.jit.annotate(List[List[int]], [])",
7351*da0073e9SAndroid Build Coastguard Worker                 "torch.jit.annotate(List[int], [])", "[2.5, 2.5]", "[[2], [2]]", "[[-.5], [2.2]]", "[[False], [True]]"]
7352*da0073e9SAndroid Build Coastguard Worker
7353*da0073e9SAndroid Build Coastguard Worker        dtypes = ["", ", dtype=torch.float", ", dtype=torch.double", ", dtype=torch.half",
7354*da0073e9SAndroid Build Coastguard Worker                  ", dtype=torch.uint8", ", dtype=torch.int8", ", dtype=torch.short",
7355*da0073e9SAndroid Build Coastguard Worker                  ", dtype=torch.int", ", dtype=torch.long", ", dtype=torch.cfloat",
7356*da0073e9SAndroid Build Coastguard Worker                  ", dtype=torch.cdouble"]
7357*da0073e9SAndroid Build Coastguard Worker
7358*da0073e9SAndroid Build Coastguard Worker        ops = ['tensor', 'as_tensor']
7359*da0073e9SAndroid Build Coastguard Worker        devices = ['', ", device='cpu'"]
7360*da0073e9SAndroid Build Coastguard Worker        if RUN_CUDA:
7361*da0073e9SAndroid Build Coastguard Worker            devices.append(", device='cuda'")
7362*da0073e9SAndroid Build Coastguard Worker
7363*da0073e9SAndroid Build Coastguard Worker        option_pairs = [dtype + device for dtype in dtypes for device in devices]
7364*da0073e9SAndroid Build Coastguard Worker        for op in ops:
7365*da0073e9SAndroid Build Coastguard Worker            for li in lists:
7366*da0073e9SAndroid Build Coastguard Worker                for option in option_pairs:
7367*da0073e9SAndroid Build Coastguard Worker                    # tensor from empty list is type float in python and annotated type in torchscript
7368*da0073e9SAndroid Build Coastguard Worker                    if "annotate" in li and "dtype" not in option:
7369*da0073e9SAndroid Build Coastguard Worker                        continue
7370*da0073e9SAndroid Build Coastguard Worker                    # Skip unsigned tensor initializaton for signed values on 3.10
7371*da0073e9SAndroid Build Coastguard Worker                    if sys.version_info[:2] >= (3, 10) and "torch.uint8" in option and "-" in li:
7372*da0073e9SAndroid Build Coastguard Worker                        continue
7373*da0073e9SAndroid Build Coastguard Worker                    code = tensor_template.format(list_create=li, tensor_op=op, options=option)
7374*da0073e9SAndroid Build Coastguard Worker                    scope = {}
7375*da0073e9SAndroid Build Coastguard Worker                    exec(code, globals(), scope)
7376*da0073e9SAndroid Build Coastguard Worker                    cu = torch.jit.CompilationUnit(code)
7377*da0073e9SAndroid Build Coastguard Worker                    t1 = cu.func()
7378*da0073e9SAndroid Build Coastguard Worker                    t2 = scope['func']()
7379*da0073e9SAndroid Build Coastguard Worker                    if t1.dtype == torch.float16:  # equality NYI for half tensor
7380*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue(str(t1) == str(t2))
7381*da0073e9SAndroid Build Coastguard Worker                    else:
7382*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(t1, t2)
7383*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(t1.dtype, t2.dtype)
7384*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(t1.device, t2.device)
7385*da0073e9SAndroid Build Coastguard Worker
7386*da0073e9SAndroid Build Coastguard Worker        def test_as_tensor_tensor_input(input):
7387*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor) -> Tuple[Tensor, Tensor, Tensor]
7388*da0073e9SAndroid Build Coastguard Worker            return torch.as_tensor(input, dtype=torch.cfloat), torch.as_tensor(input, dtype=torch.float), \
7389*da0073e9SAndroid Build Coastguard Worker                torch.as_tensor(input, dtype=torch.int32)
7390*da0073e9SAndroid Build Coastguard Worker
7391*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(3, 4, dtype=torch.cfloat)
7392*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_as_tensor_tensor_input, (inp,))
7393*da0073e9SAndroid Build Coastguard Worker
7394*da0073e9SAndroid Build Coastguard Worker    def test_torch_tensor_dtype(self):
7395*da0073e9SAndroid Build Coastguard Worker        def foo(s: float):
7396*da0073e9SAndroid Build Coastguard Worker            return torch.tensor(s), torch.tensor([s, s])
7397*da0073e9SAndroid Build Coastguard Worker
7398*da0073e9SAndroid Build Coastguard Worker        # need to clear function cache so we re run shape analysis
7399*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.double):
7400*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.jit.script(foo)(1.), foo(1.), exact_dtype=True)
7401*da0073e9SAndroid Build Coastguard Worker            if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
7402*da0073e9SAndroid Build Coastguard Worker                FileCheck().check("Double").check_same("aten::tensor").run(torch.jit.last_executed_optimized_graph())
7403*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.float):
7404*da0073e9SAndroid Build Coastguard Worker            del torch.jit._state._jit_caching_layer[foo]
7405*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.jit.script(foo)(1.), foo(1.), exact_dtype=True)
7406*da0073e9SAndroid Build Coastguard Worker            if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
7407*da0073e9SAndroid Build Coastguard Worker                FileCheck().check("Float").check_same("aten::tensor").run(torch.jit.last_executed_optimized_graph())
7408*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.half):
7409*da0073e9SAndroid Build Coastguard Worker            del torch.jit._state._jit_caching_layer[foo]
7410*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.jit.script(foo)(1.), foo(1.), exact_dtype=True)
7411*da0073e9SAndroid Build Coastguard Worker            if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
7412*da0073e9SAndroid Build Coastguard Worker                FileCheck().check("Half").check_same("aten::tensor").run(torch.jit.last_executed_optimized_graph())
7413*da0073e9SAndroid Build Coastguard Worker
7414*da0073e9SAndroid Build Coastguard Worker    def test_shape_analysis_grad_property(self):
7415*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
7416*da0073e9SAndroid Build Coastguard Worker        def foo(x):
7417*da0073e9SAndroid Build Coastguard Worker            return torch.sub(x, torch.tanh(x))
7418*da0073e9SAndroid Build Coastguard Worker
7419*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_complete_shape_analysis(foo.graph, (torch.tensor([0.39]),), False)
7420*da0073e9SAndroid Build Coastguard Worker
7421*da0073e9SAndroid Build Coastguard Worker        # requires_grad property shouldn't be accidentally set by shape analysis
7422*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(foo.graph.findNode("aten::sub").output().requiresGrad() is None)
7423*da0073e9SAndroid Build Coastguard Worker
7424*da0073e9SAndroid Build Coastguard Worker    def test_empty_like_memory_format_bc(self):
7425*da0073e9SAndroid Build Coastguard Worker        def f(x):
7426*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor) -> Tensor
7427*da0073e9SAndroid Build Coastguard Worker            return torch.zeros_like(x, memory_format=None)
7428*da0073e9SAndroid Build Coastguard Worker
7429*da0073e9SAndroid Build Coastguard Worker        scripted_f = torch.jit.script(f)
7430*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
7431*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_f(x), f(x))
7432*da0073e9SAndroid Build Coastguard Worker
7433*da0073e9SAndroid Build Coastguard Worker    def test_multiline_string_dedents(self):
7434*da0073e9SAndroid Build Coastguard Worker        def foo() -> None:
7435*da0073e9SAndroid Build Coastguard Worker            multiline_string_dedent_1 = """
7436*da0073e9SAndroid Build Coastguard WorkerThis is a string dedent """
7437*da0073e9SAndroid Build Coastguard Worker            multiline_string_dedent_2 = """ This is a
7438*da0073e9SAndroid Build Coastguard Worker  string dedent """
7439*da0073e9SAndroid Build Coastguard Worker            multiline_string_dedent_3 = """
7440*da0073e9SAndroid Build Coastguard Worker            This is a string
7441*da0073e9SAndroid Build Coastguard Workerdedent """
7442*da0073e9SAndroid Build Coastguard Worker            multiline_string_dedent_4 = """ This is a string dedent """
7443*da0073e9SAndroid Build Coastguard Worker
7444*da0073e9SAndroid Build Coastguard Worker        scripted_foo = torch.jit.script(foo)
7445*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_foo(), foo())
7446*da0073e9SAndroid Build Coastguard Worker
7447*da0073e9SAndroid Build Coastguard Worker    def test_class_with_comment_at_lower_indentation(self):
7448*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
7449*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
7450*da0073e9SAndroid Build Coastguard Worker                x = torch.neg(x)
7451*da0073e9SAndroid Build Coastguard Worker        # This comment is at the wrong indent
7452*da0073e9SAndroid Build Coastguard Worker                return x
7453*da0073e9SAndroid Build Coastguard Worker
7454*da0073e9SAndroid Build Coastguard Worker        torch.jit.script(Foo())
7455*da0073e9SAndroid Build Coastguard Worker
7456*da0073e9SAndroid Build Coastguard Worker    # adapted from test in test_torch
7457*da0073e9SAndroid Build Coastguard Worker    def test_tensor_to(self):
7458*da0073e9SAndroid Build Coastguard Worker        template = dedent('''
7459*da0073e9SAndroid Build Coastguard Worker        def func(t):
7460*da0073e9SAndroid Build Coastguard Worker            cuda = "{cuda}"
7461*da0073e9SAndroid Build Coastguard Worker            device = "{device}"
7462*da0073e9SAndroid Build Coastguard Worker            non_blocking = {non_blocking}
7463*da0073e9SAndroid Build Coastguard Worker            return {to_str}
7464*da0073e9SAndroid Build Coastguard Worker        ''')
7465*da0073e9SAndroid Build Coastguard Worker
7466*da0073e9SAndroid Build Coastguard Worker        def s(t, to_str, non_blocking=None, device=None, cuda=None):
7467*da0073e9SAndroid Build Coastguard Worker            device = device if device is not None else str(t.device)
7468*da0073e9SAndroid Build Coastguard Worker            non_blocking = non_blocking if non_blocking is not None else False
7469*da0073e9SAndroid Build Coastguard Worker            cuda = "cuda" if cuda is None else cuda
7470*da0073e9SAndroid Build Coastguard Worker            code = template.format(to_str=to_str, device=device, non_blocking=non_blocking, cuda=cuda)
7471*da0073e9SAndroid Build Coastguard Worker            scope = {}
7472*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit(code)
7473*da0073e9SAndroid Build Coastguard Worker            return cu.func(t, profile_and_replay=True)
7474*da0073e9SAndroid Build Coastguard Worker
7475*da0073e9SAndroid Build Coastguard Worker        def test_copy_behavior(t, non_blocking=False):
7476*da0073e9SAndroid Build Coastguard Worker            self.assertIs(t, s(t, 't.to(t, non_blocking=non_blocking)', non_blocking))
7477*da0073e9SAndroid Build Coastguard Worker            self.assertIs(t, s(t, 't.to(t.dtype, non_blocking=non_blocking)', non_blocking))
7478*da0073e9SAndroid Build Coastguard Worker            self.assertIs(t, s(t, 't.to(torch.empty_like(t), non_blocking=non_blocking)', non_blocking))
7479*da0073e9SAndroid Build Coastguard Worker            self.assertIsNot(t, s(t, 't.to(t, non_blocking=non_blocking, copy=True)', non_blocking))
7480*da0073e9SAndroid Build Coastguard Worker            self.assertIsNot(t, s(t, 't.to(t.dtype, non_blocking=non_blocking, copy=True)', non_blocking))
7481*da0073e9SAndroid Build Coastguard Worker            self.assertIsNot(t, s(t, 't.to(torch.empty_like(t), non_blocking=non_blocking, copy=True)', non_blocking))
7482*da0073e9SAndroid Build Coastguard Worker
7483*da0073e9SAndroid Build Coastguard Worker            devices = [t.device]
7484*da0073e9SAndroid Build Coastguard Worker            if t.device.type == 'cuda':
7485*da0073e9SAndroid Build Coastguard Worker                if t.device.index == -1:
7486*da0073e9SAndroid Build Coastguard Worker                    devices.append(f'cuda:{torch.cuda.current_device()}')
7487*da0073e9SAndroid Build Coastguard Worker                elif t.device.index == torch.cuda.current_device():
7488*da0073e9SAndroid Build Coastguard Worker                    devices.append('cuda')
7489*da0073e9SAndroid Build Coastguard Worker            for device in devices:
7490*da0073e9SAndroid Build Coastguard Worker                self.assertIs(t, s(t, 't.to(device, non_blocking=non_blocking)', non_blocking, device))
7491*da0073e9SAndroid Build Coastguard Worker                self.assertIs(t, s(t, 't.to(device, t.dtype, non_blocking=non_blocking)', non_blocking, device))
7492*da0073e9SAndroid Build Coastguard Worker                self.assertIsNot(t, s(t, 't.to(device, non_blocking=non_blocking, copy=True)', non_blocking, device))
7493*da0073e9SAndroid Build Coastguard Worker                self.assertIsNot(t, s(t, 't.to(device, t.dtype, non_blocking=non_blocking, copy=True)',
7494*da0073e9SAndroid Build Coastguard Worker                                      non_blocking, device))
7495*da0073e9SAndroid Build Coastguard Worker
7496*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor(5)
7497*da0073e9SAndroid Build Coastguard Worker        test_copy_behavior(t)
7498*da0073e9SAndroid Build Coastguard Worker
7499*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.device, s(t, "t.to('cpu')").device)
7500*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.device, s(t, "t.to('cpu', dtype=torch.float32)").device)
7501*da0073e9SAndroid Build Coastguard Worker        self.assertIs(torch.float32, s(t, "t.to('cpu', dtype=torch.float32)").dtype)
7502*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.device, s(t, "t.to(torch.float32)").device)
7503*da0073e9SAndroid Build Coastguard Worker        self.assertIs(torch.float32, s(t, "t.to(dtype=torch.float32)").dtype)
7504*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.data_ptr(), s(t, "t.to('cpu')").data_ptr())
7505*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.data_ptr(), s(t, "t.to(dtype=t.dtype, device=t.device, copy=False)").data_ptr())
7506*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.data_ptr(), s(t, "t.to('cpu', copy=False)").data_ptr())
7507*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(t.data_ptr(), s(t, "t.to('cpu', copy=True)").data_ptr())
7508*da0073e9SAndroid Build Coastguard Worker
7509*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(5)
7510*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
7511*da0073e9SAndroid Build Coastguard Worker            for non_blocking in [True, False]:
7512*da0073e9SAndroid Build Coastguard Worker                for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']:
7513*da0073e9SAndroid Build Coastguard Worker                    b = torch.tensor(5., device=cuda)
7514*da0073e9SAndroid Build Coastguard Worker                    test_copy_behavior(b, non_blocking)
7515*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(b.device, s(b, "t.to(cuda, non_blocking=non_blocking).device", cuda=cuda))
7516*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(a.device, s(b, "t.to('cpu', non_blocking=non_blocking).device"))
7517*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(b.device, s(b, "t.to(cuda, non_blocking=non_blocking).device", cuda=cuda))
7518*da0073e9SAndroid Build Coastguard Worker                    self.assertIs(torch.int32, s(b, "t.to('cpu', dtype=torch.int32, non_blocking=non_blocking)").dtype)
7519*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(a.device, s(b, "t.to('cpu', dtype=torch.int32, non_blocking=non_blocking)").device)
7520*da0073e9SAndroid Build Coastguard Worker                    self.assertIs(torch.int32, s(b, "t.to(dtype=torch.int32)").dtype)
7521*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(b.device, s(b, "t.to(dtype=torch.int32)").device)
7522*da0073e9SAndroid Build Coastguard Worker
7523*da0073e9SAndroid Build Coastguard Worker        # Test AD: aten::to(Tensor self, int dtype, bool non_blocking, bool copy) -> Tensor
7524*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor(5).float().requires_grad_()
7525*da0073e9SAndroid Build Coastguard Worker        out_ref = t.to(torch.float32)
7526*da0073e9SAndroid Build Coastguard Worker        out = s(t, "t.to(torch.float32)")
7527*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref, out)
7528*da0073e9SAndroid Build Coastguard Worker
7529*da0073e9SAndroid Build Coastguard Worker        grad_ref = torch.autograd.grad(out_ref.sum(), t)
7530*da0073e9SAndroid Build Coastguard Worker        grad = torch.autograd.grad(out.sum(), t)
7531*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(grad_ref, grad)
7532*da0073e9SAndroid Build Coastguard Worker
7533*da0073e9SAndroid Build Coastguard Worker        # Test AD: aten::to(Tensor self, Device? device, int? dtype, bool non_blocking, bool copy) -> Tensor
7534*da0073e9SAndroid Build Coastguard Worker        out_ref = t.to('cpu')
7535*da0073e9SAndroid Build Coastguard Worker        out = s(t, "t.to('cpu')")
7536*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref, out)
7537*da0073e9SAndroid Build Coastguard Worker
7538*da0073e9SAndroid Build Coastguard Worker        grad_ref = torch.autograd.grad(out_ref.sum(), t)
7539*da0073e9SAndroid Build Coastguard Worker        grad = torch.autograd.grad(out.sum(), t)
7540*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(grad_ref, grad)
7541*da0073e9SAndroid Build Coastguard Worker
7542*da0073e9SAndroid Build Coastguard Worker        # Test AD: aten::to(Tensor self, Tensor other, bool non_blocking, bool copy) -> Tensor
7543*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
7544*da0073e9SAndroid Build Coastguard Worker        def func2(t, t_ref):
7545*da0073e9SAndroid Build Coastguard Worker            return t.to(t_ref)
7546*da0073e9SAndroid Build Coastguard Worker
7547*da0073e9SAndroid Build Coastguard Worker        with disable_autodiff_subgraph_inlining():
7548*da0073e9SAndroid Build Coastguard Worker            t_ref = torch.tensor(4).double()
7549*da0073e9SAndroid Build Coastguard Worker            out_ref = t.to(t_ref)
7550*da0073e9SAndroid Build Coastguard Worker            out = func2(t, t_ref)
7551*da0073e9SAndroid Build Coastguard Worker            grad_ref = torch.autograd.grad(out_ref.sum(), t)
7552*da0073e9SAndroid Build Coastguard Worker            grad = torch.autograd.grad(out.sum(), t)
7553*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(grad_ref, grad)
7554*da0073e9SAndroid Build Coastguard Worker
7555*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "No CUDA")
7556*da0073e9SAndroid Build Coastguard Worker    def test_tensor_number_math_cuda(self):
7557*da0073e9SAndroid Build Coastguard Worker        self._test_tensor_number_math(device='cuda')
7558*da0073e9SAndroid Build Coastguard Worker
7559*da0073e9SAndroid Build Coastguard Worker    def test_not(self):
7560*da0073e9SAndroid Build Coastguard Worker        # test not operator in python
7561*da0073e9SAndroid Build Coastguard Worker        # TODO: add more tests when bool conversions ready
7562*da0073e9SAndroid Build Coastguard Worker        def test_not_op(a):
7563*da0073e9SAndroid Build Coastguard Worker            return not bool(a > 1)
7564*da0073e9SAndroid Build Coastguard Worker
7565*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_not_op, (torch.tensor(2), ), optimize=True)
7566*da0073e9SAndroid Build Coastguard Worker
7567*da0073e9SAndroid Build Coastguard Worker    def test_is_isnot(self):
7568*da0073e9SAndroid Build Coastguard Worker        # test is and is not operator in python
7569*da0073e9SAndroid Build Coastguard Worker        template = dedent('''
7570*da0073e9SAndroid Build Coastguard Worker        def func():
7571*da0073e9SAndroid Build Coastguard Worker            # type: () -> bool
7572*da0073e9SAndroid Build Coastguard Worker            return {lhs} {op} {rhs}
7573*da0073e9SAndroid Build Coastguard Worker        ''')
7574*da0073e9SAndroid Build Coastguard Worker
7575*da0073e9SAndroid Build Coastguard Worker        def test(op, args):
7576*da0073e9SAndroid Build Coastguard Worker            code = template.format(lhs=args[0], rhs=args[1], op=op)
7577*da0073e9SAndroid Build Coastguard Worker            scope = {}
7578*da0073e9SAndroid Build Coastguard Worker            execWrapper(code, globals(), scope)
7579*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit(code)
7580*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
7581*da0073e9SAndroid Build Coastguard Worker                cu.func(),
7582*da0073e9SAndroid Build Coastguard Worker                scope['func'](),
7583*da0073e9SAndroid Build Coastguard Worker                msg=f"Failed with op: {op}, lhs: {args[0]}, rhs: {args[1]}"
7584*da0073e9SAndroid Build Coastguard Worker            )
7585*da0073e9SAndroid Build Coastguard Worker
7586*da0073e9SAndroid Build Coastguard Worker        ops = ['is', 'is not']
7587*da0073e9SAndroid Build Coastguard Worker        type_literals = [True, False, None, [1, 1], 1, 2, .5, 1.5]
7588*da0073e9SAndroid Build Coastguard Worker
7589*da0073e9SAndroid Build Coastguard Worker        # do literals product to try any types combinations
7590*da0073e9SAndroid Build Coastguard Worker        for op, lhs, rhs in product(ops, type_literals, type_literals):
7591*da0073e9SAndroid Build Coastguard Worker            test(op, [lhs, rhs])
7592*da0073e9SAndroid Build Coastguard Worker
7593*da0073e9SAndroid Build Coastguard Worker    def test_isinstance_refinement(self):
7594*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
7595*da0073e9SAndroid Build Coastguard Worker        def foo(a):
7596*da0073e9SAndroid Build Coastguard Worker            # type: (Optional[int]) -> int
7597*da0073e9SAndroid Build Coastguard Worker            if isinstance(a, int):
7598*da0073e9SAndroid Build Coastguard Worker                return a + 3
7599*da0073e9SAndroid Build Coastguard Worker            else:
7600*da0073e9SAndroid Build Coastguard Worker                return 4
7601*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(4), 7)
7602*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(None), 4)
7603*da0073e9SAndroid Build Coastguard Worker
7604*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
7605*da0073e9SAndroid Build Coastguard Worker        def foo2(a, b):
7606*da0073e9SAndroid Build Coastguard Worker            # type: (Optional[int], Optional[int]) -> int
7607*da0073e9SAndroid Build Coastguard Worker            if not isinstance(a, int) or not isinstance(b, int):
7608*da0073e9SAndroid Build Coastguard Worker                return 0
7609*da0073e9SAndroid Build Coastguard Worker            else:
7610*da0073e9SAndroid Build Coastguard Worker                return a + b
7611*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo2(3, 4), 7)
7612*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo2(None, 4), 0)
7613*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo2(4, None), 0)
7614*da0073e9SAndroid Build Coastguard Worker
7615*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
7616*da0073e9SAndroid Build Coastguard Worker        def any_refinement(a, b):
7617*da0073e9SAndroid Build Coastguard Worker            # type: (Any, Any) -> int
7618*da0073e9SAndroid Build Coastguard Worker            if isinstance(a, int) and isinstance(b, int):
7619*da0073e9SAndroid Build Coastguard Worker                return a + b
7620*da0073e9SAndroid Build Coastguard Worker            return 0
7621*da0073e9SAndroid Build Coastguard Worker
7622*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(any_refinement(3, 4), 7)
7623*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(any_refinement(3, "hi"), 0)
7624*da0073e9SAndroid Build Coastguard Worker
7625*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
7626*da0073e9SAndroid Build Coastguard Worker        def any_refinement2(a):
7627*da0073e9SAndroid Build Coastguard Worker            # type: (Any) -> Tensor
7628*da0073e9SAndroid Build Coastguard Worker            if isinstance(a, Tensor):
7629*da0073e9SAndroid Build Coastguard Worker                return a
7630*da0073e9SAndroid Build Coastguard Worker            return torch.tensor(3)
7631*da0073e9SAndroid Build Coastguard Worker
7632*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(any_refinement2(3), torch.tensor(3))
7633*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(any_refinement2(torch.tensor(5)), torch.tensor(5))
7634*da0073e9SAndroid Build Coastguard Worker
7635*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.LEGACY, "bug persists in deprecated executor")
7636*da0073e9SAndroid Build Coastguard Worker    def test_unspecialized_any_binding(self):
7637*da0073e9SAndroid Build Coastguard Worker        # any binding will infer the type, if it infers
7638*da0073e9SAndroid Build Coastguard Worker        # a specialized tensor type `x` Dict type will fail isinstance check
7639*da0073e9SAndroid Build Coastguard Worker
7640*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
7641*da0073e9SAndroid Build Coastguard Worker        def foo(x: Any):
7642*da0073e9SAndroid Build Coastguard Worker            assert isinstance(x, Dict[str, torch.Tensor])
7643*da0073e9SAndroid Build Coastguard Worker
7644*da0073e9SAndroid Build Coastguard Worker        foo({"1": torch.tensor(3)})
7645*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(Exception):
7646*da0073e9SAndroid Build Coastguard Worker            foo(2)
7647*da0073e9SAndroid Build Coastguard Worker
7648*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
7649*da0073e9SAndroid Build Coastguard Worker    def test_isinstance(self):
7650*da0073e9SAndroid Build Coastguard Worker        # test isinstance operator for static type checking
7651*da0073e9SAndroid Build Coastguard Worker        template = dedent('''
7652*da0073e9SAndroid Build Coastguard Worker        def func(x):
7653*da0073e9SAndroid Build Coastguard Worker            # type: ({type_hint}) -> bool
7654*da0073e9SAndroid Build Coastguard Worker            return isinstance(x, {typ})
7655*da0073e9SAndroid Build Coastguard Worker        ''')
7656*da0073e9SAndroid Build Coastguard Worker
7657*da0073e9SAndroid Build Coastguard Worker        def test(inp, typ, type_hint):
7658*da0073e9SAndroid Build Coastguard Worker            code = template.format(typ=typ, type_hint=type_hint)
7659*da0073e9SAndroid Build Coastguard Worker            scope = {}
7660*da0073e9SAndroid Build Coastguard Worker            execWrapper(code, globals(), scope)
7661*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit(code)
7662*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
7663*da0073e9SAndroid Build Coastguard Worker                cu.func(inp),
7664*da0073e9SAndroid Build Coastguard Worker                scope['func'](inp),
7665*da0073e9SAndroid Build Coastguard Worker                msg=f"Failed with typ: {typ}"
7666*da0073e9SAndroid Build Coastguard Worker            )
7667*da0073e9SAndroid Build Coastguard Worker
7668*da0073e9SAndroid Build Coastguard Worker        inputs = [True, 1, 1.0, torch.tensor(1), [1, 2], (1.0,), [1, 2], 1]
7669*da0073e9SAndroid Build Coastguard Worker        type_literals = ['bool', 'int', 'float', 'torch.Tensor', 'list', 'tuple',
7670*da0073e9SAndroid Build Coastguard Worker                         '(list, tuple)', '(int, float, bool)']
7671*da0073e9SAndroid Build Coastguard Worker        type_annotations = ['bool', 'int', 'float', 'Tensor', 'List[int]', 'Tuple[float]',
7672*da0073e9SAndroid Build Coastguard Worker                            'List[int]', 'int']
7673*da0073e9SAndroid Build Coastguard Worker
7674*da0073e9SAndroid Build Coastguard Worker        # do zipping to try different types
7675*da0073e9SAndroid Build Coastguard Worker        for inp, typ, type_hint in zip(inputs, type_literals, type_annotations):
7676*da0073e9SAndroid Build Coastguard Worker            test(inp, typ, type_hint)
7677*da0073e9SAndroid Build Coastguard Worker
7678*da0073e9SAndroid Build Coastguard Worker        # test optional isinstance check
7679*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
7680*da0073e9SAndroid Build Coastguard Worker        def opt_func(x):
7681*da0073e9SAndroid Build Coastguard Worker            # type: (Optional[int]) -> bool
7682*da0073e9SAndroid Build Coastguard Worker            return isinstance(x, int)
7683*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(opt_func(3))
7684*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(opt_func(None))
7685*da0073e9SAndroid Build Coastguard Worker
7686*da0073e9SAndroid Build Coastguard Worker    def test_dropout_eval(self):
7687*da0073e9SAndroid Build Coastguard Worker        class ScriptedConv2d(torch.jit.ScriptModule):
7688*da0073e9SAndroid Build Coastguard Worker            def __init__(self, in_channels, out_channels, **kwargs):
7689*da0073e9SAndroid Build Coastguard Worker                super().__init__()
7690*da0073e9SAndroid Build Coastguard Worker                self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
7691*da0073e9SAndroid Build Coastguard Worker                self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
7692*da0073e9SAndroid Build Coastguard Worker
7693*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
7694*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
7695*da0073e9SAndroid Build Coastguard Worker                x = self.conv(x)
7696*da0073e9SAndroid Build Coastguard Worker                x = self.bn(x)
7697*da0073e9SAndroid Build Coastguard Worker                return F.relu(x, inplace=True)
7698*da0073e9SAndroid Build Coastguard Worker
7699*da0073e9SAndroid Build Coastguard Worker        class ScriptMod(torch.jit.ScriptModule):
7700*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
7701*da0073e9SAndroid Build Coastguard Worker                super().__init__()
7702*da0073e9SAndroid Build Coastguard Worker                self.Conv2d_1a_3x3 = ScriptedConv2d(3, 32, kernel_size=3, stride=2)
7703*da0073e9SAndroid Build Coastguard Worker
7704*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
7705*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
7706*da0073e9SAndroid Build Coastguard Worker                x = self.Conv2d_1a_3x3(x)
7707*da0073e9SAndroid Build Coastguard Worker                return F.dropout(x, training=self.training)
7708*da0073e9SAndroid Build Coastguard Worker
7709*da0073e9SAndroid Build Coastguard Worker        class EagerConv2d(torch.nn.Module):
7710*da0073e9SAndroid Build Coastguard Worker            def __init__(self, in_channels, out_channels, **kwargs):
7711*da0073e9SAndroid Build Coastguard Worker                super().__init__()
7712*da0073e9SAndroid Build Coastguard Worker                self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
7713*da0073e9SAndroid Build Coastguard Worker                self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
7714*da0073e9SAndroid Build Coastguard Worker
7715*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
7716*da0073e9SAndroid Build Coastguard Worker                x = self.conv(x)
7717*da0073e9SAndroid Build Coastguard Worker                x = self.bn(x)
7718*da0073e9SAndroid Build Coastguard Worker                return F.relu(x, inplace=True)
7719*da0073e9SAndroid Build Coastguard Worker
7720*da0073e9SAndroid Build Coastguard Worker        class EagerMod(torch.nn.Module):
7721*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
7722*da0073e9SAndroid Build Coastguard Worker                super().__init__()
7723*da0073e9SAndroid Build Coastguard Worker                self.Conv2d_1a_3x3 = EagerConv2d(3, 32, kernel_size=3, stride=2)
7724*da0073e9SAndroid Build Coastguard Worker
7725*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
7726*da0073e9SAndroid Build Coastguard Worker                x = self.Conv2d_1a_3x3(x)
7727*da0073e9SAndroid Build Coastguard Worker                return F.dropout(x, training=self.training)
7728*da0073e9SAndroid Build Coastguard Worker
7729*da0073e9SAndroid Build Coastguard Worker        script_input = torch.rand(4, 3, 299, 299)
7730*da0073e9SAndroid Build Coastguard Worker        eager_input = script_input.clone()
7731*da0073e9SAndroid Build Coastguard Worker
7732*da0073e9SAndroid Build Coastguard Worker        with freeze_rng_state():
7733*da0073e9SAndroid Build Coastguard Worker            script_mod = ScriptMod()
7734*da0073e9SAndroid Build Coastguard Worker            script_mod.eval()
7735*da0073e9SAndroid Build Coastguard Worker            script_output = script_mod(script_input)
7736*da0073e9SAndroid Build Coastguard Worker
7737*da0073e9SAndroid Build Coastguard Worker        with freeze_rng_state():
7738*da0073e9SAndroid Build Coastguard Worker            eager_mod = EagerMod()
7739*da0073e9SAndroid Build Coastguard Worker            eager_mod.eval()
7740*da0073e9SAndroid Build Coastguard Worker            eager_output = eager_mod(eager_input)
7741*da0073e9SAndroid Build Coastguard Worker
7742*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(script_output, eager_output)
7743*da0073e9SAndroid Build Coastguard Worker
7744*da0073e9SAndroid Build Coastguard Worker        with freeze_rng_state():
7745*da0073e9SAndroid Build Coastguard Worker            script_mod = ScriptMod()
7746*da0073e9SAndroid Build Coastguard Worker            script_mod.train()
7747*da0073e9SAndroid Build Coastguard Worker            script_output = script_mod(script_input)
7748*da0073e9SAndroid Build Coastguard Worker
7749*da0073e9SAndroid Build Coastguard Worker        with freeze_rng_state():
7750*da0073e9SAndroid Build Coastguard Worker            eager_mod = EagerMod()
7751*da0073e9SAndroid Build Coastguard Worker            eager_mod.train()
7752*da0073e9SAndroid Build Coastguard Worker            eager_output = eager_mod(eager_input)
7753*da0073e9SAndroid Build Coastguard Worker
7754*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(script_output, eager_output)
7755*da0073e9SAndroid Build Coastguard Worker
7756*da0073e9SAndroid Build Coastguard Worker    def test_nested_breaks(self):
7757*da0073e9SAndroid Build Coastguard Worker        def no_bool_loop_outputs(g):
7758*da0073e9SAndroid Build Coastguard Worker            # testing that the "did exit" transform values are not loop block
7759*da0073e9SAndroid Build Coastguard Worker            # outputs (and thus not affecting one loop from another)
7760*da0073e9SAndroid Build Coastguard Worker            loops = g.findAllNodes("prim::Loop")
7761*da0073e9SAndroid Build Coastguard Worker            for loop in loops:
7762*da0073e9SAndroid Build Coastguard Worker                for out in loop.outputs():
7763*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(out.type() != BoolType.get())
7764*da0073e9SAndroid Build Coastguard Worker
7765*da0073e9SAndroid Build Coastguard Worker        def test(y):
7766*da0073e9SAndroid Build Coastguard Worker            # type: (int)
7767*da0073e9SAndroid Build Coastguard Worker            ret = 0
7768*da0073e9SAndroid Build Coastguard Worker            tensor = torch.tensor(0)
7769*da0073e9SAndroid Build Coastguard Worker            while int(tensor.add_(1)) < 4:
7770*da0073e9SAndroid Build Coastguard Worker                if y == 1:
7771*da0073e9SAndroid Build Coastguard Worker                    continue
7772*da0073e9SAndroid Build Coastguard Worker                for i in range(y):
7773*da0073e9SAndroid Build Coastguard Worker                    continue
7774*da0073e9SAndroid Build Coastguard Worker                    ret += 1
7775*da0073e9SAndroid Build Coastguard Worker                ret += 1
7776*da0073e9SAndroid Build Coastguard Worker            return ret, int(tensor)
7777*da0073e9SAndroid Build Coastguard Worker
7778*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.jit.script(test)(1), test(1))
7779*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.jit.script(test)(2), test(2))
7780*da0073e9SAndroid Build Coastguard Worker        no_bool_loop_outputs(torch.jit.script(test).graph)
7781*da0073e9SAndroid Build Coastguard Worker
7782*da0073e9SAndroid Build Coastguard Worker        def foo():
7783*da0073e9SAndroid Build Coastguard Worker            y = torch.tensor(0)
7784*da0073e9SAndroid Build Coastguard Worker            z = 0
7785*da0073e9SAndroid Build Coastguard Worker            while int(y.add_(1)) < 20:
7786*da0073e9SAndroid Build Coastguard Worker                if int(y) < 10:
7787*da0073e9SAndroid Build Coastguard Worker                    for i in range(6):
7788*da0073e9SAndroid Build Coastguard Worker                        if i == 3:
7789*da0073e9SAndroid Build Coastguard Worker                            continue
7790*da0073e9SAndroid Build Coastguard Worker                        else:
7791*da0073e9SAndroid Build Coastguard Worker                            if i > 3:
7792*da0073e9SAndroid Build Coastguard Worker                                break
7793*da0073e9SAndroid Build Coastguard Worker                        z += 2
7794*da0073e9SAndroid Build Coastguard Worker                if int(y) == 18:
7795*da0073e9SAndroid Build Coastguard Worker                    break
7796*da0073e9SAndroid Build Coastguard Worker                if int(y) == 15:
7797*da0073e9SAndroid Build Coastguard Worker                    continue
7798*da0073e9SAndroid Build Coastguard Worker                z += 1
7799*da0073e9SAndroid Build Coastguard Worker            return int(y), z
7800*da0073e9SAndroid Build Coastguard Worker
7801*da0073e9SAndroid Build Coastguard Worker        no_bool_loop_outputs(torch.jit.script(foo).graph)
7802*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, ())
7803*da0073e9SAndroid Build Coastguard Worker
7804*da0073e9SAndroid Build Coastguard Worker        def test_nested_two():
7805*da0073e9SAndroid Build Coastguard Worker            i = 0
7806*da0073e9SAndroid Build Coastguard Worker            k = 0
7807*da0073e9SAndroid Build Coastguard Worker            while i < 5:
7808*da0073e9SAndroid Build Coastguard Worker                for j in range(5):
7809*da0073e9SAndroid Build Coastguard Worker                    k += 1
7810*da0073e9SAndroid Build Coastguard Worker                    if j == 3:
7811*da0073e9SAndroid Build Coastguard Worker                        continue
7812*da0073e9SAndroid Build Coastguard Worker                i += 1
7813*da0073e9SAndroid Build Coastguard Worker                k += 1
7814*da0073e9SAndroid Build Coastguard Worker                if i == 4:
7815*da0073e9SAndroid Build Coastguard Worker                    break
7816*da0073e9SAndroid Build Coastguard Worker            return i, k
7817*da0073e9SAndroid Build Coastguard Worker
7818*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_nested_two, ())
7819*da0073e9SAndroid Build Coastguard Worker        no_bool_loop_outputs(torch.jit.script(test_nested_two).graph)
7820*da0073e9SAndroid Build Coastguard Worker
7821*da0073e9SAndroid Build Coastguard Worker    def test_breaks_continues(self):
7822*da0073e9SAndroid Build Coastguard Worker        def foo_continue(cond):
7823*da0073e9SAndroid Build Coastguard Worker            # type: (int)
7824*da0073e9SAndroid Build Coastguard Worker            j = 1
7825*da0073e9SAndroid Build Coastguard Worker            for i in range(5):
7826*da0073e9SAndroid Build Coastguard Worker                if i == cond:
7827*da0073e9SAndroid Build Coastguard Worker                    continue
7828*da0073e9SAndroid Build Coastguard Worker                j += 1
7829*da0073e9SAndroid Build Coastguard Worker            return j
7830*da0073e9SAndroid Build Coastguard Worker
7831*da0073e9SAndroid Build Coastguard Worker        def foo_break(cond):
7832*da0073e9SAndroid Build Coastguard Worker            # type: (int)
7833*da0073e9SAndroid Build Coastguard Worker            j = 1
7834*da0073e9SAndroid Build Coastguard Worker            for i in range(5):
7835*da0073e9SAndroid Build Coastguard Worker                if i == cond:
7836*da0073e9SAndroid Build Coastguard Worker                    break
7837*da0073e9SAndroid Build Coastguard Worker                j += 1
7838*da0073e9SAndroid Build Coastguard Worker            return j
7839*da0073e9SAndroid Build Coastguard Worker
7840*da0073e9SAndroid Build Coastguard Worker        for i in range(1, 4):
7841*da0073e9SAndroid Build Coastguard Worker            self.checkScript(foo_continue, (i,))
7842*da0073e9SAndroid Build Coastguard Worker            self.checkScript(foo_break, (i,))
7843*da0073e9SAndroid Build Coastguard Worker
7844*da0073e9SAndroid Build Coastguard Worker        def test_refine_outside_loop():
7845*da0073e9SAndroid Build Coastguard Worker            if 1 == 1:
7846*da0073e9SAndroid Build Coastguard Worker                x = None
7847*da0073e9SAndroid Build Coastguard Worker            else:
7848*da0073e9SAndroid Build Coastguard Worker                x = 1
7849*da0073e9SAndroid Build Coastguard Worker            i = 0
7850*da0073e9SAndroid Build Coastguard Worker            j = 0
7851*da0073e9SAndroid Build Coastguard Worker            while (x is None or torch.jit._unwrap_optional(x) > 3):
7852*da0073e9SAndroid Build Coastguard Worker                if i < 3:
7853*da0073e9SAndroid Build Coastguard Worker                    if i < 3:
7854*da0073e9SAndroid Build Coastguard Worker                        x = torch.jit.annotate(Optional[int], None)
7855*da0073e9SAndroid Build Coastguard Worker                        i += 1
7856*da0073e9SAndroid Build Coastguard Worker                        continue
7857*da0073e9SAndroid Build Coastguard Worker                    x = 1
7858*da0073e9SAndroid Build Coastguard Worker                else:
7859*da0073e9SAndroid Build Coastguard Worker                    x = 1 if x is None else x
7860*da0073e9SAndroid Build Coastguard Worker                x = x + 1
7861*da0073e9SAndroid Build Coastguard Worker                j = x + x
7862*da0073e9SAndroid Build Coastguard Worker
7863*da0073e9SAndroid Build Coastguard Worker            return x, j
7864*da0073e9SAndroid Build Coastguard Worker
7865*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_refine_outside_loop, ())
7866*da0073e9SAndroid Build Coastguard Worker
7867*da0073e9SAndroid Build Coastguard Worker        def assign_after_break(y):
7868*da0073e9SAndroid Build Coastguard Worker            # type: (int)
7869*da0073e9SAndroid Build Coastguard Worker            x = 0
7870*da0073e9SAndroid Build Coastguard Worker            for i in range(y):
7871*da0073e9SAndroid Build Coastguard Worker                x = y * 2 + i
7872*da0073e9SAndroid Build Coastguard Worker                break
7873*da0073e9SAndroid Build Coastguard Worker                x = 4
7874*da0073e9SAndroid Build Coastguard Worker            return x
7875*da0073e9SAndroid Build Coastguard Worker
7876*da0073e9SAndroid Build Coastguard Worker        self.checkScript(assign_after_break, (1,))
7877*da0073e9SAndroid Build Coastguard Worker        self.checkScript(assign_after_break, (2,))
7878*da0073e9SAndroid Build Coastguard Worker        self.checkScript(assign_after_break, (3,))
7879*da0073e9SAndroid Build Coastguard Worker
7880*da0073e9SAndroid Build Coastguard Worker        def assign_after_break_nested(y):
7881*da0073e9SAndroid Build Coastguard Worker            # type: (int)
7882*da0073e9SAndroid Build Coastguard Worker            x = 0
7883*da0073e9SAndroid Build Coastguard Worker            for i in range(y):
7884*da0073e9SAndroid Build Coastguard Worker                if y == 1:
7885*da0073e9SAndroid Build Coastguard Worker                    x = 5
7886*da0073e9SAndroid Build Coastguard Worker                    break
7887*da0073e9SAndroid Build Coastguard Worker                    assert 1 == 2
7888*da0073e9SAndroid Build Coastguard Worker                else:
7889*da0073e9SAndroid Build Coastguard Worker                    x = x + 1
7890*da0073e9SAndroid Build Coastguard Worker                    break
7891*da0073e9SAndroid Build Coastguard Worker                    assert 1 == 2
7892*da0073e9SAndroid Build Coastguard Worker                x = -30
7893*da0073e9SAndroid Build Coastguard Worker                assert 1 == 2
7894*da0073e9SAndroid Build Coastguard Worker            return x
7895*da0073e9SAndroid Build Coastguard Worker
7896*da0073e9SAndroid Build Coastguard Worker        self.checkScript(assign_after_break_nested, (1,))
7897*da0073e9SAndroid Build Coastguard Worker        self.checkScript(assign_after_break_nested, (2,))
7898*da0073e9SAndroid Build Coastguard Worker        self.checkScript(assign_after_break_nested, (3,))
7899*da0073e9SAndroid Build Coastguard Worker
7900*da0073e9SAndroid Build Coastguard Worker        def may_break(y):
7901*da0073e9SAndroid Build Coastguard Worker            # type: (int)
7902*da0073e9SAndroid Build Coastguard Worker            x = 0
7903*da0073e9SAndroid Build Coastguard Worker            for i in range(y):
7904*da0073e9SAndroid Build Coastguard Worker                if y == 1:
7905*da0073e9SAndroid Build Coastguard Worker                    x = 5
7906*da0073e9SAndroid Build Coastguard Worker                else:
7907*da0073e9SAndroid Build Coastguard Worker                    x = x + 1
7908*da0073e9SAndroid Build Coastguard Worker                    break
7909*da0073e9SAndroid Build Coastguard Worker                x = -30
7910*da0073e9SAndroid Build Coastguard Worker            return x
7911*da0073e9SAndroid Build Coastguard Worker
7912*da0073e9SAndroid Build Coastguard Worker        self.checkScript(may_break, (1,))
7913*da0073e9SAndroid Build Coastguard Worker        self.checkScript(may_break, (2,))
7914*da0073e9SAndroid Build Coastguard Worker        self.checkScript(may_break, (3,))
7915*da0073e9SAndroid Build Coastguard Worker
7916*da0073e9SAndroid Build Coastguard Worker        def test(x, y):
7917*da0073e9SAndroid Build Coastguard Worker            # type: (int, int)
7918*da0073e9SAndroid Build Coastguard Worker            a = 1
7919*da0073e9SAndroid Build Coastguard Worker            while (x > 0):
7920*da0073e9SAndroid Build Coastguard Worker                if y == 3:
7921*da0073e9SAndroid Build Coastguard Worker                    for i in range(y):
7922*da0073e9SAndroid Build Coastguard Worker                        a += (1 % (i + 1))
7923*da0073e9SAndroid Build Coastguard Worker                        x -= 1
7924*da0073e9SAndroid Build Coastguard Worker                if x == 3:
7925*da0073e9SAndroid Build Coastguard Worker                    a = x * 3
7926*da0073e9SAndroid Build Coastguard Worker                    break
7927*da0073e9SAndroid Build Coastguard Worker                if x < 3:
7928*da0073e9SAndroid Build Coastguard Worker                    if x == 1:
7929*da0073e9SAndroid Build Coastguard Worker                        a -= 2
7930*da0073e9SAndroid Build Coastguard Worker                        x -= 1
7931*da0073e9SAndroid Build Coastguard Worker                        break
7932*da0073e9SAndroid Build Coastguard Worker                a -= 1
7933*da0073e9SAndroid Build Coastguard Worker                x -= 3
7934*da0073e9SAndroid Build Coastguard Worker            return a, x
7935*da0073e9SAndroid Build Coastguard Worker
7936*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test, (10, 3))
7937*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test, (10, 2))
7938*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test, (3, 2))
7939*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test, (5, 3))
7940*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test, (2, 3))
7941*da0073e9SAndroid Build Coastguard Worker
7942*da0073e9SAndroid Build Coastguard Worker        def test_delete_after_break(x):
7943*da0073e9SAndroid Build Coastguard Worker            # type: (int)
7944*da0073e9SAndroid Build Coastguard Worker            a = 1
7945*da0073e9SAndroid Build Coastguard Worker            b = 1
7946*da0073e9SAndroid Build Coastguard Worker            for i in range(x):
7947*da0073e9SAndroid Build Coastguard Worker                a = i * 3
7948*da0073e9SAndroid Build Coastguard Worker                break
7949*da0073e9SAndroid Build Coastguard Worker                b = i * 5
7950*da0073e9SAndroid Build Coastguard Worker            return a, b
7951*da0073e9SAndroid Build Coastguard Worker
7952*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_delete_after_break, (0,))
7953*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_delete_after_break, (1,))
7954*da0073e9SAndroid Build Coastguard Worker
7955*da0073e9SAndroid Build Coastguard Worker        def test_will_break_after_guard(x):
7956*da0073e9SAndroid Build Coastguard Worker            # type: (int)
7957*da0073e9SAndroid Build Coastguard Worker            a = 1
7958*da0073e9SAndroid Build Coastguard Worker            for i in range(x):
7959*da0073e9SAndroid Build Coastguard Worker                if i == 4:
7960*da0073e9SAndroid Build Coastguard Worker                    a = 3
7961*da0073e9SAndroid Build Coastguard Worker                    break
7962*da0073e9SAndroid Build Coastguard Worker                a -= 1
7963*da0073e9SAndroid Build Coastguard Worker                break
7964*da0073e9SAndroid Build Coastguard Worker                assert 1 == 2
7965*da0073e9SAndroid Build Coastguard Worker                a -= -100
7966*da0073e9SAndroid Build Coastguard Worker            return a
7967*da0073e9SAndroid Build Coastguard Worker
7968*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_will_break_after_guard, (0,))
7969*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_will_break_after_guard, (2,))
7970*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_will_break_after_guard, (4,))
7971*da0073e9SAndroid Build Coastguard Worker
7972*da0073e9SAndroid Build Coastguard Worker        def test_varexit(cond):
7973*da0073e9SAndroid Build Coastguard Worker            # type: (int)
7974*da0073e9SAndroid Build Coastguard Worker            m = 0
7975*da0073e9SAndroid Build Coastguard Worker            for i in range(3):
7976*da0073e9SAndroid Build Coastguard Worker                if cond == 2:
7977*da0073e9SAndroid Build Coastguard Worker                    if cond == 2:
7978*da0073e9SAndroid Build Coastguard Worker                        m = 2
7979*da0073e9SAndroid Build Coastguard Worker                        break
7980*da0073e9SAndroid Build Coastguard Worker                    k = 1
7981*da0073e9SAndroid Build Coastguard Worker                else:
7982*da0073e9SAndroid Build Coastguard Worker                    k = 2
7983*da0073e9SAndroid Build Coastguard Worker                m += k
7984*da0073e9SAndroid Build Coastguard Worker            return m
7985*da0073e9SAndroid Build Coastguard Worker
7986*da0073e9SAndroid Build Coastguard Worker        # use of k tests the pathway where we have to insert unitialized
7987*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_varexit, (3,))
7988*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_varexit, (2,))
7989*da0073e9SAndroid Build Coastguard Worker
7990*da0073e9SAndroid Build Coastguard Worker        def test_break_true():
7991*da0073e9SAndroid Build Coastguard Worker            i = 0
7992*da0073e9SAndroid Build Coastguard Worker            while True:
7993*da0073e9SAndroid Build Coastguard Worker                i += 1
7994*da0073e9SAndroid Build Coastguard Worker                if i == 3:
7995*da0073e9SAndroid Build Coastguard Worker                    break
7996*da0073e9SAndroid Build Coastguard Worker            while False:
7997*da0073e9SAndroid Build Coastguard Worker                i += 1
7998*da0073e9SAndroid Build Coastguard Worker            return i
7999*da0073e9SAndroid Build Coastguard Worker
8000*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_break_true, ())
8001*da0073e9SAndroid Build Coastguard Worker
8002*da0073e9SAndroid Build Coastguard Worker    def test_break_continue_error(self):
8003*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Syntax"):
8004*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit('''
8005*da0073e9SAndroid Build Coastguard Worker            def other_func(a):
8006*da0073e9SAndroid Build Coastguard Worker                break
8007*da0073e9SAndroid Build Coastguard Worker                ''')
8008*da0073e9SAndroid Build Coastguard Worker
8009*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Syntax"):
8010*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit('''
8011*da0073e9SAndroid Build Coastguard Worker            def other_func(a):
8012*da0073e9SAndroid Build Coastguard Worker                for i in range(5):
8013*da0073e9SAndroid Build Coastguard Worker                    def foo():
8014*da0073e9SAndroid Build Coastguard Worker                        break
8015*da0073e9SAndroid Build Coastguard Worker                ''')
8016*da0073e9SAndroid Build Coastguard Worker
8017*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "do not support break or continue inside"):
8018*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
8019*da0073e9SAndroid Build Coastguard Worker            def foo(x):
8020*da0073e9SAndroid Build Coastguard Worker                i = 0
8021*da0073e9SAndroid Build Coastguard Worker                for a in (1, "2", 1.5):
8022*da0073e9SAndroid Build Coastguard Worker                    b = a
8023*da0073e9SAndroid Build Coastguard Worker                    if x:
8024*da0073e9SAndroid Build Coastguard Worker                        break
8025*da0073e9SAndroid Build Coastguard Worker                return b
8026*da0073e9SAndroid Build Coastguard Worker
8027*da0073e9SAndroid Build Coastguard Worker    def test_python_call(self):
8028*da0073e9SAndroid Build Coastguard Worker        def pyfunc(a):
8029*da0073e9SAndroid Build Coastguard Worker            return a * 3.0
8030*da0073e9SAndroid Build Coastguard Worker
8031*da0073e9SAndroid Build Coastguard Worker        cu = torch.jit.CompilationUnit('''
8032*da0073e9SAndroid Build Coastguard Worker        def other_func(a):
8033*da0073e9SAndroid Build Coastguard Worker            return a + a
8034*da0073e9SAndroid Build Coastguard Worker
8035*da0073e9SAndroid Build Coastguard Worker        def test_call_python(a):
8036*da0073e9SAndroid Build Coastguard Worker            b = pyfunc(a)
8037*da0073e9SAndroid Build Coastguard Worker            b = other_func(b)
8038*da0073e9SAndroid Build Coastguard Worker            i = 0
8039*da0073e9SAndroid Build Coastguard Worker            step = 1
8040*da0073e9SAndroid Build Coastguard Worker            while i < 10:
8041*da0073e9SAndroid Build Coastguard Worker                b = pyfunc(b)
8042*da0073e9SAndroid Build Coastguard Worker                if bool(b > 3.0):
8043*da0073e9SAndroid Build Coastguard Worker                    b = pyfunc(b)
8044*da0073e9SAndroid Build Coastguard Worker                i = 11
8045*da0073e9SAndroid Build Coastguard Worker            return b
8046*da0073e9SAndroid Build Coastguard Worker        ''')
8047*da0073e9SAndroid Build Coastguard Worker        inputs = self._make_scalar_vars([1], torch.float)
8048*da0073e9SAndroid Build Coastguard Worker        outputs = self._make_scalar_vars([54], torch.float)
8049*da0073e9SAndroid Build Coastguard Worker
8050*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cu.test_call_python(*inputs), outputs[0])
8051*da0073e9SAndroid Build Coastguard Worker
8052*da0073e9SAndroid Build Coastguard Worker    def test_python_call_failure(self):
8053*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "undefined value pyfunc2"):
8054*da0073e9SAndroid Build Coastguard Worker            def pyfunc(a):
8055*da0073e9SAndroid Build Coastguard Worker                return a * 3.0
8056*da0073e9SAndroid Build Coastguard Worker
8057*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit('''
8058*da0073e9SAndroid Build Coastguard Worker            def other_func(a):
8059*da0073e9SAndroid Build Coastguard Worker                return a + a
8060*da0073e9SAndroid Build Coastguard Worker
8061*da0073e9SAndroid Build Coastguard Worker            def test_call_python(a):
8062*da0073e9SAndroid Build Coastguard Worker                b = pyfunc(a)
8063*da0073e9SAndroid Build Coastguard Worker                b = other_func(b)
8064*da0073e9SAndroid Build Coastguard Worker                i = 0
8065*da0073e9SAndroid Build Coastguard Worker                step = 1
8066*da0073e9SAndroid Build Coastguard Worker                while i < 10:
8067*da0073e9SAndroid Build Coastguard Worker                    b = pyfunc2(b)
8068*da0073e9SAndroid Build Coastguard Worker                    if b > 3.0:
8069*da0073e9SAndroid Build Coastguard Worker                        b = pyfunc(b)
8070*da0073e9SAndroid Build Coastguard Worker                    i = 11
8071*da0073e9SAndroid Build Coastguard Worker                return b
8072*da0073e9SAndroid Build Coastguard Worker            ''')
8073*da0073e9SAndroid Build Coastguard Worker            inputs = self._make_scalar_vars([1], torch.float)
8074*da0073e9SAndroid Build Coastguard Worker            outputs = self._make_scalar_vars([54], torch.float)
8075*da0073e9SAndroid Build Coastguard Worker
8076*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cu.test_call_python(*inputs), outputs)
8077*da0073e9SAndroid Build Coastguard Worker
8078*da0073e9SAndroid Build Coastguard Worker    def test_type_call_in_script(self):
8079*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
8080*da0073e9SAndroid Build Coastguard Worker        def fn(x):
8081*da0073e9SAndroid Build Coastguard Worker            return type(x)
8082*da0073e9SAndroid Build Coastguard Worker
8083*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "value of type _TensorMeta"):
8084*da0073e9SAndroid Build Coastguard Worker            fn(torch.tensor(.5))
8085*da0073e9SAndroid Build Coastguard Worker
8086*da0073e9SAndroid Build Coastguard Worker    def test_python_call_annotation(self):
8087*da0073e9SAndroid Build Coastguard Worker        def pyfunc(a):
8088*da0073e9SAndroid Build Coastguard Worker            return a * 3.0
8089*da0073e9SAndroid Build Coastguard Worker
8090*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
8091*da0073e9SAndroid Build Coastguard Worker        def foo(a):
8092*da0073e9SAndroid Build Coastguard Worker            return pyfunc(a) + pyfunc(a)
8093*da0073e9SAndroid Build Coastguard Worker
8094*da0073e9SAndroid Build Coastguard Worker        inputs = self._make_scalar_vars([1], torch.float)
8095*da0073e9SAndroid Build Coastguard Worker        outputs = self._make_scalar_vars([6], torch.float)
8096*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(*inputs), outputs[0])
8097*da0073e9SAndroid Build Coastguard Worker
8098*da0073e9SAndroid Build Coastguard Worker    def test_python_call_annoytation_failure(self):
8099*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "undefined value pyfunc2"):
8100*da0073e9SAndroid Build Coastguard Worker            def pyfunc(a):
8101*da0073e9SAndroid Build Coastguard Worker                return a * 3.0
8102*da0073e9SAndroid Build Coastguard Worker
8103*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
8104*da0073e9SAndroid Build Coastguard Worker            def foo(a):
8105*da0073e9SAndroid Build Coastguard Worker                return pyfunc2(a) + pyfunc(a)  # noqa: F821
8106*da0073e9SAndroid Build Coastguard Worker
8107*da0073e9SAndroid Build Coastguard Worker            inputs = self._make_scalar_vars([1], torch.float)
8108*da0073e9SAndroid Build Coastguard Worker            outputs = self._make_scalar_vars([6], torch.float)
8109*da0073e9SAndroid Build Coastguard Worker
8110*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(foo(*inputs), outputs[0])
8111*da0073e9SAndroid Build Coastguard Worker
8112*da0073e9SAndroid Build Coastguard Worker    def test_desugar_module(self):
8113*da0073e9SAndroid Build Coastguard Worker        import torch.nn.functional as F
8114*da0073e9SAndroid Build Coastguard Worker
8115*da0073e9SAndroid Build Coastguard Worker        def fn(x, slope):
8116*da0073e9SAndroid Build Coastguard Worker            a = torch.abs(x)
8117*da0073e9SAndroid Build Coastguard Worker            b = torch.nn.functional.prelu(x, slope)
8118*da0073e9SAndroid Build Coastguard Worker            c = F.prelu(x, slope)
8119*da0073e9SAndroid Build Coastguard Worker            return a, b, c
8120*da0073e9SAndroid Build Coastguard Worker
8121*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(-3., 4)
8122*da0073e9SAndroid Build Coastguard Worker        slope = torch.tensor([0.5])
8123*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, [x, slope], optimize=True)
8124*da0073e9SAndroid Build Coastguard Worker
8125*da0073e9SAndroid Build Coastguard Worker    def test_script_docstring(self):
8126*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
8127*da0073e9SAndroid Build Coastguard Worker        def with_docstring(x):
8128*da0073e9SAndroid Build Coastguard Worker            """test str"""
8129*da0073e9SAndroid Build Coastguard Worker            y = x
8130*da0073e9SAndroid Build Coastguard Worker            """y is the same as x"""
8131*da0073e9SAndroid Build Coastguard Worker            return y
8132*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(with_docstring.__doc__, 'test str')
8133*da0073e9SAndroid Build Coastguard Worker
8134*da0073e9SAndroid Build Coastguard Worker    def test_script_method_docstring(self):
8135*da0073e9SAndroid Build Coastguard Worker        class A(torch.jit.ScriptModule):
8136*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
8137*da0073e9SAndroid Build Coastguard Worker            def with_docstring(self, x):
8138*da0073e9SAndroid Build Coastguard Worker                """test str"""
8139*da0073e9SAndroid Build Coastguard Worker                y = x
8140*da0073e9SAndroid Build Coastguard Worker                """y is the same as x"""
8141*da0073e9SAndroid Build Coastguard Worker                return y
8142*da0073e9SAndroid Build Coastguard Worker        a = A()
8143*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.with_docstring.__doc__, 'test str')
8144*da0073e9SAndroid Build Coastguard Worker
8145*da0073e9SAndroid Build Coastguard Worker    def test_script_module(self):
8146*da0073e9SAndroid Build Coastguard Worker        class M1(torch.jit.ScriptModule):
8147*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
8148*da0073e9SAndroid Build Coastguard Worker                super().__init__()
8149*da0073e9SAndroid Build Coastguard Worker                self.weight = nn.Parameter(torch.randn(2))
8150*da0073e9SAndroid Build Coastguard Worker
8151*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
8152*da0073e9SAndroid Build Coastguard Worker            def forward(self, thing):
8153*da0073e9SAndroid Build Coastguard Worker                return self.weight + thing
8154*da0073e9SAndroid Build Coastguard Worker
8155*da0073e9SAndroid Build Coastguard Worker        class PModule(nn.Module):
8156*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
8157*da0073e9SAndroid Build Coastguard Worker                super().__init__()
8158*da0073e9SAndroid Build Coastguard Worker                self.a = nn.Parameter(torch.randn(2, 3))
8159*da0073e9SAndroid Build Coastguard Worker
8160*da0073e9SAndroid Build Coastguard Worker            def forward(self, a):
8161*da0073e9SAndroid Build Coastguard Worker                return self.a.mm(a)
8162*da0073e9SAndroid Build Coastguard Worker
8163*da0073e9SAndroid Build Coastguard Worker        class M2(torch.jit.ScriptModule):
8164*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
8165*da0073e9SAndroid Build Coastguard Worker                super().__init__()
8166*da0073e9SAndroid Build Coastguard Worker                # test submodule
8167*da0073e9SAndroid Build Coastguard Worker                self.sub = M1()
8168*da0073e9SAndroid Build Coastguard Worker                self.sub2 = PModule()
8169*da0073e9SAndroid Build Coastguard Worker                # test parameters
8170*da0073e9SAndroid Build Coastguard Worker                self.weight = nn.Parameter(torch.randn(2, 3))
8171*da0073e9SAndroid Build Coastguard Worker                self.bias = nn.Parameter(torch.randn(2))
8172*da0073e9SAndroid Build Coastguard Worker                # test defining a method from a string
8173*da0073e9SAndroid Build Coastguard Worker                self.define("""
8174*da0073e9SAndroid Build Coastguard Worker                    def hi(self, a):
8175*da0073e9SAndroid Build Coastguard Worker                        return self.weight.mm(a)
8176*da0073e9SAndroid Build Coastguard Worker                """)
8177*da0073e9SAndroid Build Coastguard Worker            # test script methods
8178*da0073e9SAndroid Build Coastguard Worker
8179*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
8180*da0073e9SAndroid Build Coastguard Worker            def doit(self, input):
8181*da0073e9SAndroid Build Coastguard Worker                # test use of parameter
8182*da0073e9SAndroid Build Coastguard Worker                return self.weight.mm(input)
8183*da0073e9SAndroid Build Coastguard Worker
8184*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
8185*da0073e9SAndroid Build Coastguard Worker            def doit2(self, input):
8186*da0073e9SAndroid Build Coastguard Worker                return self.weight.mm(input)
8187*da0073e9SAndroid Build Coastguard Worker
8188*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
8189*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
8190*da0073e9SAndroid Build Coastguard Worker                a = self.doit(input)
8191*da0073e9SAndroid Build Coastguard Worker                b = self.doit2(input)
8192*da0073e9SAndroid Build Coastguard Worker                c = self.hi(input)
8193*da0073e9SAndroid Build Coastguard Worker                d = self.sub2(input)
8194*da0073e9SAndroid Build Coastguard Worker                return a + b + self.bias + self.sub(a) + c + d
8195*da0073e9SAndroid Build Coastguard Worker        with torch.jit.optimized_execution(False):
8196*da0073e9SAndroid Build Coastguard Worker            m2 = M2()
8197*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(3, 2)
8198*da0073e9SAndroid Build Coastguard Worker            a = m2.weight.mm(input)
8199*da0073e9SAndroid Build Coastguard Worker            b = m2.weight.mm(input)
8200*da0073e9SAndroid Build Coastguard Worker            c = m2.weight.mm(input)
8201*da0073e9SAndroid Build Coastguard Worker            d = m2.sub2.a.mm(input)
8202*da0073e9SAndroid Build Coastguard Worker            ref = a + b + m2.bias + m2.sub.weight + a + c + d
8203*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ref, m2.forward(input))
8204*da0073e9SAndroid Build Coastguard Worker            m2.weight = nn.Parameter(torch.zeros_like(m2.weight))
8205*da0073e9SAndroid Build Coastguard Worker            m2.bias = nn.Parameter(torch.zeros_like(m2.bias))
8206*da0073e9SAndroid Build Coastguard Worker            m2.sub.weight = nn.Parameter(torch.zeros_like(m2.sub.weight))
8207*da0073e9SAndroid Build Coastguard Worker            m2.sub2.a.data.zero_()
8208*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.zeros(2, 2), m2.forward(torch.randn(3, 2)))
8209*da0073e9SAndroid Build Coastguard Worker
8210*da0073e9SAndroid Build Coastguard Worker    def test_irparser(self):
8211*da0073e9SAndroid Build Coastguard Worker        graph_str = """graph(%0 : Double(5, 5)):
8212*da0073e9SAndroid Build Coastguard Worker          # CHECK: aten::relu
8213*da0073e9SAndroid Build Coastguard Worker          %1 : Double(5, 5) = aten::relu(%0)
8214*da0073e9SAndroid Build Coastguard Worker          return (%1)
8215*da0073e9SAndroid Build Coastguard Worker        """
8216*da0073e9SAndroid Build Coastguard Worker        FileCheck().run(graph_str, parse_ir(graph_str))
8217*da0073e9SAndroid Build Coastguard Worker
8218*da0073e9SAndroid Build Coastguard Worker    def test_parse_tensor_constants(self):
8219*da0073e9SAndroid Build Coastguard Worker        def foo():
8220*da0073e9SAndroid Build Coastguard Worker            return torch.zeros([4, 4])
8221*da0073e9SAndroid Build Coastguard Worker
8222*da0073e9SAndroid Build Coastguard Worker        foo_s = torch.jit.script(foo)
8223*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_constant_propagation(foo_s.graph)
8224*da0073e9SAndroid Build Coastguard Worker
8225*da0073e9SAndroid Build Coastguard Worker        g = str(foo_s.graph)
8226*da0073e9SAndroid Build Coastguard Worker        g_parsed = parse_ir(g, parse_tensor_constants=True)
8227*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(str(canonical(g_parsed)), str(canonical(foo_s.graph)))
8228*da0073e9SAndroid Build Coastguard Worker        func = torch._C._create_function_from_graph("forward", g_parsed)
8229*da0073e9SAndroid Build Coastguard Worker
8230*da0073e9SAndroid Build Coastguard Worker        out_parsed = func()
8231*da0073e9SAndroid Build Coastguard Worker        out_func = foo()
8232*da0073e9SAndroid Build Coastguard Worker        # not checking data, just dtype, size etc
8233*da0073e9SAndroid Build Coastguard Worker        out_parsed[:] = 0
8234*da0073e9SAndroid Build Coastguard Worker        out_func[:] = 0
8235*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_func, out_parsed)
8236*da0073e9SAndroid Build Coastguard Worker
8237*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
8238*da0073e9SAndroid Build Coastguard Worker            parse_ir(g, parse_tensor_constants=False)
8239*da0073e9SAndroid Build Coastguard Worker
8240*da0073e9SAndroid Build Coastguard Worker    def test_parse_nested_names(self):
8241*da0073e9SAndroid Build Coastguard Worker        g_str = """
8242*da0073e9SAndroid Build Coastguard Worker    graph(%x.1 : Tensor):
8243*da0073e9SAndroid Build Coastguard Worker        %3 : int = prim::Constant[value=1]()
8244*da0073e9SAndroid Build Coastguard Worker        %2 : int = prim::Constant[value=2]()
8245*da0073e9SAndroid Build Coastguard Worker        %hi.submod.value.5 : Tensor = aten::add(%x.1, %2, %3)
8246*da0073e9SAndroid Build Coastguard Worker        return (%hi.submod.value.5)
8247*da0073e9SAndroid Build Coastguard Worker        """
8248*da0073e9SAndroid Build Coastguard Worker        g = parse_ir(g_str)
8249*da0073e9SAndroid Build Coastguard Worker        round_trip_g = parse_ir(str(g))
8250*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(canonical(g), canonical(round_trip_g))
8251*da0073e9SAndroid Build Coastguard Worker
8252*da0073e9SAndroid Build Coastguard Worker        func1 = torch._C._create_function_from_graph("forward", g)
8253*da0073e9SAndroid Build Coastguard Worker        func2 = torch._C._create_function_from_graph("forward", round_trip_g)
8254*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(func1(torch.ones([2])), func2(torch.ones([2])))
8255*da0073e9SAndroid Build Coastguard Worker
8256*da0073e9SAndroid Build Coastguard Worker    def test_is_after_use(self):
8257*da0073e9SAndroid Build Coastguard Worker        def sorted_input_use(g):
8258*da0073e9SAndroid Build Coastguard Worker            uses = list(next(g.inputs()).uses())
8259*da0073e9SAndroid Build Coastguard Worker            return sorted(uses, key=functools.cmp_to_key(type(uses[0]).isAfter))
8260*da0073e9SAndroid Build Coastguard Worker
8261*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
8262*da0073e9SAndroid Build Coastguard Worker        def foo(x):
8263*da0073e9SAndroid Build Coastguard Worker            a = x + 1
8264*da0073e9SAndroid Build Coastguard Worker            return (x, x, a)
8265*da0073e9SAndroid Build Coastguard Worker
8266*da0073e9SAndroid Build Coastguard Worker        uses_sorted = sorted_input_use(foo.graph)
8267*da0073e9SAndroid Build Coastguard Worker        # sorts last use to the end
8268*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(uses_sorted[0].isAfter(uses_sorted[1]))
8269*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(uses_sorted[0].user.kind() == "aten::add")
8270*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(uses_sorted[1].offset, 0)
8271*da0073e9SAndroid Build Coastguard Worker
8272*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
8273*da0073e9SAndroid Build Coastguard Worker        def foo(x, cond: bool):
8274*da0073e9SAndroid Build Coastguard Worker            if cond:
8275*da0073e9SAndroid Build Coastguard Worker                return x + 3
8276*da0073e9SAndroid Build Coastguard Worker            else:
8277*da0073e9SAndroid Build Coastguard Worker                return x - 3
8278*da0073e9SAndroid Build Coastguard Worker
8279*da0073e9SAndroid Build Coastguard Worker        uses_sorted = sorted_input_use(foo.graph)
8280*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(uses_sorted[0].user.kind() == "aten::add")
8281*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(uses_sorted[1].user.kind() == "aten::sub")
8282*da0073e9SAndroid Build Coastguard Worker
8283*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
8284*da0073e9SAndroid Build Coastguard Worker        def foo(x, cond: bool, cond2: bool):
8285*da0073e9SAndroid Build Coastguard Worker            if cond:
8286*da0073e9SAndroid Build Coastguard Worker                return x + 3
8287*da0073e9SAndroid Build Coastguard Worker            elif cond2 :
8288*da0073e9SAndroid Build Coastguard Worker                return x - 3
8289*da0073e9SAndroid Build Coastguard Worker
8290*da0073e9SAndroid Build Coastguard Worker            return x / 3
8291*da0073e9SAndroid Build Coastguard Worker
8292*da0073e9SAndroid Build Coastguard Worker        graph1 = foo.graph
8293*da0073e9SAndroid Build Coastguard Worker
8294*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
8295*da0073e9SAndroid Build Coastguard Worker        def foo(x, cond: bool, cond2: bool):
8296*da0073e9SAndroid Build Coastguard Worker            if cond:
8297*da0073e9SAndroid Build Coastguard Worker                return x + 3
8298*da0073e9SAndroid Build Coastguard Worker            else:
8299*da0073e9SAndroid Build Coastguard Worker                if cond2 :
8300*da0073e9SAndroid Build Coastguard Worker                    return x - 3
8301*da0073e9SAndroid Build Coastguard Worker                return x / 3
8302*da0073e9SAndroid Build Coastguard Worker
8303*da0073e9SAndroid Build Coastguard Worker        graph2 = foo.graph
8304*da0073e9SAndroid Build Coastguard Worker
8305*da0073e9SAndroid Build Coastguard Worker        for graph in [graph1, graph2]:
8306*da0073e9SAndroid Build Coastguard Worker            uses_sorted = sorted_input_use(graph)
8307*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(uses_sorted[0].user.kind() == "aten::add")
8308*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(uses_sorted[1].user.kind() == "aten::sub")
8309*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(uses_sorted[2].user.kind() == "aten::div")
8310*da0073e9SAndroid Build Coastguard Worker
8311*da0073e9SAndroid Build Coastguard Worker    def test_canonicalize_control_outputs(self):
8312*da0073e9SAndroid Build Coastguard Worker        def test_all_outputs(g):
8313*da0073e9SAndroid Build Coastguard Worker            ifs = g.findAllNodes("prim::If")
8314*da0073e9SAndroid Build Coastguard Worker            loops = g.findAllNodes("prim::Loop")
8315*da0073e9SAndroid Build Coastguard Worker
8316*da0073e9SAndroid Build Coastguard Worker            def contained_blocks(node):
8317*da0073e9SAndroid Build Coastguard Worker                return len(node.findAllNodes("prim::If")) * 2 + len(node.findAllNodes("prim::Loop"))
8318*da0073e9SAndroid Build Coastguard Worker            for node in ifs + loops:
8319*da0073e9SAndroid Build Coastguard Worker                outs = list(node.outputs())
8320*da0073e9SAndroid Build Coastguard Worker                out_name = [x.debugName() for x in outs]
8321*da0073e9SAndroid Build Coastguard Worker                if len(out_name) == 0:
8322*da0073e9SAndroid Build Coastguard Worker                    continue
8323*da0073e9SAndroid Build Coastguard Worker                fc = FileCheck()
8324*da0073e9SAndroid Build Coastguard Worker                # find the last output, then all subsequent uses
8325*da0073e9SAndroid Build Coastguard Worker                fc.check(out_name[-1] + " : ")
8326*da0073e9SAndroid Build Coastguard Worker                # skip past node body
8327*da0073e9SAndroid Build Coastguard Worker                for i in range(contained_blocks(node)):
8328*da0073e9SAndroid Build Coastguard Worker                    fc.check("->")
8329*da0073e9SAndroid Build Coastguard Worker                if (node.kind() == "prim::If"):
8330*da0073e9SAndroid Build Coastguard Worker                    fc.check("->").check("->").check("\n")
8331*da0073e9SAndroid Build Coastguard Worker                else:
8332*da0073e9SAndroid Build Coastguard Worker                    fc.check("->").check("\n")
8333*da0073e9SAndroid Build Coastguard Worker                # the canonical order is the same order as the first use
8334*da0073e9SAndroid Build Coastguard Worker                # appears in text
8335*da0073e9SAndroid Build Coastguard Worker                for name in out_name:
8336*da0073e9SAndroid Build Coastguard Worker                    fc.check(name)
8337*da0073e9SAndroid Build Coastguard Worker                fc.run(g)
8338*da0073e9SAndroid Build Coastguard Worker
8339*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
8340*da0073e9SAndroid Build Coastguard Worker        def test(x):
8341*da0073e9SAndroid Build Coastguard Worker            # type: (bool) -> Tuple[int, int]
8342*da0073e9SAndroid Build Coastguard Worker            b = 2
8343*da0073e9SAndroid Build Coastguard Worker            a = 1
8344*da0073e9SAndroid Build Coastguard Worker            if x:
8345*da0073e9SAndroid Build Coastguard Worker                a = 1
8346*da0073e9SAndroid Build Coastguard Worker                b = 2
8347*da0073e9SAndroid Build Coastguard Worker                x = False
8348*da0073e9SAndroid Build Coastguard Worker            if x:
8349*da0073e9SAndroid Build Coastguard Worker                b = a
8350*da0073e9SAndroid Build Coastguard Worker            else:
8351*da0073e9SAndroid Build Coastguard Worker                a = b
8352*da0073e9SAndroid Build Coastguard Worker
8353*da0073e9SAndroid Build Coastguard Worker            return a, b
8354*da0073e9SAndroid Build Coastguard Worker        test_all_outputs(test.graph)
8355*da0073e9SAndroid Build Coastguard Worker
8356*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
8357*da0073e9SAndroid Build Coastguard Worker        def test2(x):
8358*da0073e9SAndroid Build Coastguard Worker            # type: (bool) -> Tuple[int, int]
8359*da0073e9SAndroid Build Coastguard Worker            b = 2
8360*da0073e9SAndroid Build Coastguard Worker            a = 1
8361*da0073e9SAndroid Build Coastguard Worker            if x:
8362*da0073e9SAndroid Build Coastguard Worker                a = 1
8363*da0073e9SAndroid Build Coastguard Worker                b = 2
8364*da0073e9SAndroid Build Coastguard Worker                x = False
8365*da0073e9SAndroid Build Coastguard Worker            if x:
8366*da0073e9SAndroid Build Coastguard Worker                print(a)
8367*da0073e9SAndroid Build Coastguard Worker            else:
8368*da0073e9SAndroid Build Coastguard Worker                if x:
8369*da0073e9SAndroid Build Coastguard Worker                    print(b)
8370*da0073e9SAndroid Build Coastguard Worker
8371*da0073e9SAndroid Build Coastguard Worker            return a, b
8372*da0073e9SAndroid Build Coastguard Worker        test_all_outputs(test2.graph)
8373*da0073e9SAndroid Build Coastguard Worker
8374*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
8375*da0073e9SAndroid Build Coastguard Worker        def test_loop(x, iter):
8376*da0073e9SAndroid Build Coastguard Worker            # type: (bool, int) -> (None)
8377*da0073e9SAndroid Build Coastguard Worker            a = 1
8378*da0073e9SAndroid Build Coastguard Worker            b = 2
8379*da0073e9SAndroid Build Coastguard Worker            c = 3
8380*da0073e9SAndroid Build Coastguard Worker            for i in range(iter):
8381*da0073e9SAndroid Build Coastguard Worker                a = 4
8382*da0073e9SAndroid Build Coastguard Worker                b = 5
8383*da0073e9SAndroid Build Coastguard Worker                c = 6
8384*da0073e9SAndroid Build Coastguard Worker                x = True
8385*da0073e9SAndroid Build Coastguard Worker            print(c)
8386*da0073e9SAndroid Build Coastguard Worker            if x:
8387*da0073e9SAndroid Build Coastguard Worker                print(a, b)
8388*da0073e9SAndroid Build Coastguard Worker        test_all_outputs(test_loop.graph)
8389*da0073e9SAndroid Build Coastguard Worker
8390*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
8391*da0073e9SAndroid Build Coastguard Worker        def loop_unused(iter):
8392*da0073e9SAndroid Build Coastguard Worker            # type: (int) -> (None)
8393*da0073e9SAndroid Build Coastguard Worker            a = 1
8394*da0073e9SAndroid Build Coastguard Worker            b = 2
8395*da0073e9SAndroid Build Coastguard Worker            c = 3
8396*da0073e9SAndroid Build Coastguard Worker            for i in range(iter):
8397*da0073e9SAndroid Build Coastguard Worker                c = c + 1
8398*da0073e9SAndroid Build Coastguard Worker                b = b + 1
8399*da0073e9SAndroid Build Coastguard Worker                a = a + 1
8400*da0073e9SAndroid Build Coastguard Worker                print(a, b)
8401*da0073e9SAndroid Build Coastguard Worker            print(c)
8402*da0073e9SAndroid Build Coastguard Worker
8403*da0073e9SAndroid Build Coastguard Worker        # c is used, then unused should be ordered by alphabetical
8404*da0073e9SAndroid Build Coastguard Worker        FileCheck().check(r"%c : int, %a : int, %b : int").run(loop_unused.graph)
8405*da0073e9SAndroid Build Coastguard Worker
8406*da0073e9SAndroid Build Coastguard Worker    def test_filecheck(self):
8407*da0073e9SAndroid Build Coastguard Worker        def test_check():
8408*da0073e9SAndroid Build Coastguard Worker            file = "232"
8409*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("2").check("3").check("2").run(file)
8410*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("232").run(file)
8411*da0073e9SAndroid Build Coastguard Worker
8412*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'):
8413*da0073e9SAndroid Build Coastguard Worker                FileCheck().check("22").run(file)
8414*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "CHECK: 3"):
8415*da0073e9SAndroid Build Coastguard Worker                FileCheck().check("3").check("3").run(file)
8416*da0073e9SAndroid Build Coastguard Worker
8417*da0073e9SAndroid Build Coastguard Worker        test_check()
8418*da0073e9SAndroid Build Coastguard Worker
8419*da0073e9SAndroid Build Coastguard Worker        def test_check_count():
8420*da0073e9SAndroid Build Coastguard Worker            file = "22222"
8421*da0073e9SAndroid Build Coastguard Worker            FileCheck().check_count("2", 5).run(file)
8422*da0073e9SAndroid Build Coastguard Worker            FileCheck().check_count("22", 2).run(file)
8423*da0073e9SAndroid Build Coastguard Worker            FileCheck().check_count("222", 1).run(file)
8424*da0073e9SAndroid Build Coastguard Worker
8425*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, 'Expected to not find'):
8426*da0073e9SAndroid Build Coastguard Worker                FileCheck().check_count("2", 4, exactly=True).run(file)
8427*da0073e9SAndroid Build Coastguard Worker
8428*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'):
8429*da0073e9SAndroid Build Coastguard Worker                FileCheck().check_count("22", 3).run(file)
8430*da0073e9SAndroid Build Coastguard Worker
8431*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "CHECK-COUNT-6: 2"):
8432*da0073e9SAndroid Build Coastguard Worker                FileCheck().check_count("2", 6).run(file)
8433*da0073e9SAndroid Build Coastguard Worker
8434*da0073e9SAndroid Build Coastguard Worker        test_check_count()
8435*da0073e9SAndroid Build Coastguard Worker
8436*da0073e9SAndroid Build Coastguard Worker        def test_check_same():
8437*da0073e9SAndroid Build Coastguard Worker            file = "22\n33"
8438*da0073e9SAndroid Build Coastguard Worker            FileCheck().check_same("22").run(file)
8439*da0073e9SAndroid Build Coastguard Worker
8440*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Expected to not find"):
8441*da0073e9SAndroid Build Coastguard Worker                FileCheck().check_same("33").run(file)
8442*da0073e9SAndroid Build Coastguard Worker
8443*da0073e9SAndroid Build Coastguard Worker            file = "22  1  3"
8444*da0073e9SAndroid Build Coastguard Worker
8445*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("2").check_same("3").run(file)
8446*da0073e9SAndroid Build Coastguard Worker            FileCheck().check_count("2", 2).check_same("3").run(file)
8447*da0073e9SAndroid Build Coastguard Worker
8448*da0073e9SAndroid Build Coastguard Worker        test_check_same()
8449*da0073e9SAndroid Build Coastguard Worker
8450*da0073e9SAndroid Build Coastguard Worker        def test_check_next():
8451*da0073e9SAndroid Build Coastguard Worker            file = "\n1\n2\n3"
8452*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("1").check_next("2").check_next("3").run(file)
8453*da0073e9SAndroid Build Coastguard Worker            FileCheck().check_next("1").check_next("2").check_next("3").run(file)
8454*da0073e9SAndroid Build Coastguard Worker
8455*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Expected to find"):
8456*da0073e9SAndroid Build Coastguard Worker                FileCheck().check("1").check_next("2").run("12")
8457*da0073e9SAndroid Build Coastguard Worker
8458*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Expected to not find"):
8459*da0073e9SAndroid Build Coastguard Worker                FileCheck().check("1").check_next("2").run("1\n\n2")
8460*da0073e9SAndroid Build Coastguard Worker
8461*da0073e9SAndroid Build Coastguard Worker        test_check_next()
8462*da0073e9SAndroid Build Coastguard Worker
8463*da0073e9SAndroid Build Coastguard Worker        def test_check_dag():
8464*da0073e9SAndroid Build Coastguard Worker            fc = FileCheck().check_dag("1").check_dag("2").check_not("2")
8465*da0073e9SAndroid Build Coastguard Worker            fc.run("12")
8466*da0073e9SAndroid Build Coastguard Worker            fc.run("21")
8467*da0073e9SAndroid Build Coastguard Worker
8468*da0073e9SAndroid Build Coastguard Worker            fc = FileCheck()
8469*da0073e9SAndroid Build Coastguard Worker            fc.check_not("3").check_dag("1").check_dag("2").check_not("3")
8470*da0073e9SAndroid Build Coastguard Worker            fc.run("1 3 2")
8471*da0073e9SAndroid Build Coastguard Worker            fc.run("2 3 1")
8472*da0073e9SAndroid Build Coastguard Worker
8473*da0073e9SAndroid Build Coastguard Worker            fc = FileCheck().check_dag("1").check_dag("2").check("3")
8474*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, 'Expected to find "3" but did not find it'):
8475*da0073e9SAndroid Build Coastguard Worker                fc.run("1 3 2")
8476*da0073e9SAndroid Build Coastguard Worker
8477*da0073e9SAndroid Build Coastguard Worker        test_check_dag()
8478*da0073e9SAndroid Build Coastguard Worker
8479*da0073e9SAndroid Build Coastguard Worker        def test_check_not():
8480*da0073e9SAndroid Build Coastguard Worker            FileCheck().check_not("2").check("1").run("12")
8481*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("2").check_not("2").run("12")
8482*da0073e9SAndroid Build Coastguard Worker
8483*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, 'Expected to not find "2"'):
8484*da0073e9SAndroid Build Coastguard Worker                FileCheck().check_not("2").check("1").run("21")
8485*da0073e9SAndroid Build Coastguard Worker
8486*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'):
8487*da0073e9SAndroid Build Coastguard Worker                FileCheck().check("2").check_not("1").run("21")
8488*da0073e9SAndroid Build Coastguard Worker
8489*da0073e9SAndroid Build Coastguard Worker            # checks with distinct range matchings
8490*da0073e9SAndroid Build Coastguard Worker            fb = FileCheck().check_count("2", 2).check_count("2", 2).check_not("2")
8491*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, 'Expected to not find "2"'):
8492*da0073e9SAndroid Build Coastguard Worker                fb.run("22 2 22")
8493*da0073e9SAndroid Build Coastguard Worker
8494*da0073e9SAndroid Build Coastguard Worker            fb = FileCheck().check_count("2", 2).check_not("1").check_count("2", 2)
8495*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'):
8496*da0073e9SAndroid Build Coastguard Worker                fb.run("22 1 22")
8497*da0073e9SAndroid Build Coastguard Worker
8498*da0073e9SAndroid Build Coastguard Worker    def _dtype_to_jit_name(self, dtype):
8499*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.float32:
8500*da0073e9SAndroid Build Coastguard Worker            return "Float"
8501*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.float64:
8502*da0073e9SAndroid Build Coastguard Worker            return "Double"
8503*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.int64:
8504*da0073e9SAndroid Build Coastguard Worker            return "Long"
8505*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.int32:
8506*da0073e9SAndroid Build Coastguard Worker            return "Int"
8507*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.bool:
8508*da0073e9SAndroid Build Coastguard Worker            return "Bool"
8509*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError('dtype not handled')
8510*da0073e9SAndroid Build Coastguard Worker
8511*da0073e9SAndroid Build Coastguard Worker    def _dtype_to_expect(self, dtype, dim=0):
8512*da0073e9SAndroid Build Coastguard Worker        param = ', '.join(['*'] * dim + ['device=cpu'])
8513*da0073e9SAndroid Build Coastguard Worker        param = '(' + param + ')'
8514*da0073e9SAndroid Build Coastguard Worker        jit_type = self._dtype_to_jit_name(dtype)
8515*da0073e9SAndroid Build Coastguard Worker        if dim >= 0:
8516*da0073e9SAndroid Build Coastguard Worker            return jit_type + param
8517*da0073e9SAndroid Build Coastguard Worker        # special case representing wrapped number
8518*da0073e9SAndroid Build Coastguard Worker        else:
8519*da0073e9SAndroid Build Coastguard Worker            return jit_type.lower()
8520*da0073e9SAndroid Build Coastguard Worker
8521*da0073e9SAndroid Build Coastguard Worker
8522*da0073e9SAndroid Build Coastguard Worker    def _test_dtype_op_shape(self, ops, args, input_dims=1):
8523*da0073e9SAndroid Build Coastguard Worker        if input_dims < 1:
8524*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("input dims must be at least 1")
8525*da0073e9SAndroid Build Coastguard Worker        dtypes = [torch.float32, torch.float64, torch.int64, torch.int32]
8526*da0073e9SAndroid Build Coastguard Worker        str_args = ', '.join([str(arg) for arg in args]) + (', ' if len(args) else '')
8527*da0073e9SAndroid Build Coastguard Worker        tensor_data = ('[' * input_dims) + '1, 2, 3' + (input_dims * ']')
8528*da0073e9SAndroid Build Coastguard Worker        template = dedent('''
8529*da0073e9SAndroid Build Coastguard Worker        def func():
8530*da0073e9SAndroid Build Coastguard Worker            return {return_line}
8531*da0073e9SAndroid Build Coastguard Worker        ''')
8532*da0073e9SAndroid Build Coastguard Worker
8533*da0073e9SAndroid Build Coastguard Worker        for op in ops:
8534*da0073e9SAndroid Build Coastguard Worker            for dtype in (dtypes + [None]):
8535*da0073e9SAndroid Build Coastguard Worker                for tensor_type in dtypes:
8536*da0073e9SAndroid Build Coastguard Worker                    # a couple of ops aren't implemented for non-floating types
8537*da0073e9SAndroid Build Coastguard Worker                    if not tensor_type.is_floating_point or (dtype is not None and not dtype.is_floating_point):
8538*da0073e9SAndroid Build Coastguard Worker                        if op in ['mean', 'softmax', 'log_softmax']:
8539*da0073e9SAndroid Build Coastguard Worker                            continue
8540*da0073e9SAndroid Build Coastguard Worker                    return_line = f"torch.tensor({tensor_data}, dtype={tensor_type}).{op}({str_args}dtype={dtype})"
8541*da0073e9SAndroid Build Coastguard Worker                    # uncomment for debugging a failed test:
8542*da0073e9SAndroid Build Coastguard Worker                    # print("testing {}".format(return_line))
8543*da0073e9SAndroid Build Coastguard Worker                    code = template.format(return_line=return_line)
8544*da0073e9SAndroid Build Coastguard Worker                    scope = {}
8545*da0073e9SAndroid Build Coastguard Worker                    exec(code, globals(), scope)
8546*da0073e9SAndroid Build Coastguard Worker                    cu = torch.jit.CompilationUnit(code)
8547*da0073e9SAndroid Build Coastguard Worker                    graph = cu.func.graph
8548*da0073e9SAndroid Build Coastguard Worker                    torch._C._jit_pass_complete_shape_analysis(graph, (), False)
8549*da0073e9SAndroid Build Coastguard Worker                    input_array = [1, 2, 3]
8550*da0073e9SAndroid Build Coastguard Worker                    for _ in range(1, input_dims):
8551*da0073e9SAndroid Build Coastguard Worker                        input_array = [input_array]
8552*da0073e9SAndroid Build Coastguard Worker                    t = torch.tensor(input_array, dtype=tensor_type)
8553*da0073e9SAndroid Build Coastguard Worker                    attr = getattr(t, op)
8554*da0073e9SAndroid Build Coastguard Worker                    kwargs = {'dtype': dtype}
8555*da0073e9SAndroid Build Coastguard Worker                    result = attr(*args, **kwargs)
8556*da0073e9SAndroid Build Coastguard Worker                    expect = self._dtype_to_expect(result.dtype, result.dim())
8557*da0073e9SAndroid Build Coastguard Worker                    FileCheck().check("aten::tensor").check(expect).run(graph)
8558*da0073e9SAndroid Build Coastguard Worker
8559*da0073e9SAndroid Build Coastguard Worker    def test_dtype_op_shape(self):
8560*da0073e9SAndroid Build Coastguard Worker        ops = ['prod']
8561*da0073e9SAndroid Build Coastguard Worker        self._test_dtype_op_shape(ops, args=[])
8562*da0073e9SAndroid Build Coastguard Worker        self._test_dtype_op_shape(ops, args=[0, False])
8563*da0073e9SAndroid Build Coastguard Worker        self._test_dtype_op_shape(ops, args=[0, False])
8564*da0073e9SAndroid Build Coastguard Worker        self._test_dtype_op_shape(ops, args=[0, True])
8565*da0073e9SAndroid Build Coastguard Worker
8566*da0073e9SAndroid Build Coastguard Worker    def test_dtype_op_shape2(self):
8567*da0073e9SAndroid Build Coastguard Worker        ops = ['cumprod', 'cumsum', 'softmax', 'log_softmax']
8568*da0073e9SAndroid Build Coastguard Worker        self._test_dtype_op_shape(ops, args=[0])
8569*da0073e9SAndroid Build Coastguard Worker
8570*da0073e9SAndroid Build Coastguard Worker        self._test_dtype_op_shape(ops, args=[1], input_dims=4)
8571*da0073e9SAndroid Build Coastguard Worker
8572*da0073e9SAndroid Build Coastguard Worker
8573*da0073e9SAndroid Build Coastguard Worker    def _test_binary_op_shape(self, ops, input_dims=1):
8574*da0073e9SAndroid Build Coastguard Worker
8575*da0073e9SAndroid Build Coastguard Worker        dtypes = [torch.float32, torch.float64, torch.int64, torch.int32, torch.bool]
8576*da0073e9SAndroid Build Coastguard Worker
8577*da0073e9SAndroid Build Coastguard Worker        if input_dims == 0:
8578*da0073e9SAndroid Build Coastguard Worker            shape = '1'
8579*da0073e9SAndroid Build Coastguard Worker        else:
8580*da0073e9SAndroid Build Coastguard Worker            shape = '[' + ('1,' * 4) + ']'
8581*da0073e9SAndroid Build Coastguard Worker            for _ in range(1, input_dims):
8582*da0073e9SAndroid Build Coastguard Worker                shape = '[' + ",".join([shape] * 4) + ']'
8583*da0073e9SAndroid Build Coastguard Worker
8584*da0073e9SAndroid Build Coastguard Worker        template = dedent('''
8585*da0073e9SAndroid Build Coastguard Worker        def func():
8586*da0073e9SAndroid Build Coastguard Worker            arg1 = {}
8587*da0073e9SAndroid Build Coastguard Worker            arg2 = {}
8588*da0073e9SAndroid Build Coastguard Worker            return torch.{}(arg1, arg2)
8589*da0073e9SAndroid Build Coastguard Worker        ''')
8590*da0073e9SAndroid Build Coastguard Worker
8591*da0073e9SAndroid Build Coastguard Worker        args = []
8592*da0073e9SAndroid Build Coastguard Worker        for dtype in dtypes:
8593*da0073e9SAndroid Build Coastguard Worker            args = args + [f"torch.tensor({shape}, dtype={dtype})"]
8594*da0073e9SAndroid Build Coastguard Worker        args = args + [1, 1.5]
8595*da0073e9SAndroid Build Coastguard Worker
8596*da0073e9SAndroid Build Coastguard Worker        def isBool(arg):
8597*da0073e9SAndroid Build Coastguard Worker            return type(arg) == bool or (type(arg) == str and "torch.bool" in arg)
8598*da0073e9SAndroid Build Coastguard Worker
8599*da0073e9SAndroid Build Coastguard Worker        for op in ops:
8600*da0073e9SAndroid Build Coastguard Worker            for first_arg in args:
8601*da0073e9SAndroid Build Coastguard Worker                for second_arg in args:
8602*da0073e9SAndroid Build Coastguard Worker                    # subtract not supported for bool
8603*da0073e9SAndroid Build Coastguard Worker                    if (op == 'sub' or op == 'div') and (isBool(first_arg) or isBool(second_arg)):
8604*da0073e9SAndroid Build Coastguard Worker                        continue
8605*da0073e9SAndroid Build Coastguard Worker                    # div is not implemented correctly for mixed-type or int params
8606*da0073e9SAndroid Build Coastguard Worker                    if (op == 'div' and (type(first_arg) != type(second_arg) or
8607*da0073e9SAndroid Build Coastguard Worker                       isinstance(first_arg, int) or
8608*da0073e9SAndroid Build Coastguard Worker                       (isinstance(first_arg, str) and 'int' in first_arg))):
8609*da0073e9SAndroid Build Coastguard Worker                        continue
8610*da0073e9SAndroid Build Coastguard Worker                    return_line = f"torch.{op}({first_arg}, {second_arg})"
8611*da0073e9SAndroid Build Coastguard Worker                    # uncomment for debugging a failed test:
8612*da0073e9SAndroid Build Coastguard Worker                    # print("testing {}".format(return_line))
8613*da0073e9SAndroid Build Coastguard Worker                    code = template.format(first_arg, second_arg, op)
8614*da0073e9SAndroid Build Coastguard Worker                    scope = {}
8615*da0073e9SAndroid Build Coastguard Worker                    exec(code, globals(), scope)
8616*da0073e9SAndroid Build Coastguard Worker                    non_jit_result = scope['func']()
8617*da0073e9SAndroid Build Coastguard Worker
8618*da0073e9SAndroid Build Coastguard Worker                    cu = torch.jit.CompilationUnit(code)
8619*da0073e9SAndroid Build Coastguard Worker                    graph = cu.func.graph
8620*da0073e9SAndroid Build Coastguard Worker                    torch._C._jit_pass_complete_shape_analysis(graph, (), False)
8621*da0073e9SAndroid Build Coastguard Worker                    # use dim=-1 to represent a python/jit scalar.
8622*da0073e9SAndroid Build Coastguard Worker                    dim = -1 if type(first_arg) != str and type(second_arg) != str else non_jit_result.dim()
8623*da0073e9SAndroid Build Coastguard Worker                    dtype = non_jit_result.dtype
8624*da0073e9SAndroid Build Coastguard Worker                    # jit only supports int/float scalars.
8625*da0073e9SAndroid Build Coastguard Worker                    if dim < 0:
8626*da0073e9SAndroid Build Coastguard Worker                        if dtype == torch.int64:
8627*da0073e9SAndroid Build Coastguard Worker                            dtype = torch.int32
8628*da0073e9SAndroid Build Coastguard Worker                        if dtype == torch.float64:
8629*da0073e9SAndroid Build Coastguard Worker                            dtype = torch.float32
8630*da0073e9SAndroid Build Coastguard Worker                    expect = self._dtype_to_expect(dtype, dim)
8631*da0073e9SAndroid Build Coastguard Worker                    jit_output = next(graph.outputs())
8632*da0073e9SAndroid Build Coastguard Worker
8633*da0073e9SAndroid Build Coastguard Worker                    check = FileCheck()
8634*da0073e9SAndroid Build Coastguard Worker                    check.check(expect).run(str(jit_output))
8635*da0073e9SAndroid Build Coastguard Worker
8636*da0073e9SAndroid Build Coastguard Worker    def test_binary_op_shape(self):
8637*da0073e9SAndroid Build Coastguard Worker        self._test_binary_op_shape(['mul', 'div', 'add', 'sub'], 0)
8638*da0073e9SAndroid Build Coastguard Worker        self._test_binary_op_shape(['mul', 'div', 'add', 'sub'], 3)
8639*da0073e9SAndroid Build Coastguard Worker
8640*da0073e9SAndroid Build Coastguard Worker    def test_no_dtype_shape(self):
8641*da0073e9SAndroid Build Coastguard Worker
8642*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
8643*da0073e9SAndroid Build Coastguard Worker        def foo(x):
8644*da0073e9SAndroid Build Coastguard Worker            scalar_number = x.item()
8645*da0073e9SAndroid Build Coastguard Worker            return x.add(scalar_number)
8646*da0073e9SAndroid Build Coastguard Worker
8647*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
8648*da0073e9SAndroid Build Coastguard Worker        def foo2(x):
8649*da0073e9SAndroid Build Coastguard Worker            scalar_number = x.item()
8650*da0073e9SAndroid Build Coastguard Worker            return torch.tensor(1).add(scalar_number)
8651*da0073e9SAndroid Build Coastguard Worker
8652*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor(5)
8653*da0073e9SAndroid Build Coastguard Worker        g = foo.graph_for(t)
8654*da0073e9SAndroid Build Coastguard Worker        type = next(g.outputs())
8655*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(type.type() == torch._C.TensorType.get())
8656*da0073e9SAndroid Build Coastguard Worker        g2 = foo2.graph_for(t)
8657*da0073e9SAndroid Build Coastguard Worker        type = next(g.outputs())
8658*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(type.type() == torch._C.TensorType.get())
8659*da0073e9SAndroid Build Coastguard Worker
8660*da0073e9SAndroid Build Coastguard Worker
8661*da0073e9SAndroid Build Coastguard Worker    def test_filecheck_parse(self):
8662*da0073e9SAndroid Build Coastguard Worker        def test_check():
8663*da0073e9SAndroid Build Coastguard Worker            file = """
8664*da0073e9SAndroid Build Coastguard Worker                # CHECK: 2
8665*da0073e9SAndroid Build Coastguard Worker                # CHECK: 3
8666*da0073e9SAndroid Build Coastguard Worker                # CHECK: 2
8667*da0073e9SAndroid Build Coastguard Worker                232
8668*da0073e9SAndroid Build Coastguard Worker                """
8669*da0073e9SAndroid Build Coastguard Worker            FileCheck().run(checks_file=file, test_file=file)
8670*da0073e9SAndroid Build Coastguard Worker            file = """
8671*da0073e9SAndroid Build Coastguard Worker                # CHECK: 232
8672*da0073e9SAndroid Build Coastguard Worker                232
8673*da0073e9SAndroid Build Coastguard Worker                """
8674*da0073e9SAndroid Build Coastguard Worker            FileCheck().run(file, "232")
8675*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, 'Expected to find "232"'):
8676*da0073e9SAndroid Build Coastguard Worker                FileCheck().run(file, "22")
8677*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'):
8678*da0073e9SAndroid Build Coastguard Worker                FileCheck().run("# CHECK: 22", "23")
8679*da0073e9SAndroid Build Coastguard Worker        test_check()
8680*da0073e9SAndroid Build Coastguard Worker
8681*da0073e9SAndroid Build Coastguard Worker        def test_check_count():
8682*da0073e9SAndroid Build Coastguard Worker            file = "22222"
8683*da0073e9SAndroid Build Coastguard Worker            FileCheck().run("# CHECK-COUNT-5: 2", file)
8684*da0073e9SAndroid Build Coastguard Worker            FileCheck().run("# CHECK-COUNT-EXACTLY-5: 2", file)
8685*da0073e9SAndroid Build Coastguard Worker            FileCheck().run("# CHECK-COUNT-2: 22", file)
8686*da0073e9SAndroid Build Coastguard Worker            FileCheck().run("# CHECK-COUNT-1: 222", file)
8687*da0073e9SAndroid Build Coastguard Worker
8688*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, 'Expected to not find'):
8689*da0073e9SAndroid Build Coastguard Worker                FileCheck().run("# CHECK-COUNT-EXACTLY-2: 2", file)
8690*da0073e9SAndroid Build Coastguard Worker        test_check_count()
8691*da0073e9SAndroid Build Coastguard Worker
8692*da0073e9SAndroid Build Coastguard Worker        def test_check_same():
8693*da0073e9SAndroid Build Coastguard Worker            file = "22\n33"
8694*da0073e9SAndroid Build Coastguard Worker            FileCheck().run("# CHECK-SAME: 22", file)
8695*da0073e9SAndroid Build Coastguard Worker
8696*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Expected to not find"):
8697*da0073e9SAndroid Build Coastguard Worker                FileCheck().run("# CHECK-SAME: 33", file)
8698*da0073e9SAndroid Build Coastguard Worker
8699*da0073e9SAndroid Build Coastguard Worker            file = "22  1  3"
8700*da0073e9SAndroid Build Coastguard Worker
8701*da0073e9SAndroid Build Coastguard Worker            FileCheck().run("# CHECK: 2\n # CHECK-SAME: 3", file)
8702*da0073e9SAndroid Build Coastguard Worker            FileCheck().run("# CHECK-COUNT-2: 2\n # CHECK-SAME: 3", file)
8703*da0073e9SAndroid Build Coastguard Worker        test_check_same()
8704*da0073e9SAndroid Build Coastguard Worker
8705*da0073e9SAndroid Build Coastguard Worker        def test_bad_input():
8706*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Check for bad input"):
8707*da0073e9SAndroid Build Coastguard Worker                FileCheck().run("", "1")
8708*da0073e9SAndroid Build Coastguard Worker
8709*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Could not parse check"):
8710*da0073e9SAndroid Build Coastguard Worker                FileCheck().run("# CHECK1", "")
8711*da0073e9SAndroid Build Coastguard Worker
8712*da0073e9SAndroid Build Coastguard Worker        test_bad_input()
8713*da0073e9SAndroid Build Coastguard Worker
8714*da0073e9SAndroid Build Coastguard Worker    def test_script_module_call_noscript(self):
8715*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
8716*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
8717*da0073e9SAndroid Build Coastguard Worker                super().__init__()
8718*da0073e9SAndroid Build Coastguard Worker                self.value = 1
8719*da0073e9SAndroid Build Coastguard Worker
8720*da0073e9SAndroid Build Coastguard Worker            @torch.jit.ignore
8721*da0073e9SAndroid Build Coastguard Worker            def foo(self):
8722*da0073e9SAndroid Build Coastguard Worker                return torch.ones(2, 2) + self.value
8723*da0073e9SAndroid Build Coastguard Worker
8724*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
8725*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
8726*da0073e9SAndroid Build Coastguard Worker                return input + self.foo()
8727*da0073e9SAndroid Build Coastguard Worker
8728*da0073e9SAndroid Build Coastguard Worker        with torch.jit.optimized_execution(False):
8729*da0073e9SAndroid Build Coastguard Worker            m = M()
8730*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(2, 2)
8731*da0073e9SAndroid Build Coastguard Worker            o = m(input)
8732*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(o, input + torch.ones(2, 2) + 1)
8733*da0073e9SAndroid Build Coastguard Worker            # check that we can change python attributes
8734*da0073e9SAndroid Build Coastguard Worker            # and that those changes are picked up in script methods
8735*da0073e9SAndroid Build Coastguard Worker            m.value = 2
8736*da0073e9SAndroid Build Coastguard Worker            o = m(input)
8737*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(o, input + torch.ones(2, 2) + 2)
8738*da0073e9SAndroid Build Coastguard Worker
8739*da0073e9SAndroid Build Coastguard Worker    def test_script_module_nochange_submodule(self):
8740*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
8741*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
8742*da0073e9SAndroid Build Coastguard Worker                super().__init__()
8743*da0073e9SAndroid Build Coastguard Worker                self.sub = nn.Linear(5, 5)
8744*da0073e9SAndroid Build Coastguard Worker
8745*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
8746*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
8747*da0073e9SAndroid Build Coastguard Worker                return self.sub(input)
8748*da0073e9SAndroid Build Coastguard Worker        with torch.jit.optimized_execution(False):
8749*da0073e9SAndroid Build Coastguard Worker            m = M()
8750*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(1, 5, 5)
8751*da0073e9SAndroid Build Coastguard Worker            o = m(input)
8752*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(o, m.sub(input))
8753*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Cannot re-assign"):
8754*da0073e9SAndroid Build Coastguard Worker                m.sub = nn.Linear(5, 5)
8755*da0073e9SAndroid Build Coastguard Worker
8756*da0073e9SAndroid Build Coastguard Worker    def test_module_apis(self):
8757*da0073e9SAndroid Build Coastguard Worker        class Sub(torch.nn.Module):
8758*da0073e9SAndroid Build Coastguard Worker            def forward(self, thing):
8759*da0073e9SAndroid Build Coastguard Worker                return thing - 2
8760*da0073e9SAndroid Build Coastguard Worker
8761*da0073e9SAndroid Build Coastguard Worker        class Double(torch.nn.Module):
8762*da0073e9SAndroid Build Coastguard Worker            def forward(self, thing):
8763*da0073e9SAndroid Build Coastguard Worker                return thing * 2
8764*da0073e9SAndroid Build Coastguard Worker
8765*da0073e9SAndroid Build Coastguard Worker        class MyMod(torch.nn.Module):
8766*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
8767*da0073e9SAndroid Build Coastguard Worker                super().__init__()
8768*da0073e9SAndroid Build Coastguard Worker                self.mod = (Sub())
8769*da0073e9SAndroid Build Coastguard Worker                self.mod2 = (Sub())
8770*da0073e9SAndroid Build Coastguard Worker                self.mod3 = nn.Sequential(nn.Sequential(Sub()))
8771*da0073e9SAndroid Build Coastguard Worker                self.mod4 = nn.Sequential(Sub(), Double())
8772*da0073e9SAndroid Build Coastguard Worker
8773*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
8774*da0073e9SAndroid Build Coastguard Worker            def method(self, x, x1, y, y1):
8775*da0073e9SAndroid Build Coastguard Worker                mod_names = ""
8776*da0073e9SAndroid Build Coastguard Worker                for name, mod in self.named_modules():
8777*da0073e9SAndroid Build Coastguard Worker                    mod_names = mod_names + " " + name
8778*da0073e9SAndroid Build Coastguard Worker                    x = mod(x)
8779*da0073e9SAndroid Build Coastguard Worker
8780*da0073e9SAndroid Build Coastguard Worker                children_names = ""
8781*da0073e9SAndroid Build Coastguard Worker                for name, mod in self.named_children():
8782*da0073e9SAndroid Build Coastguard Worker                    children_names = children_names + " " + name
8783*da0073e9SAndroid Build Coastguard Worker                    x1 = mod(x1)
8784*da0073e9SAndroid Build Coastguard Worker
8785*da0073e9SAndroid Build Coastguard Worker                for mod in self.modules():
8786*da0073e9SAndroid Build Coastguard Worker                    y = mod(y)
8787*da0073e9SAndroid Build Coastguard Worker
8788*da0073e9SAndroid Build Coastguard Worker                for mod in self.children():
8789*da0073e9SAndroid Build Coastguard Worker                    y1 = mod(y1)
8790*da0073e9SAndroid Build Coastguard Worker
8791*da0073e9SAndroid Build Coastguard Worker                return mod_names, children_names, x, x1, y, y1
8792*da0073e9SAndroid Build Coastguard Worker
8793*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
8794*da0073e9SAndroid Build Coastguard Worker                return x + 2
8795*da0073e9SAndroid Build Coastguard Worker
8796*da0073e9SAndroid Build Coastguard Worker        mod = torch.jit.script(MyMod())
8797*da0073e9SAndroid Build Coastguard Worker        inps = tuple([torch.tensor(i) for i in range(1, 5)])
8798*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mod.method(*inps), MyMod().method(*inps))
8799*da0073e9SAndroid Build Coastguard Worker
8800*da0073e9SAndroid Build Coastguard Worker    def test_script_module_const(self):
8801*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
8802*da0073e9SAndroid Build Coastguard Worker
8803*da0073e9SAndroid Build Coastguard Worker            __constants__ = ['b', 'i', 'c', 's']
8804*da0073e9SAndroid Build Coastguard Worker
8805*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
8806*da0073e9SAndroid Build Coastguard Worker                super().__init__()
8807*da0073e9SAndroid Build Coastguard Worker                self.b = False
8808*da0073e9SAndroid Build Coastguard Worker                self.i = 1
8809*da0073e9SAndroid Build Coastguard Worker                self.c = 3.5
8810*da0073e9SAndroid Build Coastguard Worker                self.s = ["hello"]
8811*da0073e9SAndroid Build Coastguard Worker
8812*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
8813*da0073e9SAndroid Build Coastguard Worker            def forward(self):
8814*da0073e9SAndroid Build Coastguard Worker                return self.b, self.i, self.c
8815*da0073e9SAndroid Build Coastguard Worker
8816*da0073e9SAndroid Build Coastguard Worker        with torch.jit.optimized_execution(False):
8817*da0073e9SAndroid Build Coastguard Worker            m = M()
8818*da0073e9SAndroid Build Coastguard Worker            o0, o1, o2 = m()
8819*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(o0, 0)
8820*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(o1, 1)
8821*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(o2, 3.5)
8822*da0073e9SAndroid Build Coastguard Worker
8823*da0073e9SAndroid Build Coastguard Worker    def test_script_module_fail_exist(self):
8824*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
8825*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
8826*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
8827*da0073e9SAndroid Build Coastguard Worker                return x + self.whatisgoingon
8828*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Module 'M' has no attribute"):
8829*da0073e9SAndroid Build Coastguard Worker            M()
8830*da0073e9SAndroid Build Coastguard Worker
8831*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("[module dedupe] currently NoneType refinement on optional attributes doesn't work.")
8832*da0073e9SAndroid Build Coastguard Worker    def test_script_module_none_exist_fail(self):
8833*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
8834*da0073e9SAndroid Build Coastguard Worker            def __init__(self, my_optional):
8835*da0073e9SAndroid Build Coastguard Worker                super().__init__()
8836*da0073e9SAndroid Build Coastguard Worker                self.my_optional = my_optional
8837*da0073e9SAndroid Build Coastguard Worker
8838*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
8839*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
8840*da0073e9SAndroid Build Coastguard Worker                if self.my_optional is not None:
8841*da0073e9SAndroid Build Coastguard Worker                    return torch.neg(x) + self.my_optional
8842*da0073e9SAndroid Build Coastguard Worker                return torch.neg(x)
8843*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "has no attribute 'my_optional'"):
8844*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(3, 4)
8845*da0073e9SAndroid Build Coastguard Worker            fb = M(None)
8846*da0073e9SAndroid Build Coastguard Worker            fb(x)
8847*da0073e9SAndroid Build Coastguard Worker
8848*da0073e9SAndroid Build Coastguard Worker    def test_script_module_invalid_consts(self):
8849*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.jit.ScriptModule):
8850*da0073e9SAndroid Build Coastguard Worker            __constants__ = ['invalid']
8851*da0073e9SAndroid Build Coastguard Worker
8852*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
8853*da0073e9SAndroid Build Coastguard Worker                super().__init__()
8854*da0073e9SAndroid Build Coastguard Worker                self.invalid = [nn.Linear(3, 4)]
8855*da0073e9SAndroid Build Coastguard Worker
8856*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
8857*da0073e9SAndroid Build Coastguard Worker                TypeError,
8858*da0073e9SAndroid Build Coastguard Worker                "Linear' object in attribute 'Foo.invalid' is not a valid constant"):
8859*da0073e9SAndroid Build Coastguard Worker            Foo()
8860*da0073e9SAndroid Build Coastguard Worker
8861*da0073e9SAndroid Build Coastguard Worker        class Foo2(torch.jit.ScriptModule):
8862*da0073e9SAndroid Build Coastguard Worker            __constants__ = ['invalid']
8863*da0073e9SAndroid Build Coastguard Worker
8864*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
8865*da0073e9SAndroid Build Coastguard Worker                super().__init__()
8866*da0073e9SAndroid Build Coastguard Worker                self.invalid = int
8867*da0073e9SAndroid Build Coastguard Worker
8868*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, "not a valid constant"):
8869*da0073e9SAndroid Build Coastguard Worker            Foo2()
8870*da0073e9SAndroid Build Coastguard Worker
8871*da0073e9SAndroid Build Coastguard Worker        class Foo3(torch.jit.ScriptModule):
8872*da0073e9SAndroid Build Coastguard Worker            __constants__ = ['invalid']
8873*da0073e9SAndroid Build Coastguard Worker
8874*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
8875*da0073e9SAndroid Build Coastguard Worker                super().__init__()
8876*da0073e9SAndroid Build Coastguard Worker                self.invalid = (3, 4, {})
8877*da0073e9SAndroid Build Coastguard Worker
8878*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, "not a valid constant"):
8879*da0073e9SAndroid Build Coastguard Worker            Foo3()
8880*da0073e9SAndroid Build Coastguard Worker
8881*da0073e9SAndroid Build Coastguard Worker        class Foo4(torch.jit.ScriptModule):
8882*da0073e9SAndroid Build Coastguard Worker            __constants__ = ['invalid']
8883*da0073e9SAndroid Build Coastguard Worker
8884*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
8885*da0073e9SAndroid Build Coastguard Worker                super().__init__()
8886*da0073e9SAndroid Build Coastguard Worker                self.invalid = np.int64(5)
8887*da0073e9SAndroid Build Coastguard Worker
8888*da0073e9SAndroid Build Coastguard Worker        # verify that we capture human understandable class name
8889*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, "numpy.int64"):
8890*da0073e9SAndroid Build Coastguard Worker            Foo4()
8891*da0073e9SAndroid Build Coastguard Worker
8892*da0073e9SAndroid Build Coastguard Worker    def test_script_module_param_buffer_mutation(self):
8893*da0073e9SAndroid Build Coastguard Worker        # TODO: add param mutation test case after JIT support it
8894*da0073e9SAndroid Build Coastguard Worker        class ModuleBufferMutate(torch.jit.ScriptModule):
8895*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
8896*da0073e9SAndroid Build Coastguard Worker                super().__init__()
8897*da0073e9SAndroid Build Coastguard Worker                self.running_var = nn.Buffer(torch.tensor(0, dtype=torch.long))
8898*da0073e9SAndroid Build Coastguard Worker
8899*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
8900*da0073e9SAndroid Build Coastguard Worker            def forward(self):
8901*da0073e9SAndroid Build Coastguard Worker                if self.training:
8902*da0073e9SAndroid Build Coastguard Worker                    self.running_var += 1
8903*da0073e9SAndroid Build Coastguard Worker                return self.running_var
8904*da0073e9SAndroid Build Coastguard Worker
8905*da0073e9SAndroid Build Coastguard Worker        with torch.jit.optimized_execution(False):
8906*da0073e9SAndroid Build Coastguard Worker            m = ModuleBufferMutate()
8907*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m(), 1)
8908*da0073e9SAndroid Build Coastguard Worker            m.eval()
8909*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m(), 1)
8910*da0073e9SAndroid Build Coastguard Worker
8911*da0073e9SAndroid Build Coastguard Worker    def test_script_module_for(self):
8912*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
8913*da0073e9SAndroid Build Coastguard Worker            __constants__ = ['b']
8914*da0073e9SAndroid Build Coastguard Worker
8915*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
8916*da0073e9SAndroid Build Coastguard Worker                super().__init__()
8917*da0073e9SAndroid Build Coastguard Worker                self.b = [1, 2, 3, 4]
8918*da0073e9SAndroid Build Coastguard Worker
8919*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
8920*da0073e9SAndroid Build Coastguard Worker            def forward(self):
8921*da0073e9SAndroid Build Coastguard Worker                sum = 0
8922*da0073e9SAndroid Build Coastguard Worker                for i in self.b:
8923*da0073e9SAndroid Build Coastguard Worker                    sum += i
8924*da0073e9SAndroid Build Coastguard Worker                return sum
8925*da0073e9SAndroid Build Coastguard Worker
8926*da0073e9SAndroid Build Coastguard Worker        with torch.jit.optimized_execution(False):
8927*da0073e9SAndroid Build Coastguard Worker            m = M()
8928*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m(), 10)
8929*da0073e9SAndroid Build Coastguard Worker
8930*da0073e9SAndroid Build Coastguard Worker    def test_override_magic(self):
8931*da0073e9SAndroid Build Coastguard Worker        class OverrideMagic(nn.Module):
8932*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
8933*da0073e9SAndroid Build Coastguard Worker            def __len__(self):
8934*da0073e9SAndroid Build Coastguard Worker                return 10
8935*da0073e9SAndroid Build Coastguard Worker
8936*da0073e9SAndroid Build Coastguard Worker        mod = OverrideMagic()
8937*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(mod), len(torch.jit.script(mod)))
8938*da0073e9SAndroid Build Coastguard Worker
8939*da0073e9SAndroid Build Coastguard Worker        class OverrideMagicSeq(nn.Sequential):
8940*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
8941*da0073e9SAndroid Build Coastguard Worker            def __len__(self):
8942*da0073e9SAndroid Build Coastguard Worker                return 10
8943*da0073e9SAndroid Build Coastguard Worker
8944*da0073e9SAndroid Build Coastguard Worker        mod = OverrideMagicSeq()
8945*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(mod), len(torch.jit.script(mod)))
8946*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.jit.script(mod))
8947*da0073e9SAndroid Build Coastguard Worker
8948*da0073e9SAndroid Build Coastguard Worker    def test_script_module_for2(self):
8949*da0073e9SAndroid Build Coastguard Worker        class Sub(torch.jit.ScriptModule):
8950*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
8951*da0073e9SAndroid Build Coastguard Worker                super().__init__()
8952*da0073e9SAndroid Build Coastguard Worker                self.weight = nn.Parameter(torch.randn(2))
8953*da0073e9SAndroid Build Coastguard Worker
8954*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
8955*da0073e9SAndroid Build Coastguard Worker            def forward(self, thing):
8956*da0073e9SAndroid Build Coastguard Worker                return self.weight + thing
8957*da0073e9SAndroid Build Coastguard Worker
8958*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
8959*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
8960*da0073e9SAndroid Build Coastguard Worker                super().__init__()
8961*da0073e9SAndroid Build Coastguard Worker                self.mods = nn.ModuleList([Sub() for i in range(10)])
8962*da0073e9SAndroid Build Coastguard Worker
8963*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
8964*da0073e9SAndroid Build Coastguard Worker            def forward(self, v):
8965*da0073e9SAndroid Build Coastguard Worker                for m in self.mods:
8966*da0073e9SAndroid Build Coastguard Worker                    v = m(v)
8967*da0073e9SAndroid Build Coastguard Worker                return v
8968*da0073e9SAndroid Build Coastguard Worker
8969*da0073e9SAndroid Build Coastguard Worker        with torch.jit.optimized_execution(False):
8970*da0073e9SAndroid Build Coastguard Worker            i = torch.empty(2)
8971*da0073e9SAndroid Build Coastguard Worker            m = M()
8972*da0073e9SAndroid Build Coastguard Worker            o = m(i)
8973*da0073e9SAndroid Build Coastguard Worker            v = i
8974*da0073e9SAndroid Build Coastguard Worker            for sub in m.mods:
8975*da0073e9SAndroid Build Coastguard Worker                v = sub(v)
8976*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(o, v)
8977*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(Exception, "object is not iterable"):
8978*da0073e9SAndroid Build Coastguard Worker                print(list(m))
8979*da0073e9SAndroid Build Coastguard Worker
8980*da0073e9SAndroid Build Coastguard Worker    def test_attr_qscheme_script(self):
8981*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
8982*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
8983*da0073e9SAndroid Build Coastguard Worker                super().__init__()
8984*da0073e9SAndroid Build Coastguard Worker                self.qscheme = torch.per_tensor_affine
8985*da0073e9SAndroid Build Coastguard Worker
8986*da0073e9SAndroid Build Coastguard Worker            def forward(self):
8987*da0073e9SAndroid Build Coastguard Worker                if self.qscheme == torch.per_tensor_symmetric:
8988*da0073e9SAndroid Build Coastguard Worker                    return 3
8989*da0073e9SAndroid Build Coastguard Worker                else:
8990*da0073e9SAndroid Build Coastguard Worker                    return 4
8991*da0073e9SAndroid Build Coastguard Worker
8992*da0073e9SAndroid Build Coastguard Worker        f = Foo()
8993*da0073e9SAndroid Build Coastguard Worker        scripted = torch.jit.script(f)
8994*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f(), scripted())
8995*da0073e9SAndroid Build Coastguard Worker
8996*da0073e9SAndroid Build Coastguard Worker    def test_script_module_const_submodule_fail(self):
8997*da0073e9SAndroid Build Coastguard Worker        class Sub(torch.jit.ScriptModule):
8998*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
8999*da0073e9SAndroid Build Coastguard Worker                super().__init__()
9000*da0073e9SAndroid Build Coastguard Worker                self.weight = nn.Parameter(torch.randn(2))
9001*da0073e9SAndroid Build Coastguard Worker
9002*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
9003*da0073e9SAndroid Build Coastguard Worker            def forward(self, thing):
9004*da0073e9SAndroid Build Coastguard Worker                return self.weight + thing
9005*da0073e9SAndroid Build Coastguard Worker
9006*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
9007*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
9008*da0073e9SAndroid Build Coastguard Worker                super().__init__()
9009*da0073e9SAndroid Build Coastguard Worker                self.mods = [Sub() for _ in range(10)]
9010*da0073e9SAndroid Build Coastguard Worker
9011*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
9012*da0073e9SAndroid Build Coastguard Worker            def forward(self):
9013*da0073e9SAndroid Build Coastguard Worker                for _ in self.mods:
9014*da0073e9SAndroid Build Coastguard Worker                    print(1)
9015*da0073e9SAndroid Build Coastguard Worker                return 4
9016*da0073e9SAndroid Build Coastguard Worker
9017*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "has no attribute 'mods'"):
9018*da0073e9SAndroid Build Coastguard Worker            M()
9019*da0073e9SAndroid Build Coastguard Worker
9020*da0073e9SAndroid Build Coastguard Worker    class DerivedStateModule(torch.jit.ScriptModule):
9021*da0073e9SAndroid Build Coastguard Worker        def __init__(self) -> None:
9022*da0073e9SAndroid Build Coastguard Worker            super(TestScript.DerivedStateModule, self).__init__()
9023*da0073e9SAndroid Build Coastguard Worker            self.param = torch.nn.Parameter(torch.ones(3, 4, dtype=torch.float))
9024*da0073e9SAndroid Build Coastguard Worker            self.derived = nn.Buffer(torch.neg(self.param).detach().clone())
9025*da0073e9SAndroid Build Coastguard Worker
9026*da0073e9SAndroid Build Coastguard Worker            # This is a flag so we can test that the pack method was called
9027*da0073e9SAndroid Build Coastguard Worker            self.pack_called = nn.Buffer(torch.zeros(1, dtype=torch.long))
9028*da0073e9SAndroid Build Coastguard Worker            # This is a flag so we can test that the unpack method was called
9029*da0073e9SAndroid Build Coastguard Worker            self.unpack_called = nn.Buffer(torch.zeros(1, dtype=torch.long))
9030*da0073e9SAndroid Build Coastguard Worker
9031*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script_method
9032*da0073e9SAndroid Build Coastguard Worker        def _pack(self):
9033*da0073e9SAndroid Build Coastguard Worker            self.pack_called.set_(torch.ones(1, dtype=torch.long))
9034*da0073e9SAndroid Build Coastguard Worker            self.derived.set_(torch.rand(1).detach())
9035*da0073e9SAndroid Build Coastguard Worker
9036*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script_method
9037*da0073e9SAndroid Build Coastguard Worker        def _unpack(self):
9038*da0073e9SAndroid Build Coastguard Worker            self.unpack_called.set_(torch.ones(1, dtype=torch.long))
9039*da0073e9SAndroid Build Coastguard Worker            self.derived.set_(torch.neg(self.param).detach())
9040*da0073e9SAndroid Build Coastguard Worker
9041*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script_method
9042*da0073e9SAndroid Build Coastguard Worker        def forward(self, x):
9043*da0073e9SAndroid Build Coastguard Worker            return x + self.derived
9044*da0073e9SAndroid Build Coastguard Worker
9045*da0073e9SAndroid Build Coastguard Worker    def test_pack_unpack_state(self):
9046*da0073e9SAndroid Build Coastguard Worker        sm = TestScript.DerivedStateModule()
9047*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
9048*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(sm(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float)))
9049*da0073e9SAndroid Build Coastguard Worker
9050*da0073e9SAndroid Build Coastguard Worker        # Test save path
9051*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(sm.pack_called.item())
9052*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(sm.unpack_called.item())
9053*da0073e9SAndroid Build Coastguard Worker        imported = self.getExportImportCopyWithPacking(sm)
9054*da0073e9SAndroid Build Coastguard Worker        # ensure pack was called before serialization
9055*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(sm.pack_called.item())
9056*da0073e9SAndroid Build Coastguard Worker        # ensure unpack was called after serialization so as to leave the module in an initialized state
9057*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(sm.unpack_called.item())
9058*da0073e9SAndroid Build Coastguard Worker
9059*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(sm.derived, torch.neg(sm.param))
9060*da0073e9SAndroid Build Coastguard Worker
9061*da0073e9SAndroid Build Coastguard Worker        # Test load paths
9062*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(imported.unpack_called.item())
9063*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(imported(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float)))
9064*da0073e9SAndroid Build Coastguard Worker
9065*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support")
9066*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(True, "Skipping while landing PR stack")
9067*da0073e9SAndroid Build Coastguard Worker    def test_torch_functional(self):
9068*da0073e9SAndroid Build Coastguard Worker        def stft(input, n_fft):
9069*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor, int) -> Tensor
9070*da0073e9SAndroid Build Coastguard Worker            return torch.stft(input, n_fft, return_complex=True)
9071*da0073e9SAndroid Build Coastguard Worker
9072*da0073e9SAndroid Build Coastguard Worker        inps = (torch.randn(10), 7)
9073*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(stft(*inps), torch.jit.script(stft)(*inps))
9074*da0073e9SAndroid Build Coastguard Worker
9075*da0073e9SAndroid Build Coastguard Worker        def istft(input, n_fft):
9076*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor, int) -> Tensor
9077*da0073e9SAndroid Build Coastguard Worker            return torch.istft(input, n_fft)
9078*da0073e9SAndroid Build Coastguard Worker
9079*da0073e9SAndroid Build Coastguard Worker        inps2 = (stft(*inps), inps[1])
9080*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(istft(*inps2), torch.jit.script(istft)(*inps2))
9081*da0073e9SAndroid Build Coastguard Worker
9082*da0073e9SAndroid Build Coastguard Worker        def lu_unpack(x):
9083*da0073e9SAndroid Build Coastguard Worker            A_LU, pivots = torch.linalg.lu_factor(x)
9084*da0073e9SAndroid Build Coastguard Worker            return torch.lu_unpack(A_LU, pivots)
9085*da0073e9SAndroid Build Coastguard Worker
9086*da0073e9SAndroid Build Coastguard Worker        for shape in ((3, 3), (5, 3, 3), (7, 3, 5, 5), (7, 5, 3, 3, 3)):
9087*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(*shape)
9088*da0073e9SAndroid Build Coastguard Worker            self.checkScript(lu_unpack, (a,))
9089*da0073e9SAndroid Build Coastguard Worker
9090*da0073e9SAndroid Build Coastguard Worker        def cdist_fn():
9091*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor([[0.9041, 0.0196], [-0.3108, -2.4423], [-0.4821, 1.059]])
9092*da0073e9SAndroid Build Coastguard Worker            b = torch.tensor([[-2.1763, -0.4713], [-0.6986, 1.3702]])
9093*da0073e9SAndroid Build Coastguard Worker            return torch.cdist(a, b, compute_mode="use_mm_for_euclid_dist")
9094*da0073e9SAndroid Build Coastguard Worker
9095*da0073e9SAndroid Build Coastguard Worker        self.checkScript(cdist_fn, ())
9096*da0073e9SAndroid Build Coastguard Worker
9097*da0073e9SAndroid Build Coastguard Worker        def norm():
9098*da0073e9SAndroid Build Coastguard Worker            c = torch.tensor([[1, 2, 3], [-1, 1, 4]], dtype=torch.float)
9099*da0073e9SAndroid Build Coastguard Worker            return torch.norm(c, p="fro"), torch.norm(c, p="nuc"), torch.norm(c), torch.norm(c, p=.5)
9100*da0073e9SAndroid Build Coastguard Worker
9101*da0073e9SAndroid Build Coastguard Worker        self.checkScript(norm, ())
9102*da0073e9SAndroid Build Coastguard Worker
9103*da0073e9SAndroid Build Coastguard Worker        def torch_unique(dim: Optional[int]):
9104*da0073e9SAndroid Build Coastguard Worker            ten = torch.unique(torch.tensor([[1, 3], [2, 3]], dtype=torch.long))
9105*da0073e9SAndroid Build Coastguard Worker            a = torch.unique(ten, dim=dim)
9106*da0073e9SAndroid Build Coastguard Worker            b = torch.unique(ten, return_counts=True, dim=dim)
9107*da0073e9SAndroid Build Coastguard Worker            c = torch.unique(ten, return_inverse=True, dim=dim)
9108*da0073e9SAndroid Build Coastguard Worker            d = torch.unique(ten, return_counts=True, return_inverse=True, dim=dim)
9109*da0073e9SAndroid Build Coastguard Worker            return a, b, c, d
9110*da0073e9SAndroid Build Coastguard Worker
9111*da0073e9SAndroid Build Coastguard Worker        self.checkScript(torch_unique, (None,))
9112*da0073e9SAndroid Build Coastguard Worker        self.checkScript(torch_unique, (0,))
9113*da0073e9SAndroid Build Coastguard Worker
9114*da0073e9SAndroid Build Coastguard Worker        def torch_unique_consecutive(dim: Optional[int]):
9115*da0073e9SAndroid Build Coastguard Worker            ten = torch.unique(torch.tensor([[1, 3], [3, 2], [3, 2], [2, 3]], dtype=torch.long))
9116*da0073e9SAndroid Build Coastguard Worker            a = torch.unique_consecutive(ten, dim=dim)
9117*da0073e9SAndroid Build Coastguard Worker            b = torch.unique_consecutive(ten, return_counts=True, dim=dim)
9118*da0073e9SAndroid Build Coastguard Worker            c = torch.unique_consecutive(ten, return_inverse=True, dim=dim)
9119*da0073e9SAndroid Build Coastguard Worker            d = torch.unique_consecutive(ten, return_counts=True, return_inverse=True, dim=dim)
9120*da0073e9SAndroid Build Coastguard Worker            return a, b, c, d
9121*da0073e9SAndroid Build Coastguard Worker
9122*da0073e9SAndroid Build Coastguard Worker        self.checkScript(torch_unique_consecutive, (None,))
9123*da0073e9SAndroid Build Coastguard Worker        self.checkScript(torch_unique_consecutive, (0,))
9124*da0073e9SAndroid Build Coastguard Worker
9125*da0073e9SAndroid Build Coastguard Worker    def test_torch_functional_tensordot_int(self):
9126*da0073e9SAndroid Build Coastguard Worker        def tensordot_dims_int(a: torch.Tensor, b: torch.Tensor, dims: int):
9127*da0073e9SAndroid Build Coastguard Worker            return torch.tensordot(a, b, dims=dims)
9128*da0073e9SAndroid Build Coastguard Worker
9129*da0073e9SAndroid Build Coastguard Worker        a = torch.arange(120.).reshape(2, 3, 4, 5)
9130*da0073e9SAndroid Build Coastguard Worker        b = torch.arange(840.).reshape(4, 5, 6, 7)
9131*da0073e9SAndroid Build Coastguard Worker        dims = 2
9132*da0073e9SAndroid Build Coastguard Worker        self.checkScript(tensordot_dims_int, (a, b, dims))
9133*da0073e9SAndroid Build Coastguard Worker
9134*da0073e9SAndroid Build Coastguard Worker        for dims in [-1, 5]:
9135*da0073e9SAndroid Build Coastguard Worker            try:
9136*da0073e9SAndroid Build Coastguard Worker                tensordot_dims_int(a, b, dims)
9137*da0073e9SAndroid Build Coastguard Worker            except RuntimeError as error:
9138*da0073e9SAndroid Build Coastguard Worker                if dims < 0:
9139*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(str(error), "tensordot expects dims >= 0, but got dims=" + str(dims))
9140*da0073e9SAndroid Build Coastguard Worker                if dims > min(a.dim(), b.dim()):
9141*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(str(error), "tensordot expects dims < ndim_a or ndim_b, but got dims=" + str(dims))
9142*da0073e9SAndroid Build Coastguard Worker
9143*da0073e9SAndroid Build Coastguard Worker    def test_torch_functional_tensordot_tensor(self):
9144*da0073e9SAndroid Build Coastguard Worker        def tensordot_dims_tensor(a: torch.Tensor, b: torch.Tensor, dims: torch.Tensor):
9145*da0073e9SAndroid Build Coastguard Worker            return torch.tensordot(a, b, dims=dims)
9146*da0073e9SAndroid Build Coastguard Worker
9147*da0073e9SAndroid Build Coastguard Worker        a = torch.arange(120.).reshape(2, 3, 4, 5)
9148*da0073e9SAndroid Build Coastguard Worker        b = torch.arange(840.).reshape(4, 5, 6, 7)
9149*da0073e9SAndroid Build Coastguard Worker        dims = torch.tensor([2])
9150*da0073e9SAndroid Build Coastguard Worker        self.checkScript(tensordot_dims_tensor, (a, b, dims))
9151*da0073e9SAndroid Build Coastguard Worker
9152*da0073e9SAndroid Build Coastguard Worker        a = torch.arange(60.).reshape(3, 4, 5)
9153*da0073e9SAndroid Build Coastguard Worker        b = torch.arange(24.).reshape(4, 3, 2)
9154*da0073e9SAndroid Build Coastguard Worker        dims = torch.tensor([[1, 0], [0, 1]], dtype=torch.long)
9155*da0073e9SAndroid Build Coastguard Worker        self.checkScript(tensordot_dims_tensor, (a, b, dims))
9156*da0073e9SAndroid Build Coastguard Worker
9157*da0073e9SAndroid Build Coastguard Worker    def test_torch_functional_tensordot_list(self):
9158*da0073e9SAndroid Build Coastguard Worker        def tensordot_dims_list(a: torch.Tensor, b: torch.Tensor, dims: List[List[int]]):
9159*da0073e9SAndroid Build Coastguard Worker            return torch.tensordot(a, b, dims=dims)
9160*da0073e9SAndroid Build Coastguard Worker
9161*da0073e9SAndroid Build Coastguard Worker        a = torch.arange(60.).reshape(3, 4, 5)
9162*da0073e9SAndroid Build Coastguard Worker        b = torch.arange(24.).reshape(4, 3, 2)
9163*da0073e9SAndroid Build Coastguard Worker        dims = [[1, 0], [0, 1]]
9164*da0073e9SAndroid Build Coastguard Worker        self.checkScript(tensordot_dims_list, (a, b, dims))
9165*da0073e9SAndroid Build Coastguard Worker
9166*da0073e9SAndroid Build Coastguard Worker    def test_torch_functional_tensordot_tuple(self):
9167*da0073e9SAndroid Build Coastguard Worker        def tensordot_dims_tuple(a: torch.Tensor, b: torch.Tensor, dims: Tuple[List[int], List[int]]):
9168*da0073e9SAndroid Build Coastguard Worker            return torch.tensordot(a, b, dims=dims)
9169*da0073e9SAndroid Build Coastguard Worker
9170*da0073e9SAndroid Build Coastguard Worker        a = torch.arange(60.).reshape(3, 4, 5)
9171*da0073e9SAndroid Build Coastguard Worker        b = torch.arange(24.).reshape(4, 3, 2)
9172*da0073e9SAndroid Build Coastguard Worker        dims = ([1, 0], [0, 1])
9173*da0073e9SAndroid Build Coastguard Worker        self.checkScript(tensordot_dims_tuple, (a, b, dims))
9174*da0073e9SAndroid Build Coastguard Worker
9175*da0073e9SAndroid Build Coastguard Worker    def test_missing_getstate(self):
9176*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
9177*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
9178*da0073e9SAndroid Build Coastguard Worker                super().__init__()
9179*da0073e9SAndroid Build Coastguard Worker                self.x = 1
9180*da0073e9SAndroid Build Coastguard Worker
9181*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
9182*da0073e9SAndroid Build Coastguard Worker                return x * self.x
9183*da0073e9SAndroid Build Coastguard Worker
9184*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
9185*da0073e9SAndroid Build Coastguard Worker            def __setstate__(self, state):
9186*da0073e9SAndroid Build Coastguard Worker                self.x = state[0]
9187*da0073e9SAndroid Build Coastguard Worker                self.training = state[1]
9188*da0073e9SAndroid Build Coastguard Worker
9189*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "getstate"):
9190*da0073e9SAndroid Build Coastguard Worker            scripted = torch.jit.script(Foo())
9191*da0073e9SAndroid Build Coastguard Worker
9192*da0073e9SAndroid Build Coastguard Worker    def test_inlining_cleanup(self):
9193*da0073e9SAndroid Build Coastguard Worker        def foo(x):
9194*da0073e9SAndroid Build Coastguard Worker            return F.linear(x, x)
9195*da0073e9SAndroid Build Coastguard Worker
9196*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
9197*da0073e9SAndroid Build Coastguard Worker        def fee(x):
9198*da0073e9SAndroid Build Coastguard Worker            return foo(x)
9199*da0073e9SAndroid Build Coastguard Worker
9200*da0073e9SAndroid Build Coastguard Worker        # inlining optimizations should have cleaned up linear if statement
9201*da0073e9SAndroid Build Coastguard Worker        self.run_pass("inline", fee.graph)
9202*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("prim::If").run(fee.graph)
9203*da0073e9SAndroid Build Coastguard Worker
9204*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
9205*da0073e9SAndroid Build Coastguard Worker    def test_pack_unpack_nested(self):
9206*da0073e9SAndroid Build Coastguard Worker        class SubSubMod(torch.jit.ScriptModule):
9207*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
9208*da0073e9SAndroid Build Coastguard Worker                super().__init__()
9209*da0073e9SAndroid Build Coastguard Worker                self.buf = nn.Buffer(torch.ones(3, 4) * 3)
9210*da0073e9SAndroid Build Coastguard Worker
9211*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
9212*da0073e9SAndroid Build Coastguard Worker            def _pack(self):
9213*da0073e9SAndroid Build Coastguard Worker                self.buf.set_(torch.zeros(1))
9214*da0073e9SAndroid Build Coastguard Worker
9215*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
9216*da0073e9SAndroid Build Coastguard Worker            def _unpack(self):
9217*da0073e9SAndroid Build Coastguard Worker                self.buf.set_(torch.ones(3, 4) * 3)
9218*da0073e9SAndroid Build Coastguard Worker
9219*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
9220*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
9221*da0073e9SAndroid Build Coastguard Worker                return x + self.buf
9222*da0073e9SAndroid Build Coastguard Worker
9223*da0073e9SAndroid Build Coastguard Worker        class SubMod(torch.jit.ScriptModule):
9224*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
9225*da0073e9SAndroid Build Coastguard Worker                super().__init__()
9226*da0073e9SAndroid Build Coastguard Worker                self.buf = nn.Buffer(torch.ones(3, 4) * 2)
9227*da0073e9SAndroid Build Coastguard Worker                self.ssm = SubSubMod()
9228*da0073e9SAndroid Build Coastguard Worker
9229*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
9230*da0073e9SAndroid Build Coastguard Worker            def _pack(self):
9231*da0073e9SAndroid Build Coastguard Worker                self.buf.set_(torch.zeros(1))
9232*da0073e9SAndroid Build Coastguard Worker
9233*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
9234*da0073e9SAndroid Build Coastguard Worker            def _unpack(self):
9235*da0073e9SAndroid Build Coastguard Worker                self.buf.set_(torch.ones(3, 4) * 2)
9236*da0073e9SAndroid Build Coastguard Worker
9237*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
9238*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
9239*da0073e9SAndroid Build Coastguard Worker                return self.ssm(x + self.buf)
9240*da0073e9SAndroid Build Coastguard Worker
9241*da0073e9SAndroid Build Coastguard Worker        class Mod(torch.jit.ScriptModule):
9242*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
9243*da0073e9SAndroid Build Coastguard Worker                super().__init__()
9244*da0073e9SAndroid Build Coastguard Worker                self.submod = SubMod()
9245*da0073e9SAndroid Build Coastguard Worker                self.buf = nn.Buffer(torch.ones(3, 4) * 1)
9246*da0073e9SAndroid Build Coastguard Worker
9247*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
9248*da0073e9SAndroid Build Coastguard Worker            def _pack(self):
9249*da0073e9SAndroid Build Coastguard Worker                self.buf.set_(torch.zeros(1))
9250*da0073e9SAndroid Build Coastguard Worker
9251*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
9252*da0073e9SAndroid Build Coastguard Worker            def _unpack(self):
9253*da0073e9SAndroid Build Coastguard Worker                self.buf.set_(torch.ones(3, 4))
9254*da0073e9SAndroid Build Coastguard Worker
9255*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
9256*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
9257*da0073e9SAndroid Build Coastguard Worker                return self.submod(x + self.buf)
9258*da0073e9SAndroid Build Coastguard Worker
9259*da0073e9SAndroid Build Coastguard Worker        m = Mod()
9260*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6)
9261*da0073e9SAndroid Build Coastguard Worker        m.apply(lambda s: s._pack())
9262*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(m(torch.zeros(3, 4)), torch.zeros(3, 4))
9263*da0073e9SAndroid Build Coastguard Worker        m.apply(lambda s: s._unpack())
9264*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6)
9265*da0073e9SAndroid Build Coastguard Worker
9266*da0073e9SAndroid Build Coastguard Worker    def test_torch_any(self):
9267*da0073e9SAndroid Build Coastguard Worker        def fn(x):
9268*da0073e9SAndroid Build Coastguard Worker            return torch.any(x)
9269*da0073e9SAndroid Build Coastguard Worker
9270*da0073e9SAndroid Build Coastguard Worker        def fn1(x, dim: int):
9271*da0073e9SAndroid Build Coastguard Worker            return torch.any(x, dim)
9272*da0073e9SAndroid Build Coastguard Worker
9273*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (torch.randn(3, 4), ))
9274*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (torch.empty(3), ))
9275*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (torch.empty(1), ))
9276*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (torch.ones(3, 4),))
9277*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (torch.zeros(5, 7, 1),))
9278*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn1, (torch.empty(3, 4), -2))
9279*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn1, (torch.randn(3, 8), 1))
9280*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn1, (torch.zeros(3, 6, 9), -3))
9281*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn1, (torch.empty(5), 0))
9282*da0073e9SAndroid Build Coastguard Worker
9283*da0073e9SAndroid Build Coastguard Worker    def test_any(self):
9284*da0073e9SAndroid Build Coastguard Worker        def fn(x: List[int]):
9285*da0073e9SAndroid Build Coastguard Worker            return any(x)
9286*da0073e9SAndroid Build Coastguard Worker
9287*da0073e9SAndroid Build Coastguard Worker        def fn1(x: List[float]):
9288*da0073e9SAndroid Build Coastguard Worker            return any(x)
9289*da0073e9SAndroid Build Coastguard Worker
9290*da0073e9SAndroid Build Coastguard Worker        def fn2(x: List[bool]):
9291*da0073e9SAndroid Build Coastguard Worker            return any(x)
9292*da0073e9SAndroid Build Coastguard Worker
9293*da0073e9SAndroid Build Coastguard Worker        def fn3(x: List[str]):
9294*da0073e9SAndroid Build Coastguard Worker            return any(x)
9295*da0073e9SAndroid Build Coastguard Worker
9296*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ([0, 0, 0, 0], ))
9297*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ([0, 3, 0], ))
9298*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ([], ))
9299*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn1, ([1.0, 2.0, 3.0], ))
9300*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn1, ([0.0, 0.0, 0.0], ))
9301*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn1, ([0, 0, 0], ))
9302*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn1, ([], ))
9303*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn2, ([True, False, False], ))
9304*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn2, ([False, False, False], ))
9305*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn2, ([True, True, True, True], ))
9306*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn2, ([], ))
9307*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn3, (["", "", ""], ))
9308*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn3, (["", "", "", "-1"], ))
9309*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn3, ([], ))
9310*da0073e9SAndroid Build Coastguard Worker
9311*da0073e9SAndroid Build Coastguard Worker    def test_script_module_not_tuple(self):
9312*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
9313*da0073e9SAndroid Build Coastguard Worker            __constants__ = ['mods']
9314*da0073e9SAndroid Build Coastguard Worker
9315*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
9316*da0073e9SAndroid Build Coastguard Worker                super().__init__()
9317*da0073e9SAndroid Build Coastguard Worker                self.mods = 1
9318*da0073e9SAndroid Build Coastguard Worker
9319*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
9320*da0073e9SAndroid Build Coastguard Worker            def forward(self, v):
9321*da0073e9SAndroid Build Coastguard Worker                for m in self.mods:
9322*da0073e9SAndroid Build Coastguard Worker                    print(m)
9323*da0073e9SAndroid Build Coastguard Worker                return v
9324*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"):
9325*da0073e9SAndroid Build Coastguard Worker            M()
9326*da0073e9SAndroid Build Coastguard Worker
9327*da0073e9SAndroid Build Coastguard Worker    def test_attr_module_constants(self):
9328*da0073e9SAndroid Build Coastguard Worker        class M2(torch.jit.ScriptModule):
9329*da0073e9SAndroid Build Coastguard Worker            def __init__(self, mod_list):
9330*da0073e9SAndroid Build Coastguard Worker                super().__init__()
9331*da0073e9SAndroid Build Coastguard Worker                self.mods = mod_list
9332*da0073e9SAndroid Build Coastguard Worker
9333*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
9334*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
9335*da0073e9SAndroid Build Coastguard Worker                return self.mods.forward(x)
9336*da0073e9SAndroid Build Coastguard Worker
9337*da0073e9SAndroid Build Coastguard Worker        with torch.jit.optimized_execution(False):
9338*da0073e9SAndroid Build Coastguard Worker            m = M2(nn.Sequential(nn.ReLU()))
9339*da0073e9SAndroid Build Coastguard Worker            self.assertExportImportModule(m, (torch.randn(2, 2),))
9340*da0073e9SAndroid Build Coastguard Worker
9341*da0073e9SAndroid Build Coastguard Worker    def test_script_sequential_for(self):
9342*da0073e9SAndroid Build Coastguard Worker        class Sub(torch.jit.ScriptModule):
9343*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
9344*da0073e9SAndroid Build Coastguard Worker                super().__init__()
9345*da0073e9SAndroid Build Coastguard Worker                self.weight = nn.Parameter(torch.randn(2))
9346*da0073e9SAndroid Build Coastguard Worker
9347*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
9348*da0073e9SAndroid Build Coastguard Worker            def forward(self, thing):
9349*da0073e9SAndroid Build Coastguard Worker                return self.weight + thing
9350*da0073e9SAndroid Build Coastguard Worker
9351*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
9352*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
9353*da0073e9SAndroid Build Coastguard Worker                super().__init__()
9354*da0073e9SAndroid Build Coastguard Worker                self.mods = nn.Sequential(Sub(), Sub(), Sub())
9355*da0073e9SAndroid Build Coastguard Worker
9356*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
9357*da0073e9SAndroid Build Coastguard Worker            def forward(self, v):
9358*da0073e9SAndroid Build Coastguard Worker                for m in self.mods:
9359*da0073e9SAndroid Build Coastguard Worker                    v = m(v)
9360*da0073e9SAndroid Build Coastguard Worker                return v
9361*da0073e9SAndroid Build Coastguard Worker
9362*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
9363*da0073e9SAndroid Build Coastguard Worker            def forward2(self, v):
9364*da0073e9SAndroid Build Coastguard Worker                return self.mods(v)
9365*da0073e9SAndroid Build Coastguard Worker
9366*da0073e9SAndroid Build Coastguard Worker        with torch.jit.optimized_execution(False):
9367*da0073e9SAndroid Build Coastguard Worker            i = torch.empty(2)
9368*da0073e9SAndroid Build Coastguard Worker            m = M()
9369*da0073e9SAndroid Build Coastguard Worker            o = m(i)
9370*da0073e9SAndroid Build Coastguard Worker            v = i
9371*da0073e9SAndroid Build Coastguard Worker            for sub in m.mods._modules.values():
9372*da0073e9SAndroid Build Coastguard Worker                v = sub(v)
9373*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(o, v)
9374*da0073e9SAndroid Build Coastguard Worker
9375*da0073e9SAndroid Build Coastguard Worker            o2 = m.forward2(i)
9376*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(o2, v)
9377*da0073e9SAndroid Build Coastguard Worker
9378*da0073e9SAndroid Build Coastguard Worker    def test_script_sequential_sliced_iteration(self):
9379*da0073e9SAndroid Build Coastguard Worker        class seq_mod(nn.Module):
9380*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
9381*da0073e9SAndroid Build Coastguard Worker                super().__init__()
9382*da0073e9SAndroid Build Coastguard Worker                self.layers = [nn.ReLU(), nn.ReLU(), nn.ReLU()]
9383*da0073e9SAndroid Build Coastguard Worker                self.layers = nn.Sequential(*self.layers)
9384*da0073e9SAndroid Build Coastguard Worker
9385*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
9386*da0073e9SAndroid Build Coastguard Worker                x = self.layers[0].forward(input)
9387*da0073e9SAndroid Build Coastguard Worker                for layer in self.layers[1:3]:
9388*da0073e9SAndroid Build Coastguard Worker                    x = layer.forward(x)
9389*da0073e9SAndroid Build Coastguard Worker                for layer in self.layers[2:]:
9390*da0073e9SAndroid Build Coastguard Worker                    x = layer.forward(x)
9391*da0073e9SAndroid Build Coastguard Worker                return x
9392*da0073e9SAndroid Build Coastguard Worker
9393*da0073e9SAndroid Build Coastguard Worker        seq = seq_mod()
9394*da0073e9SAndroid Build Coastguard Worker        self.checkModule(seq, [torch.tensor([-2, 1, -1, 2])])
9395*da0073e9SAndroid Build Coastguard Worker
9396*da0073e9SAndroid Build Coastguard Worker    def test_script_sequential_orderdict(self):
9397*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
9398*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
9399*da0073e9SAndroid Build Coastguard Worker                super().__init__()
9400*da0073e9SAndroid Build Coastguard Worker                self.mods = nn.Sequential(OrderedDict([
9401*da0073e9SAndroid Build Coastguard Worker                    ("conv", nn.Conv2d(1, 20, 5)),
9402*da0073e9SAndroid Build Coastguard Worker                    ("relu", nn.ReLU())
9403*da0073e9SAndroid Build Coastguard Worker                ]))
9404*da0073e9SAndroid Build Coastguard Worker
9405*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
9406*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
9407*da0073e9SAndroid Build Coastguard Worker                return self.mods(input)
9408*da0073e9SAndroid Build Coastguard Worker
9409*da0073e9SAndroid Build Coastguard Worker        m = M()
9410*da0073e9SAndroid Build Coastguard Worker        self.assertTrue('mods.conv.weight' in m.state_dict().keys())
9411*da0073e9SAndroid Build Coastguard Worker
9412*da0073e9SAndroid Build Coastguard Worker    def test_script_sequential_multi_output_fail(self):
9413*da0073e9SAndroid Build Coastguard Worker        class Sub(torch.jit.ScriptModule):
9414*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
9415*da0073e9SAndroid Build Coastguard Worker                super().__init__()
9416*da0073e9SAndroid Build Coastguard Worker                self.weight = nn.Parameter(torch.randn(2))
9417*da0073e9SAndroid Build Coastguard Worker
9418*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
9419*da0073e9SAndroid Build Coastguard Worker            def forward(self, thing):
9420*da0073e9SAndroid Build Coastguard Worker                return self.weight + thing
9421*da0073e9SAndroid Build Coastguard Worker
9422*da0073e9SAndroid Build Coastguard Worker        class ReturnMulti(torch.jit.ScriptModule):
9423*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
9424*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
9425*da0073e9SAndroid Build Coastguard Worker                return x, x, x
9426*da0073e9SAndroid Build Coastguard Worker
9427*da0073e9SAndroid Build Coastguard Worker        class HaveSequential(torch.jit.ScriptModule):
9428*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
9429*da0073e9SAndroid Build Coastguard Worker                super().__init__()
9430*da0073e9SAndroid Build Coastguard Worker                self.someseq = nn.Sequential(
9431*da0073e9SAndroid Build Coastguard Worker                    Sub(),
9432*da0073e9SAndroid Build Coastguard Worker                    ReturnMulti(),
9433*da0073e9SAndroid Build Coastguard Worker                    Sub()
9434*da0073e9SAndroid Build Coastguard Worker                )
9435*da0073e9SAndroid Build Coastguard Worker
9436*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
9437*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
9438*da0073e9SAndroid Build Coastguard Worker                return self.someseq(x)
9439*da0073e9SAndroid Build Coastguard Worker
9440*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "(Tensor, Tensor, Tensor)"):
9441*da0073e9SAndroid Build Coastguard Worker            with torch.jit.optimized_execution(False):
9442*da0073e9SAndroid Build Coastguard Worker                hs = HaveSequential()
9443*da0073e9SAndroid Build Coastguard Worker                i = torch.empty(2)
9444*da0073e9SAndroid Build Coastguard Worker                hs(i)
9445*da0073e9SAndroid Build Coastguard Worker
9446*da0073e9SAndroid Build Coastguard Worker    @_tmp_donotuse_dont_inline_everything
9447*da0073e9SAndroid Build Coastguard Worker    def test_script_sequential_in_mod_list(self):
9448*da0073e9SAndroid Build Coastguard Worker        class Sub(torch.jit.ScriptModule):
9449*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
9450*da0073e9SAndroid Build Coastguard Worker                super().__init__()
9451*da0073e9SAndroid Build Coastguard Worker                self.weight = nn.Parameter(torch.randn(2))
9452*da0073e9SAndroid Build Coastguard Worker
9453*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
9454*da0073e9SAndroid Build Coastguard Worker            def forward(self, thing):
9455*da0073e9SAndroid Build Coastguard Worker                return self.weight + thing
9456*da0073e9SAndroid Build Coastguard Worker
9457*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
9458*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
9459*da0073e9SAndroid Build Coastguard Worker                super().__init__()
9460*da0073e9SAndroid Build Coastguard Worker                self.mods = nn.ModuleList([Sub(), nn.Sequential(Sub(), nn.Sequential(Sub(), Sub()), Sub())])
9461*da0073e9SAndroid Build Coastguard Worker
9462*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
9463*da0073e9SAndroid Build Coastguard Worker            def forward(self, v):
9464*da0073e9SAndroid Build Coastguard Worker                for mod in self.mods:
9465*da0073e9SAndroid Build Coastguard Worker                    v = mod(v)
9466*da0073e9SAndroid Build Coastguard Worker                return v
9467*da0073e9SAndroid Build Coastguard Worker
9468*da0073e9SAndroid Build Coastguard Worker        m = M()
9469*da0073e9SAndroid Build Coastguard Worker        graph = str(m.graph)
9470*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(graph.count("prim::CallMethod") == 2)
9471*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("python" not in graph)
9472*da0073e9SAndroid Build Coastguard Worker
9473*da0073e9SAndroid Build Coastguard Worker    @_tmp_donotuse_dont_inline_everything
9474*da0073e9SAndroid Build Coastguard Worker    def test_script_nested_mod_list(self):
9475*da0073e9SAndroid Build Coastguard Worker        class Sub(torch.jit.ScriptModule):
9476*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
9477*da0073e9SAndroid Build Coastguard Worker                super().__init__()
9478*da0073e9SAndroid Build Coastguard Worker                self.weight = nn.Parameter(torch.randn(2))
9479*da0073e9SAndroid Build Coastguard Worker
9480*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
9481*da0073e9SAndroid Build Coastguard Worker            def forward(self, thing):
9482*da0073e9SAndroid Build Coastguard Worker                return self.weight + thing
9483*da0073e9SAndroid Build Coastguard Worker
9484*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
9485*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
9486*da0073e9SAndroid Build Coastguard Worker                super().__init__()
9487*da0073e9SAndroid Build Coastguard Worker                self.mods = nn.ModuleList([nn.ModuleList([Sub()]), nn.Sequential(Sub()), nn.ModuleList([Sub(), Sub()])])
9488*da0073e9SAndroid Build Coastguard Worker
9489*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
9490*da0073e9SAndroid Build Coastguard Worker            def forward(self, v):
9491*da0073e9SAndroid Build Coastguard Worker                for mod in self.mods:
9492*da0073e9SAndroid Build Coastguard Worker                    for m in mod:
9493*da0073e9SAndroid Build Coastguard Worker                        v = m(v)
9494*da0073e9SAndroid Build Coastguard Worker                return v
9495*da0073e9SAndroid Build Coastguard Worker
9496*da0073e9SAndroid Build Coastguard Worker        m = M()
9497*da0073e9SAndroid Build Coastguard Worker        graph = str(m.graph)
9498*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(graph.count("prim::CallMethod") == 4)
9499*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("python" not in graph)
9500*da0073e9SAndroid Build Coastguard Worker
9501*da0073e9SAndroid Build Coastguard Worker    def test_constant_as_attr(self):
9502*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
9503*da0073e9SAndroid Build Coastguard Worker            __constants__ = ['dim']
9504*da0073e9SAndroid Build Coastguard Worker
9505*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
9506*da0073e9SAndroid Build Coastguard Worker                super().__init__()
9507*da0073e9SAndroid Build Coastguard Worker                self.dim = 1
9508*da0073e9SAndroid Build Coastguard Worker
9509*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
9510*da0073e9SAndroid Build Coastguard Worker            def forward(self, v):
9511*da0073e9SAndroid Build Coastguard Worker                return torch.cat([v, v, v], dim=self.dim)
9512*da0073e9SAndroid Build Coastguard Worker        v = torch.zeros(1, 1)
9513*da0073e9SAndroid Build Coastguard Worker        with torch.jit.optimized_execution(False):
9514*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.cat([v, v, v], dim=1), M()(v))
9515*da0073e9SAndroid Build Coastguard Worker
9516*da0073e9SAndroid Build Coastguard Worker    class StarTestSumStarred(torch.nn.Module):
9517*da0073e9SAndroid Build Coastguard Worker        def __init__(self) -> None:
9518*da0073e9SAndroid Build Coastguard Worker            super(TestScript.StarTestSumStarred, self).__init__()
9519*da0073e9SAndroid Build Coastguard Worker
9520*da0073e9SAndroid Build Coastguard Worker        def forward(self, *inputs):
9521*da0073e9SAndroid Build Coastguard Worker            output = inputs[0]
9522*da0073e9SAndroid Build Coastguard Worker            for i in range(1, len(inputs)):
9523*da0073e9SAndroid Build Coastguard Worker                output += inputs[i]
9524*da0073e9SAndroid Build Coastguard Worker            return output
9525*da0073e9SAndroid Build Coastguard Worker
9526*da0073e9SAndroid Build Coastguard Worker    class StarTestReturnThree(torch.nn.Module):
9527*da0073e9SAndroid Build Coastguard Worker        def __init__(self) -> None:
9528*da0073e9SAndroid Build Coastguard Worker            super(TestScript.StarTestReturnThree, self).__init__()
9529*da0073e9SAndroid Build Coastguard Worker
9530*da0073e9SAndroid Build Coastguard Worker        def forward(self, rep):
9531*da0073e9SAndroid Build Coastguard Worker            return rep, rep, rep
9532*da0073e9SAndroid Build Coastguard Worker
9533*da0073e9SAndroid Build Coastguard Worker    def test_script_star_expr(self):
9534*da0073e9SAndroid Build Coastguard Worker
9535*da0073e9SAndroid Build Coastguard Worker        class M2(torch.jit.ScriptModule):
9536*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
9537*da0073e9SAndroid Build Coastguard Worker                super().__init__()
9538*da0073e9SAndroid Build Coastguard Worker                self.m = torch.jit.trace(TestScript.StarTestSumStarred(),
9539*da0073e9SAndroid Build Coastguard Worker                                         (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)))
9540*da0073e9SAndroid Build Coastguard Worker                self.g = torch.jit.trace(TestScript.StarTestReturnThree(), torch.ones(4, 3))
9541*da0073e9SAndroid Build Coastguard Worker
9542*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
9543*da0073e9SAndroid Build Coastguard Worker            def forward(self, rep):
9544*da0073e9SAndroid Build Coastguard Worker                tup = self.g(rep)
9545*da0073e9SAndroid Build Coastguard Worker                return self.m(*tup)
9546*da0073e9SAndroid Build Coastguard Worker
9547*da0073e9SAndroid Build Coastguard Worker        m = M2()
9548*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
9549*da0073e9SAndroid Build Coastguard Worker
9550*da0073e9SAndroid Build Coastguard Worker    def test_script_star_expr_string(self):
9551*da0073e9SAndroid Build Coastguard Worker        class M2(torch.jit.ScriptModule):
9552*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
9553*da0073e9SAndroid Build Coastguard Worker                super().__init__()
9554*da0073e9SAndroid Build Coastguard Worker                self.m = torch.jit.trace(TestScript.StarTestSumStarred(),
9555*da0073e9SAndroid Build Coastguard Worker                                         (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)))
9556*da0073e9SAndroid Build Coastguard Worker                self.g = torch.jit.trace(TestScript.StarTestReturnThree(), torch.ones(4, 3))
9557*da0073e9SAndroid Build Coastguard Worker
9558*da0073e9SAndroid Build Coastguard Worker                self.define('''
9559*da0073e9SAndroid Build Coastguard Worker            def forward(self, rep):
9560*da0073e9SAndroid Build Coastguard Worker                tup = self.g(rep)
9561*da0073e9SAndroid Build Coastguard Worker                return self.m(*tup)
9562*da0073e9SAndroid Build Coastguard Worker                ''')
9563*da0073e9SAndroid Build Coastguard Worker
9564*da0073e9SAndroid Build Coastguard Worker        m = M2()
9565*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
9566*da0073e9SAndroid Build Coastguard Worker
9567*da0073e9SAndroid Build Coastguard Worker    class StarTestSumAndReturnThree(torch.nn.Module):
9568*da0073e9SAndroid Build Coastguard Worker        def __init__(self) -> None:
9569*da0073e9SAndroid Build Coastguard Worker            super(TestScript.StarTestSumAndReturnThree, self).__init__()
9570*da0073e9SAndroid Build Coastguard Worker
9571*da0073e9SAndroid Build Coastguard Worker        def forward(self, *inputs):
9572*da0073e9SAndroid Build Coastguard Worker            output = inputs[0]
9573*da0073e9SAndroid Build Coastguard Worker            for i in range(1, len(inputs)):
9574*da0073e9SAndroid Build Coastguard Worker                output += inputs[i]
9575*da0073e9SAndroid Build Coastguard Worker            return output, output, output
9576*da0073e9SAndroid Build Coastguard Worker
9577*da0073e9SAndroid Build Coastguard Worker    def test_script_star_assign(self):
9578*da0073e9SAndroid Build Coastguard Worker        class M2(torch.jit.ScriptModule):
9579*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
9580*da0073e9SAndroid Build Coastguard Worker                super().__init__()
9581*da0073e9SAndroid Build Coastguard Worker                self.g = torch.jit.trace(TestScript.StarTestSumAndReturnThree(), torch.ones(4, 3))
9582*da0073e9SAndroid Build Coastguard Worker                self.define('''
9583*da0073e9SAndroid Build Coastguard Worker            def forward(self, rep):
9584*da0073e9SAndroid Build Coastguard Worker                head, *tail = self.g(rep)
9585*da0073e9SAndroid Build Coastguard Worker                return head
9586*da0073e9SAndroid Build Coastguard Worker                ''')
9587*da0073e9SAndroid Build Coastguard Worker
9588*da0073e9SAndroid Build Coastguard Worker        m = M2()
9589*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
9590*da0073e9SAndroid Build Coastguard Worker
9591*da0073e9SAndroid Build Coastguard Worker    def test_script_module_star_assign2(self):
9592*da0073e9SAndroid Build Coastguard Worker        class M2(torch.jit.ScriptModule):
9593*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
9594*da0073e9SAndroid Build Coastguard Worker                super().__init__()
9595*da0073e9SAndroid Build Coastguard Worker                self.g = torch.jit.trace(
9596*da0073e9SAndroid Build Coastguard Worker                    TestScript.StarTestSumAndReturnThree(),
9597*da0073e9SAndroid Build Coastguard Worker                    (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)),
9598*da0073e9SAndroid Build Coastguard Worker                    _force_outplace=True)
9599*da0073e9SAndroid Build Coastguard Worker                self.define('''
9600*da0073e9SAndroid Build Coastguard Worker            def forward(self, rep):
9601*da0073e9SAndroid Build Coastguard Worker                *head, tail = self.g(rep, rep, rep)
9602*da0073e9SAndroid Build Coastguard Worker                return tail
9603*da0073e9SAndroid Build Coastguard Worker                ''')
9604*da0073e9SAndroid Build Coastguard Worker
9605*da0073e9SAndroid Build Coastguard Worker        m = M2()
9606*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m(torch.ones(4, 3)), 3 * torch.ones(4, 3))
9607*da0073e9SAndroid Build Coastguard Worker
9608*da0073e9SAndroid Build Coastguard Worker    def test_script_module_star_assign2_inplace(self):
9609*da0073e9SAndroid Build Coastguard Worker        class M2(torch.jit.ScriptModule):
9610*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
9611*da0073e9SAndroid Build Coastguard Worker                super().__init__()
9612*da0073e9SAndroid Build Coastguard Worker                self.g = torch.jit.trace(
9613*da0073e9SAndroid Build Coastguard Worker                    TestScript.StarTestSumAndReturnThree(),
9614*da0073e9SAndroid Build Coastguard Worker                    (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)),
9615*da0073e9SAndroid Build Coastguard Worker                    _force_outplace=False)
9616*da0073e9SAndroid Build Coastguard Worker                self.define('''
9617*da0073e9SAndroid Build Coastguard Worker            def forward(self, rep):
9618*da0073e9SAndroid Build Coastguard Worker                *head, tail = self.g(rep, rep, rep)
9619*da0073e9SAndroid Build Coastguard Worker                return tail
9620*da0073e9SAndroid Build Coastguard Worker                ''')
9621*da0073e9SAndroid Build Coastguard Worker
9622*da0073e9SAndroid Build Coastguard Worker        m = M2()
9623*da0073e9SAndroid Build Coastguard Worker        # since forward() makes three aliases to the input `rep` before passing
9624*da0073e9SAndroid Build Coastguard Worker        # it to StarTestSumAndReturnThree(), in-place behavior will be different
9625*da0073e9SAndroid Build Coastguard Worker        # than the above out of place.
9626*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m(torch.ones(4, 3)), 4 * torch.ones(4, 3))
9627*da0073e9SAndroid Build Coastguard Worker
9628*da0073e9SAndroid Build Coastguard Worker    def test_script_module_star_assign_fail_pythonop(self):
9629*da0073e9SAndroid Build Coastguard Worker
9630*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
9631*da0073e9SAndroid Build Coastguard Worker            class M2(torch.jit.ScriptModule):
9632*da0073e9SAndroid Build Coastguard Worker                def __init__(self) -> None:
9633*da0073e9SAndroid Build Coastguard Worker                    super().__init__()
9634*da0073e9SAndroid Build Coastguard Worker
9635*da0073e9SAndroid Build Coastguard Worker                    @torch.jit.ignore
9636*da0073e9SAndroid Build Coastguard Worker                    def myfunc():
9637*da0073e9SAndroid Build Coastguard Worker                        return torch.zeros(1, 2, 3), torch.zeros(1, 2, 3)
9638*da0073e9SAndroid Build Coastguard Worker
9639*da0073e9SAndroid Build Coastguard Worker                    self.define('''
9640*da0073e9SAndroid Build Coastguard Worker                def forward(self, rep):
9641*da0073e9SAndroid Build Coastguard Worker                    a, *b = myfunc()
9642*da0073e9SAndroid Build Coastguard Worker                    return a
9643*da0073e9SAndroid Build Coastguard Worker                    ''')
9644*da0073e9SAndroid Build Coastguard Worker
9645*da0073e9SAndroid Build Coastguard Worker            m = M2()
9646*da0073e9SAndroid Build Coastguard Worker            m(torch.zeros(4, 3))
9647*da0073e9SAndroid Build Coastguard Worker
9648*da0073e9SAndroid Build Coastguard Worker    def test_script_module_star_assign_fail_builtin(self):
9649*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
9650*da0073e9SAndroid Build Coastguard Worker            class M2(torch.jit.ScriptModule):
9651*da0073e9SAndroid Build Coastguard Worker                def __init__(self) -> None:
9652*da0073e9SAndroid Build Coastguard Worker                    super().__init__()
9653*da0073e9SAndroid Build Coastguard Worker
9654*da0073e9SAndroid Build Coastguard Worker                    self.define('''
9655*da0073e9SAndroid Build Coastguard Worker                def forward(self, rep):
9656*da0073e9SAndroid Build Coastguard Worker                    a, *b = torch.neg(rep)
9657*da0073e9SAndroid Build Coastguard Worker                    return a
9658*da0073e9SAndroid Build Coastguard Worker                    ''')
9659*da0073e9SAndroid Build Coastguard Worker
9660*da0073e9SAndroid Build Coastguard Worker            m = M2()
9661*da0073e9SAndroid Build Coastguard Worker            m(torch.zeros(4, 3))
9662*da0073e9SAndroid Build Coastguard Worker
9663*da0073e9SAndroid Build Coastguard Worker    def test_script_pack_padded_sequence(self):
9664*da0073e9SAndroid Build Coastguard Worker        from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
9665*da0073e9SAndroid Build Coastguard Worker
9666*da0073e9SAndroid Build Coastguard Worker        def pack_padded_pad_packed_script(x, seq_lens):
9667*da0073e9SAndroid Build Coastguard Worker            x = pack_padded_sequence(x, seq_lens)
9668*da0073e9SAndroid Build Coastguard Worker            x, lengths = pad_packed_sequence(x)
9669*da0073e9SAndroid Build Coastguard Worker            return x, lengths
9670*da0073e9SAndroid Build Coastguard Worker
9671*da0073e9SAndroid Build Coastguard Worker        T, B, C = 3, 5, 7
9672*da0073e9SAndroid Build Coastguard Worker        x = torch.ones((T, B, C))
9673*da0073e9SAndroid Build Coastguard Worker        seq_lens = torch.tensor([3, 3, 2, 2, 1])
9674*da0073e9SAndroid Build Coastguard Worker        # set padding value so we can test equivalence
9675*da0073e9SAndroid Build Coastguard Worker        for b in range(B):
9676*da0073e9SAndroid Build Coastguard Worker            if seq_lens[b] < T:
9677*da0073e9SAndroid Build Coastguard Worker                x[seq_lens[b]:, b, :] = 0
9678*da0073e9SAndroid Build Coastguard Worker
9679*da0073e9SAndroid Build Coastguard Worker        eager_seq, eager_lengths = pack_padded_pad_packed_script(x, seq_lens)
9680*da0073e9SAndroid Build Coastguard Worker        with torch._jit_internal._disable_emit_hooks():
9681*da0073e9SAndroid Build Coastguard Worker            scripted_pack_padded_seq = torch.jit.script(pack_padded_pad_packed_script)
9682*da0073e9SAndroid Build Coastguard Worker        script_seq, script_lengths = scripted_pack_padded_seq(x, seq_lens)
9683*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager_seq, script_seq)
9684*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager_lengths, script_lengths)
9685*da0073e9SAndroid Build Coastguard Worker
9686*da0073e9SAndroid Build Coastguard Worker        class ExperimentalLSTM(torch.nn.Module):
9687*da0073e9SAndroid Build Coastguard Worker            def __init__(self, input_dim, hidden_dim):
9688*da0073e9SAndroid Build Coastguard Worker                super().__init__()
9689*da0073e9SAndroid Build Coastguard Worker
9690*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
9691*da0073e9SAndroid Build Coastguard Worker                # type: (Tensor)
9692*da0073e9SAndroid Build Coastguard Worker                packed = pack_padded_sequence(
9693*da0073e9SAndroid Build Coastguard Worker                    input=input, lengths=torch.tensor([1, 2]), enforce_sorted=False
9694*da0073e9SAndroid Build Coastguard Worker                )
9695*da0073e9SAndroid Build Coastguard Worker                output, lengths = pad_packed_sequence(
9696*da0073e9SAndroid Build Coastguard Worker                    sequence=packed, total_length=2
9697*da0073e9SAndroid Build Coastguard Worker                )
9698*da0073e9SAndroid Build Coastguard Worker                # lengths is flipped, so is output
9699*da0073e9SAndroid Build Coastguard Worker                return output[0]
9700*da0073e9SAndroid Build Coastguard Worker
9701*da0073e9SAndroid Build Coastguard Worker        lstm = ExperimentalLSTM(input_dim=2, hidden_dim=2)
9702*da0073e9SAndroid Build Coastguard Worker
9703*da0073e9SAndroid Build Coastguard Worker        with torch._jit_internal._disable_emit_hooks():
9704*da0073e9SAndroid Build Coastguard Worker            self.checkModule(lstm, [torch.ones(2, 2)])
9705*da0073e9SAndroid Build Coastguard Worker
9706*da0073e9SAndroid Build Coastguard Worker    def test_script_pad_sequence_pack_sequence(self):
9707*da0073e9SAndroid Build Coastguard Worker        from torch.nn.utils.rnn import pad_sequence, pack_sequence, pad_packed_sequence
9708*da0073e9SAndroid Build Coastguard Worker
9709*da0073e9SAndroid Build Coastguard Worker        def pad_sequence_func(tensor_list, batch_first=False, padding_value=0.0, padding_side="right"):
9710*da0073e9SAndroid Build Coastguard Worker            # type: (List[Tensor], bool, float, str) -> Tensor
9711*da0073e9SAndroid Build Coastguard Worker            return pad_sequence(tensor_list, batch_first, padding_value, padding_side)
9712*da0073e9SAndroid Build Coastguard Worker
9713*da0073e9SAndroid Build Coastguard Worker        def pack_sequence_func(tensor_list, enforce_sorted=True):
9714*da0073e9SAndroid Build Coastguard Worker            # type: (List[Tensor], bool) -> Tensor
9715*da0073e9SAndroid Build Coastguard Worker            return pad_packed_sequence(pack_sequence(tensor_list, enforce_sorted))[0]
9716*da0073e9SAndroid Build Coastguard Worker
9717*da0073e9SAndroid Build Coastguard Worker        ones3 = torch.ones(3, 5)
9718*da0073e9SAndroid Build Coastguard Worker        ones4 = torch.ones(4, 5)
9719*da0073e9SAndroid Build Coastguard Worker        ones5 = torch.ones(5, 5)
9720*da0073e9SAndroid Build Coastguard Worker        tensor1 = torch.tensor([1, 2, 3])
9721*da0073e9SAndroid Build Coastguard Worker        tensor2 = torch.tensor([4, 5])
9722*da0073e9SAndroid Build Coastguard Worker        tensor3 = torch.tensor([6])
9723*da0073e9SAndroid Build Coastguard Worker        with torch._jit_internal._disable_emit_hooks():
9724*da0073e9SAndroid Build Coastguard Worker            self.checkScript(pad_sequence_func,
9725*da0073e9SAndroid Build Coastguard Worker                             ([ones3, ones4, ones5],))
9726*da0073e9SAndroid Build Coastguard Worker            self.checkScript(pad_sequence_func,
9727*da0073e9SAndroid Build Coastguard Worker                             ([ones3, ones4, ones5], True))
9728*da0073e9SAndroid Build Coastguard Worker            self.checkScript(pad_sequence_func,
9729*da0073e9SAndroid Build Coastguard Worker                             ([ones3, ones4, ones5], True, 2.5))
9730*da0073e9SAndroid Build Coastguard Worker            self.checkScript(pad_sequence_func,
9731*da0073e9SAndroid Build Coastguard Worker                             ([ones3, ones4, ones5], True, 2.5, "left"))
9732*da0073e9SAndroid Build Coastguard Worker            self.checkScript(pad_sequence_func,
9733*da0073e9SAndroid Build Coastguard Worker                             ([ones3, ones4, ones5], False, 2.5, "left"))
9734*da0073e9SAndroid Build Coastguard Worker            self.checkScript(pack_sequence_func,
9735*da0073e9SAndroid Build Coastguard Worker                             ([tensor1, tensor2, tensor3],))
9736*da0073e9SAndroid Build Coastguard Worker            self.checkScript(pack_sequence_func,
9737*da0073e9SAndroid Build Coastguard Worker                             ([tensor1, tensor2, tensor3], False))
9738*da0073e9SAndroid Build Coastguard Worker
9739*da0073e9SAndroid Build Coastguard Worker    def test_script_get_tracing_state(self):
9740*da0073e9SAndroid Build Coastguard Worker        def test_if_tracing(x):
9741*da0073e9SAndroid Build Coastguard Worker            if torch._C._get_tracing_state():
9742*da0073e9SAndroid Build Coastguard Worker                return x + 1
9743*da0073e9SAndroid Build Coastguard Worker            else:
9744*da0073e9SAndroid Build Coastguard Worker                return x - 1
9745*da0073e9SAndroid Build Coastguard Worker
9746*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(3, 3)
9747*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_if_tracing, (inp,))
9748*da0073e9SAndroid Build Coastguard Worker
9749*da0073e9SAndroid Build Coastguard Worker    def test_script_is_tracing(self):
9750*da0073e9SAndroid Build Coastguard Worker        def test_is_tracing(x):
9751*da0073e9SAndroid Build Coastguard Worker            if torch.jit.is_tracing():
9752*da0073e9SAndroid Build Coastguard Worker                return x + 1
9753*da0073e9SAndroid Build Coastguard Worker            else:
9754*da0073e9SAndroid Build Coastguard Worker                return x - 1
9755*da0073e9SAndroid Build Coastguard Worker
9756*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(3, 3)
9757*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_is_tracing, (inp,))
9758*da0073e9SAndroid Build Coastguard Worker
9759*da0073e9SAndroid Build Coastguard Worker    def test_is_scripting(self):
9760*da0073e9SAndroid Build Coastguard Worker        def foo():
9761*da0073e9SAndroid Build Coastguard Worker            return torch.jit.is_scripting()
9762*da0073e9SAndroid Build Coastguard Worker
9763*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(foo())
9764*da0073e9SAndroid Build Coastguard Worker        scripted = torch.jit.script(foo)
9765*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(scripted())
9766*da0073e9SAndroid Build Coastguard Worker
9767*da0073e9SAndroid Build Coastguard Worker    def test_comment_ignore_indent(self):
9768*da0073e9SAndroid Build Coastguard Worker        class Model(torch.nn.Module):
9769*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
9770*da0073e9SAndroid Build Coastguard Worker    # useless comment that is not indented correctly  # noqa: E115
9771*da0073e9SAndroid Build Coastguard Worker                super().__init__()
9772*da0073e9SAndroid Build Coastguard Worker
9773*da0073e9SAndroid Build Coastguard Worker            def forward(self):
9774*da0073e9SAndroid Build Coastguard Worker                return 5
9775*da0073e9SAndroid Build Coastguard Worker
9776*da0073e9SAndroid Build Coastguard Worker        # should compile without an error
9777*da0073e9SAndroid Build Coastguard Worker        self.checkModule(Model(), ())
9778*da0073e9SAndroid Build Coastguard Worker
9779*da0073e9SAndroid Build Coastguard Worker    def test_script_outputs(self):
9780*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
9781*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
9782*da0073e9SAndroid Build Coastguard Worker            def foo(a):
9783*da0073e9SAndroid Build Coastguard Worker                c, d = a + a
9784*da0073e9SAndroid Build Coastguard Worker                return c + d
9785*da0073e9SAndroid Build Coastguard Worker
9786*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
9787*da0073e9SAndroid Build Coastguard Worker        def return3():
9788*da0073e9SAndroid Build Coastguard Worker            return 1, 2, 3
9789*da0073e9SAndroid Build Coastguard Worker
9790*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "too many values to unpack"):
9791*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
9792*da0073e9SAndroid Build Coastguard Worker            def bind2():
9793*da0073e9SAndroid Build Coastguard Worker                a, b = return3()
9794*da0073e9SAndroid Build Coastguard Worker                print(a)
9795*da0073e9SAndroid Build Coastguard Worker                print(b)
9796*da0073e9SAndroid Build Coastguard Worker
9797*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
9798*da0073e9SAndroid Build Coastguard Worker    def test_script_get_device_cuda(self):
9799*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
9800*da0073e9SAndroid Build Coastguard Worker        def foo(a):
9801*da0073e9SAndroid Build Coastguard Worker            return a.get_device()
9802*da0073e9SAndroid Build Coastguard Worker
9803*da0073e9SAndroid Build Coastguard Worker        v = torch.randn(1, device='cuda')
9804*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(v), 0)
9805*da0073e9SAndroid Build Coastguard Worker
9806*da0073e9SAndroid Build Coastguard Worker    def test_script_chunk(self):
9807*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
9808*da0073e9SAndroid Build Coastguard Worker        def foo(a):
9809*da0073e9SAndroid Build Coastguard Worker            b, c = torch.chunk(a, dim=0, chunks=2)
9810*da0073e9SAndroid Build Coastguard Worker            return b
9811*da0073e9SAndroid Build Coastguard Worker        v = torch.rand(10, 3)
9812*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.chunk(v, dim=0, chunks=2)[0], foo(v))
9813*da0073e9SAndroid Build Coastguard Worker
9814*da0073e9SAndroid Build Coastguard Worker    def test_script_copy(self):
9815*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
9816*da0073e9SAndroid Build Coastguard Worker            __annotations__ = {
9817*da0073e9SAndroid Build Coastguard Worker                "val": Optional[torch.Tensor]
9818*da0073e9SAndroid Build Coastguard Worker            }
9819*da0073e9SAndroid Build Coastguard Worker
9820*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
9821*da0073e9SAndroid Build Coastguard Worker                super().__init__()
9822*da0073e9SAndroid Build Coastguard Worker                self.val = None
9823*da0073e9SAndroid Build Coastguard Worker
9824*da0073e9SAndroid Build Coastguard Worker            def some_method(self):
9825*da0073e9SAndroid Build Coastguard Worker                return 3
9826*da0073e9SAndroid Build Coastguard Worker
9827*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
9828*da0073e9SAndroid Build Coastguard Worker                # type: (Tensor) -> Tensor
9829*da0073e9SAndroid Build Coastguard Worker                self.val = x + self.some_method()
9830*da0073e9SAndroid Build Coastguard Worker                return x
9831*da0073e9SAndroid Build Coastguard Worker
9832*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(M())
9833*da0073e9SAndroid Build Coastguard Worker        # test copy
9834*da0073e9SAndroid Build Coastguard Worker        copy.copy(m)
9835*da0073e9SAndroid Build Coastguard Worker        copy.deepcopy(m)
9836*da0073e9SAndroid Build Coastguard Worker
9837*da0073e9SAndroid Build Coastguard Worker    def test_script_forward_method_replacement(self):
9838*da0073e9SAndroid Build Coastguard Worker        # We want to support the use case of attaching a different `forward` method
9839*da0073e9SAndroid Build Coastguard Worker        class LowLevelModule(torch.nn.Module):
9840*da0073e9SAndroid Build Coastguard Worker            def forward(self, input: torch.Tensor):
9841*da0073e9SAndroid Build Coastguard Worker                # Generic forward dispatch
9842*da0073e9SAndroid Build Coastguard Worker                return self.forward_pytorch(input) * 2
9843*da0073e9SAndroid Build Coastguard Worker
9844*da0073e9SAndroid Build Coastguard Worker        class TestModule(LowLevelModule):
9845*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
9846*da0073e9SAndroid Build Coastguard Worker                super().__init__()
9847*da0073e9SAndroid Build Coastguard Worker                # Replace the forward method
9848*da0073e9SAndroid Build Coastguard Worker                self.forward = types.MethodType(LowLevelModule.forward, self)
9849*da0073e9SAndroid Build Coastguard Worker
9850*da0073e9SAndroid Build Coastguard Worker            def forward_pytorch(self, input: torch.Tensor):
9851*da0073e9SAndroid Build Coastguard Worker                return torch.tensor(123)
9852*da0073e9SAndroid Build Coastguard Worker
9853*da0073e9SAndroid Build Coastguard Worker            def forward(self, input: torch.Tensor):
9854*da0073e9SAndroid Build Coastguard Worker                # Should not use this forward method
9855*da0073e9SAndroid Build Coastguard Worker                raise AssertionError("This method should not be used")
9856*da0073e9SAndroid Build Coastguard Worker                return self.forward_pytorch(input)
9857*da0073e9SAndroid Build Coastguard Worker
9858*da0073e9SAndroid Build Coastguard Worker        m = TestModule()
9859*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m(torch.tensor(1)), torch.tensor(246))
9860*da0073e9SAndroid Build Coastguard Worker
9861*da0073e9SAndroid Build Coastguard Worker        m_scripted = torch.jit.script(m)
9862*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m_scripted(torch.tensor(1)), torch.tensor(246))
9863*da0073e9SAndroid Build Coastguard Worker
9864*da0073e9SAndroid Build Coastguard Worker    def test_python_call_non_tensor(self):
9865*da0073e9SAndroid Build Coastguard Worker        def foo(a, b, c):
9866*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor, int, Tuple[Tensor, int]) -> Tuple[int, Tensor]
9867*da0073e9SAndroid Build Coastguard Worker            d, e = c
9868*da0073e9SAndroid Build Coastguard Worker            return b + e, a + d
9869*da0073e9SAndroid Build Coastguard Worker
9870*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
9871*da0073e9SAndroid Build Coastguard Worker        def bar():
9872*da0073e9SAndroid Build Coastguard Worker            x = torch.ones(3, 4)
9873*da0073e9SAndroid Build Coastguard Worker            a, b = foo(x, 3, (x, 3))
9874*da0073e9SAndroid Build Coastguard Worker            return a, b
9875*da0073e9SAndroid Build Coastguard Worker
9876*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((6, torch.ones(3, 4) + 1), bar())
9877*da0073e9SAndroid Build Coastguard Worker
9878*da0073e9SAndroid Build Coastguard Worker    def test_python_call_non_tensor_wrong(self):
9879*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"but instead got value of type tuple"):
9880*da0073e9SAndroid Build Coastguard Worker            @torch.jit.ignore
9881*da0073e9SAndroid Build Coastguard Worker            def foo():
9882*da0073e9SAndroid Build Coastguard Worker                # type: () -> Tensor
9883*da0073e9SAndroid Build Coastguard Worker                return ((3, 4),)  # noqa: T484
9884*da0073e9SAndroid Build Coastguard Worker
9885*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
9886*da0073e9SAndroid Build Coastguard Worker            def bar():
9887*da0073e9SAndroid Build Coastguard Worker                return foo()
9888*da0073e9SAndroid Build Coastguard Worker
9889*da0073e9SAndroid Build Coastguard Worker            bar()
9890*da0073e9SAndroid Build Coastguard Worker
9891*da0073e9SAndroid Build Coastguard Worker    def test_if_different_type(self):
9892*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "c0 is set to type "
9893*da0073e9SAndroid Build Coastguard Worker                                    "int in the true branch and type "
9894*da0073e9SAndroid Build Coastguard Worker                                    "float in the false branch"):
9895*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
9896*da0073e9SAndroid Build Coastguard Worker            def diff_type_used():
9897*da0073e9SAndroid Build Coastguard Worker                if 1 == 2:
9898*da0073e9SAndroid Build Coastguard Worker                    c0 = 1
9899*da0073e9SAndroid Build Coastguard Worker                else:
9900*da0073e9SAndroid Build Coastguard Worker                    c0 = 1.0
9901*da0073e9SAndroid Build Coastguard Worker                return c0
9902*da0073e9SAndroid Build Coastguard Worker
9903*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Variable 'c0' previously had type float"):
9904*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
9905*da0073e9SAndroid Build Coastguard Worker            def diff_existing_type(x):
9906*da0073e9SAndroid Build Coastguard Worker                c0 = 1.0
9907*da0073e9SAndroid Build Coastguard Worker                if 1 == 2:
9908*da0073e9SAndroid Build Coastguard Worker                    c0 = 1
9909*da0073e9SAndroid Build Coastguard Worker                    print(x)
9910*da0073e9SAndroid Build Coastguard Worker                return x
9911*da0073e9SAndroid Build Coastguard Worker
9912*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
9913*da0073e9SAndroid Build Coastguard Worker        def diff_type_unused():
9914*da0073e9SAndroid Build Coastguard Worker            if 1 == 1:
9915*da0073e9SAndroid Build Coastguard Worker                c0 = 1
9916*da0073e9SAndroid Build Coastguard Worker                print(c0)
9917*da0073e9SAndroid Build Coastguard Worker            else:
9918*da0073e9SAndroid Build Coastguard Worker                c0 = 1.0
9919*da0073e9SAndroid Build Coastguard Worker                print(c0)
9920*da0073e9SAndroid Build Coastguard Worker            return 1
9921*da0073e9SAndroid Build Coastguard Worker
9922*da0073e9SAndroid Build Coastguard Worker    def test_if_not_defined_error(self):
9923*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "c0 is not defined in the false branch"):
9924*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
9925*da0073e9SAndroid Build Coastguard Worker            def test():
9926*da0073e9SAndroid Build Coastguard Worker                if 1 == 1:
9927*da0073e9SAndroid Build Coastguard Worker                    c0 = 1
9928*da0073e9SAndroid Build Coastguard Worker                return c0
9929*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "c0 is not defined in the true branch"):
9930*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
9931*da0073e9SAndroid Build Coastguard Worker            def test2():
9932*da0073e9SAndroid Build Coastguard Worker                if 1 == 1:
9933*da0073e9SAndroid Build Coastguard Worker                    pass
9934*da0073e9SAndroid Build Coastguard Worker                else:
9935*da0073e9SAndroid Build Coastguard Worker                    c0 = 1
9936*da0073e9SAndroid Build Coastguard Worker                return c0
9937*da0073e9SAndroid Build Coastguard Worker
9938*da0073e9SAndroid Build Coastguard Worker    def test_if_list_cat(self):
9939*da0073e9SAndroid Build Coastguard Worker        # testing that different length lists don't throw error on cat in shape prop
9940*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
9941*da0073e9SAndroid Build Coastguard Worker        def test_list(x):
9942*da0073e9SAndroid Build Coastguard Worker            if bool(x.sum() < 1):
9943*da0073e9SAndroid Build Coastguard Worker                c = [x, x]
9944*da0073e9SAndroid Build Coastguard Worker            else:
9945*da0073e9SAndroid Build Coastguard Worker                c = [x, x, x]
9946*da0073e9SAndroid Build Coastguard Worker            return torch.cat(c)
9947*da0073e9SAndroid Build Coastguard Worker
9948*da0073e9SAndroid Build Coastguard Worker        b = torch.zeros(2, 4)
9949*da0073e9SAndroid Build Coastguard Worker        _propagate_shapes(test_list.graph, (b,), False)
9950*da0073e9SAndroid Build Coastguard Worker
9951*da0073e9SAndroid Build Coastguard Worker    def test_if_supertype(self):
9952*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
9953*da0073e9SAndroid Build Coastguard Worker        def tensor_unifying(x, y, z):
9954*da0073e9SAndroid Build Coastguard Worker            # testing dynamic is appropriately set for y and z
9955*da0073e9SAndroid Build Coastguard Worker            if bool(x):
9956*da0073e9SAndroid Build Coastguard Worker                x, y, z = x + 1, y, z
9957*da0073e9SAndroid Build Coastguard Worker            else:
9958*da0073e9SAndroid Build Coastguard Worker                x, y, z = x + 1, x, y
9959*da0073e9SAndroid Build Coastguard Worker
9960*da0073e9SAndroid Build Coastguard Worker            return x, y, z
9961*da0073e9SAndroid Build Coastguard Worker
9962*da0073e9SAndroid Build Coastguard Worker        a = torch.zeros(2, 2, dtype=torch.float)
9963*da0073e9SAndroid Build Coastguard Worker        b = torch.zeros(2, 4, dtype=torch.long)
9964*da0073e9SAndroid Build Coastguard Worker        c = torch.zeros(2, 4, dtype=torch.float)
9965*da0073e9SAndroid Build Coastguard Worker
9966*da0073e9SAndroid Build Coastguard Worker        graph = _propagate_shapes(tensor_unifying.graph, (a, b, c), False)
9967*da0073e9SAndroid Build Coastguard Worker        if_outputs = list(graph.findNode("prim::If").outputs())
9968*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(if_outputs[0].type().str() == "Float(*, *, requires_grad=0, device=cpu)")
9969*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(if_outputs[1].type().str() == "Tensor(*, *, requires_grad=0, device=cpu)")
9970*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(if_outputs[2].type().str() == "Tensor(*, *, requires_grad=0, device=cpu)")
9971*da0073e9SAndroid Build Coastguard Worker
9972*da0073e9SAndroid Build Coastguard Worker    def test_list_unify(self):
9973*da0073e9SAndroid Build Coastguard Worker        # allowing a unififed int?[] would cause a runtime error b/c
9974*da0073e9SAndroid Build Coastguard Worker        # the index operation expects int?[] to be a generic list,
9975*da0073e9SAndroid Build Coastguard Worker        # but in the true branch the IValue will be a int list
9976*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "int[] in the true branch and type None[]"):
9977*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
9978*da0073e9SAndroid Build Coastguard Worker            def list_optional_fails(x):
9979*da0073e9SAndroid Build Coastguard Worker                # type: (bool) -> Optional[int]
9980*da0073e9SAndroid Build Coastguard Worker                if x:
9981*da0073e9SAndroid Build Coastguard Worker                    y = [1]
9982*da0073e9SAndroid Build Coastguard Worker                else:
9983*da0073e9SAndroid Build Coastguard Worker                    y = [None]  # noqa: T484
9984*da0073e9SAndroid Build Coastguard Worker                return y[0]
9985*da0073e9SAndroid Build Coastguard Worker
9986*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
9987*da0073e9SAndroid Build Coastguard Worker        def list_tensors(x):
9988*da0073e9SAndroid Build Coastguard Worker            # type: (bool) -> Tuple[Tensor, List[Tensor]]
9989*da0073e9SAndroid Build Coastguard Worker            if x:
9990*da0073e9SAndroid Build Coastguard Worker                a = torch.zeros([1, 1])
9991*da0073e9SAndroid Build Coastguard Worker                y = [a]
9992*da0073e9SAndroid Build Coastguard Worker            else:
9993*da0073e9SAndroid Build Coastguard Worker                a = torch.zeros([1, 2])
9994*da0073e9SAndroid Build Coastguard Worker                y = [a]
9995*da0073e9SAndroid Build Coastguard Worker            return a, y
9996*da0073e9SAndroid Build Coastguard Worker
9997*da0073e9SAndroid Build Coastguard Worker        self.run_pass('constant_propagation', list_tensors.graph)
9998*da0073e9SAndroid Build Coastguard Worker        m = self.createFunctionFromGraph(list_tensors.graph)
9999*da0073e9SAndroid Build Coastguard Worker        # testing that tensor type of lists is unified
10000*da0073e9SAndroid Build Coastguard Worker        self.getExportImportCopy(m)
10001*da0073e9SAndroid Build Coastguard Worker
10002*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
10003*da0073e9SAndroid Build Coastguard Worker    @_inline_everything
10004*da0073e9SAndroid Build Coastguard Worker    def test_import_constants_not_specialized(self):
10005*da0073e9SAndroid Build Coastguard Worker        class Mod(torch.nn.Module):
10006*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
10007*da0073e9SAndroid Build Coastguard Worker                return torch.cat(2 * [x], dim=0)
10008*da0073e9SAndroid Build Coastguard Worker
10009*da0073e9SAndroid Build Coastguard Worker        class ScriptMod(torch.jit.ScriptModule):
10010*da0073e9SAndroid Build Coastguard Worker            def __init__(self, mod):
10011*da0073e9SAndroid Build Coastguard Worker                super().__init__()
10012*da0073e9SAndroid Build Coastguard Worker                x = torch.zeros(1, 3)
10013*da0073e9SAndroid Build Coastguard Worker                mod_fn = lambda : mod(x)  # noqa: E731
10014*da0073e9SAndroid Build Coastguard Worker                self.mod = torch.jit.trace(mod_fn, ())
10015*da0073e9SAndroid Build Coastguard Worker
10016*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
10017*da0073e9SAndroid Build Coastguard Worker            def forward(self):
10018*da0073e9SAndroid Build Coastguard Worker                return self.mod()
10019*da0073e9SAndroid Build Coastguard Worker
10020*da0073e9SAndroid Build Coastguard Worker        cm = ScriptMod(Mod())
10021*da0073e9SAndroid Build Coastguard Worker        # specialized tensor in graph
10022*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("Float(1, 3, strides=[3, 1], requires_grad=0, device=cpu)").run(cm.forward.graph)
10023*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO()
10024*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(cm, buffer)
10025*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
10026*da0073e9SAndroid Build Coastguard Worker        # when tensor is loaded as constant it isnt specialized
10027*da0073e9SAndroid Build Coastguard Worker        cm_load = torch.jit.load(buffer)
10028*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("Float(1, 3)").run(cm_load.forward.graph)
10029*da0073e9SAndroid Build Coastguard Worker
10030*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
10031*da0073e9SAndroid Build Coastguard Worker    def test_type_annotations_repeated_list(self):
10032*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
10033*da0073e9SAndroid Build Coastguard Worker        def float_fn(x, y):
10034*da0073e9SAndroid Build Coastguard Worker            # type: (float, BroadcastingList3[float]) -> List[float]
10035*da0073e9SAndroid Build Coastguard Worker            return y
10036*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(float_fn(2.0, 1.0), float_fn(2.0, [1.0, 1.0, 1.0]))
10037*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(float_fn(2.0, 1.0), float_fn(2.0, (1.0, 1.0, 1.0)))
10038*da0073e9SAndroid Build Coastguard Worker
10039*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
10040*da0073e9SAndroid Build Coastguard Worker        def float_fn_call():
10041*da0073e9SAndroid Build Coastguard Worker            print(float_fn(1.0, 1.0))
10042*da0073e9SAndroid Build Coastguard Worker            print(float_fn(1.0, (1.0, 1.0, 1.0)))
10043*da0073e9SAndroid Build Coastguard Worker
10044*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
10045*da0073e9SAndroid Build Coastguard Worker        def int_fn(x):
10046*da0073e9SAndroid Build Coastguard Worker            # type: (BroadcastingList3[int]) -> List[int]
10047*da0073e9SAndroid Build Coastguard Worker            return x
10048*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(int_fn(1), int_fn([1, 1, 1]))
10049*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(int_fn(1), int_fn((1, 1, 1)))
10050*da0073e9SAndroid Build Coastguard Worker
10051*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
10052*da0073e9SAndroid Build Coastguard Worker        def int_fn_call():
10053*da0073e9SAndroid Build Coastguard Worker            print(int_fn(1))
10054*da0073e9SAndroid Build Coastguard Worker            print(int_fn((1, 1, 1)))
10055*da0073e9SAndroid Build Coastguard Worker
10056*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "must be a positive integer:"):
10057*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script  # noqa: T484
10058*da0073e9SAndroid Build Coastguard Worker            def fn(x):
10059*da0073e9SAndroid Build Coastguard Worker                # type: (BroadcastingListx[int]) -> List[int]  # noqa: T484
10060*da0073e9SAndroid Build Coastguard Worker                return x
10061*da0073e9SAndroid Build Coastguard Worker
10062*da0073e9SAndroid Build Coastguard Worker        # using CU so that flake8 error on int[2] is not raised (noqa not working)
10063*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Unknown type constructor"):
10064*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit('''
10065*da0073e9SAndroid Build Coastguard Worker                def nested(x, y):
10066*da0073e9SAndroid Build Coastguard Worker                    # type: (int, Tuple[int, int[2]]) -> List[int]
10067*da0073e9SAndroid Build Coastguard Worker                    return x  # noqa: T484
10068*da0073e9SAndroid Build Coastguard Worker            ''')
10069*da0073e9SAndroid Build Coastguard Worker
10070*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
10071*da0073e9SAndroid Build Coastguard Worker        def f(x: BroadcastingList2[int]):
10072*da0073e9SAndroid Build Coastguard Worker            return x
10073*da0073e9SAndroid Build Coastguard Worker
10074*da0073e9SAndroid Build Coastguard Worker        out = f(1)
10075*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(out[0], int))
10076*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, [1, 1])
10077*da0073e9SAndroid Build Coastguard Worker
10078*da0073e9SAndroid Build Coastguard Worker    def test_ntuple_builtins(self):
10079*da0073e9SAndroid Build Coastguard Worker        from torch.nn.modules.utils import _single, _pair, _triple, _quadruple
10080*da0073e9SAndroid Build Coastguard Worker
10081*da0073e9SAndroid Build Coastguard Worker        def test_ints():
10082*da0073e9SAndroid Build Coastguard Worker            return _single(1), _pair(2), _triple(3), _quadruple(4)
10083*da0073e9SAndroid Build Coastguard Worker
10084*da0073e9SAndroid Build Coastguard Worker        def test_floats():
10085*da0073e9SAndroid Build Coastguard Worker            return _single(1), _pair(2.1), _triple(3.1), _quadruple(4.1)
10086*da0073e9SAndroid Build Coastguard Worker
10087*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_ints, ())
10088*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_floats, ())
10089*da0073e9SAndroid Build Coastguard Worker
10090*da0073e9SAndroid Build Coastguard Worker    def test_embedding_renorm_grad_error(self):
10091*da0073e9SAndroid Build Coastguard Worker        # Testing that the builtin call to embedding_renorm_ correctly throws
10092*da0073e9SAndroid Build Coastguard Worker        # Error when .backward() is called on its input
10093*da0073e9SAndroid Build Coastguard Worker
10094*da0073e9SAndroid Build Coastguard Worker        def embedding_norm(input, embedding_matrix, max_norm):
10095*da0073e9SAndroid Build Coastguard Worker            F.embedding(input, embedding_matrix, max_norm=0.01)
10096*da0073e9SAndroid Build Coastguard Worker
10097*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
10098*da0073e9SAndroid Build Coastguard Worker        def embedding_norm_script(input, embedding_matrix, max_norm):
10099*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor, Tensor, float) -> None
10100*da0073e9SAndroid Build Coastguard Worker            F.embedding(input, embedding_matrix, max_norm=0.01)
10101*da0073e9SAndroid Build Coastguard Worker
10102*da0073e9SAndroid Build Coastguard Worker        for _ in [embedding_norm, embedding_norm_script]:
10103*da0073e9SAndroid Build Coastguard Worker            input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]])
10104*da0073e9SAndroid Build Coastguard Worker            embedding_matrix = torch.randn(10, 3)
10105*da0073e9SAndroid Build Coastguard Worker
10106*da0073e9SAndroid Build Coastguard Worker            var1 = torch.randn(10, 3, requires_grad=True)
10107*da0073e9SAndroid Build Coastguard Worker            var2 = var1.detach().requires_grad_()
10108*da0073e9SAndroid Build Coastguard Worker            output1 = var1 * embedding_matrix
10109*da0073e9SAndroid Build Coastguard Worker            output2 = var2 * embedding_matrix
10110*da0073e9SAndroid Build Coastguard Worker
10111*da0073e9SAndroid Build Coastguard Worker            output1.sum().backward()
10112*da0073e9SAndroid Build Coastguard Worker
10113*da0073e9SAndroid Build Coastguard Worker            ignore = F.embedding(input, embedding_matrix, max_norm=0.01)
10114*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "modified"):
10115*da0073e9SAndroid Build Coastguard Worker                output2.sum().backward()
10116*da0073e9SAndroid Build Coastguard Worker
10117*da0073e9SAndroid Build Coastguard Worker    def test_type_annotations(self):
10118*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
10119*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
10120*da0073e9SAndroid Build Coastguard Worker            return x, x * 2, x * 3
10121*da0073e9SAndroid Build Coastguard Worker
10122*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"need 4 values .* found only 3"):
10123*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
10124*da0073e9SAndroid Build Coastguard Worker            def script_fn(x):
10125*da0073e9SAndroid Build Coastguard Worker                x, y, z, w = fn(x, x)
10126*da0073e9SAndroid Build Coastguard Worker
10127*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"too many values .* need 2 but found 3"):
10128*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
10129*da0073e9SAndroid Build Coastguard Worker            def script_fn2(x):
10130*da0073e9SAndroid Build Coastguard Worker                x, y = fn(x, x)
10131*da0073e9SAndroid Build Coastguard Worker
10132*da0073e9SAndroid Build Coastguard Worker        def fn_unpack(x):
10133*da0073e9SAndroid Build Coastguard Worker            y, z, w = fn(x, x)
10134*da0073e9SAndroid Build Coastguard Worker            return y
10135*da0073e9SAndroid Build Coastguard Worker
10136*da0073e9SAndroid Build Coastguard Worker        def fn_index(x):
10137*da0073e9SAndroid Build Coastguard Worker            q = fn(x, x)
10138*da0073e9SAndroid Build Coastguard Worker            return x
10139*da0073e9SAndroid Build Coastguard Worker
10140*da0073e9SAndroid Build Coastguard Worker        def fn_string(str, strpair):
10141*da0073e9SAndroid Build Coastguard Worker            # type: (str, Tuple[str, str]) -> Tuple[str, int, str, str]
10142*da0073e9SAndroid Build Coastguard Worker            str1, str2 = strpair
10143*da0073e9SAndroid Build Coastguard Worker            return str, 2, str1, str2
10144*da0073e9SAndroid Build Coastguard Worker
10145*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2, 2)
10146*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn_unpack, (x,), optimize=True)
10147*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn_index, (x,), optimize=True)
10148*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn_string, ("1", ("3", "4")), optimize=True)
10149*da0073e9SAndroid Build Coastguard Worker
10150*da0073e9SAndroid Build Coastguard Worker    def test_type_annotations_varargs(self):
10151*da0073e9SAndroid Build Coastguard Worker        @torch.jit.ignore
10152*da0073e9SAndroid Build Coastguard Worker        def fn_varargs(x, *args):
10153*da0073e9SAndroid Build Coastguard Worker            return args[0] if args else x
10154*da0073e9SAndroid Build Coastguard Worker
10155*da0073e9SAndroid Build Coastguard Worker        def fn1(x, y, z):
10156*da0073e9SAndroid Build Coastguard Worker            return fn_varargs(x)
10157*da0073e9SAndroid Build Coastguard Worker
10158*da0073e9SAndroid Build Coastguard Worker        def fn2(x, y, z):
10159*da0073e9SAndroid Build Coastguard Worker            return fn_varargs(x, y)
10160*da0073e9SAndroid Build Coastguard Worker
10161*da0073e9SAndroid Build Coastguard Worker        def fn3(x, y, z):
10162*da0073e9SAndroid Build Coastguard Worker            return fn_varargs(x, y, z)
10163*da0073e9SAndroid Build Coastguard Worker
10164*da0073e9SAndroid Build Coastguard Worker        x, y, z = (torch.randn(2, 2) for _ in range(3))
10165*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn1, (x, y, z), optimize=True)
10166*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn2, (x, y, z), optimize=True)
10167*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn3, (x, y, z), optimize=True)
10168*da0073e9SAndroid Build Coastguard Worker
10169*da0073e9SAndroid Build Coastguard Worker    def test_type_annotation_py3(self):
10170*da0073e9SAndroid Build Coastguard Worker        code = dedent("""
10171*da0073e9SAndroid Build Coastguard Worker        import torch
10172*da0073e9SAndroid Build Coastguard Worker        from torch import Tensor
10173*da0073e9SAndroid Build Coastguard Worker        from typing import Tuple
10174*da0073e9SAndroid Build Coastguard Worker
10175*da0073e9SAndroid Build Coastguard Worker        def fn(x : torch.Tensor, y : Tensor, z) -> Tuple[Tensor, Tensor, Tensor]:
10176*da0073e9SAndroid Build Coastguard Worker            return (x, y + z, z)
10177*da0073e9SAndroid Build Coastguard Worker        """)
10178*da0073e9SAndroid Build Coastguard Worker
10179*da0073e9SAndroid Build Coastguard Worker        with tempfile.TemporaryDirectory() as tmp_dir:
10180*da0073e9SAndroid Build Coastguard Worker            script_path = os.path.join(tmp_dir, 'script.py')
10181*da0073e9SAndroid Build Coastguard Worker            with open(script_path, 'w') as f:
10182*da0073e9SAndroid Build Coastguard Worker                f.write(code)
10183*da0073e9SAndroid Build Coastguard Worker            fn = get_fn('test_type_annotation_py3', script_path)
10184*da0073e9SAndroid Build Coastguard Worker            fn = torch.jit.ignore(fn)
10185*da0073e9SAndroid Build Coastguard Worker
10186*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, r"Expected a value of type 'Tensor' for argument"
10187*da0073e9SAndroid Build Coastguard Worker                                                      r" 'x' but instead found type 'Tuple\[Tensor,"):
10188*da0073e9SAndroid Build Coastguard Worker                @torch.jit.script
10189*da0073e9SAndroid Build Coastguard Worker                def bad_fn(x):
10190*da0073e9SAndroid Build Coastguard Worker                    x, y = fn((x, x), x, x)
10191*da0073e9SAndroid Build Coastguard Worker                    return y
10192*da0073e9SAndroid Build Coastguard Worker
10193*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, r"too many values .* need 2 but found 3"):
10194*da0073e9SAndroid Build Coastguard Worker                @torch.jit.script
10195*da0073e9SAndroid Build Coastguard Worker                def bad_fn2(x):
10196*da0073e9SAndroid Build Coastguard Worker                    x, y = fn(x, x, x)
10197*da0073e9SAndroid Build Coastguard Worker                    return y
10198*da0073e9SAndroid Build Coastguard Worker
10199*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, r"need 4 values .* found only 3"):
10200*da0073e9SAndroid Build Coastguard Worker                @torch.jit.script
10201*da0073e9SAndroid Build Coastguard Worker                def bad_fn3(x):
10202*da0073e9SAndroid Build Coastguard Worker                    x, y, z, w = fn(x, x, x)
10203*da0073e9SAndroid Build Coastguard Worker                    return y
10204*da0073e9SAndroid Build Coastguard Worker
10205*da0073e9SAndroid Build Coastguard Worker            def good_fn(x):
10206*da0073e9SAndroid Build Coastguard Worker                y, z, w = fn(x, x, x)
10207*da0073e9SAndroid Build Coastguard Worker                return y, z, w
10208*da0073e9SAndroid Build Coastguard Worker
10209*da0073e9SAndroid Build Coastguard Worker            self.checkScript(good_fn, (torch.ones(2, 2),), optimize=True)
10210*da0073e9SAndroid Build Coastguard Worker
10211*da0073e9SAndroid Build Coastguard Worker    def test_type_annotation_module(self):
10212*da0073e9SAndroid Build Coastguard Worker        class BaseModule(torch.jit.ScriptModule):
10213*da0073e9SAndroid Build Coastguard Worker            @torch.jit.ignore
10214*da0073e9SAndroid Build Coastguard Worker            def foo(self, x):
10215*da0073e9SAndroid Build Coastguard Worker                # type: (Tensor) -> Tensor
10216*da0073e9SAndroid Build Coastguard Worker                return x + 1
10217*da0073e9SAndroid Build Coastguard Worker
10218*da0073e9SAndroid Build Coastguard Worker            @torch.jit.ignore
10219*da0073e9SAndroid Build Coastguard Worker            def bar(self, x, y):
10220*da0073e9SAndroid Build Coastguard Worker                # type: (Tensor, Tensor) -> Tuple[Tensor, Tensor]
10221*da0073e9SAndroid Build Coastguard Worker                return x + y, y
10222*da0073e9SAndroid Build Coastguard Worker
10223*da0073e9SAndroid Build Coastguard Worker            @torch.jit.ignore
10224*da0073e9SAndroid Build Coastguard Worker            def baz(self, x, y):
10225*da0073e9SAndroid Build Coastguard Worker                return x
10226*da0073e9SAndroid Build Coastguard Worker
10227*da0073e9SAndroid Build Coastguard Worker        class ModuleTooMany(BaseModule):
10228*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
10229*da0073e9SAndroid Build Coastguard Worker            def method(self, x):
10230*da0073e9SAndroid Build Coastguard Worker                return self.foo(x, x)
10231*da0073e9SAndroid Build Coastguard Worker
10232*da0073e9SAndroid Build Coastguard Worker        class ModuleTooFew(BaseModule):
10233*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
10234*da0073e9SAndroid Build Coastguard Worker            def method(self, x):
10235*da0073e9SAndroid Build Coastguard Worker                return self.bar(x)
10236*da0073e9SAndroid Build Coastguard Worker
10237*da0073e9SAndroid Build Coastguard Worker        class ModuleTooManyAssign(BaseModule):
10238*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
10239*da0073e9SAndroid Build Coastguard Worker            def method(self, x):
10240*da0073e9SAndroid Build Coastguard Worker                y, z, w = self.bar(x, x)
10241*da0073e9SAndroid Build Coastguard Worker                return x
10242*da0073e9SAndroid Build Coastguard Worker
10243*da0073e9SAndroid Build Coastguard Worker        class ModuleDefault(BaseModule):
10244*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
10245*da0073e9SAndroid Build Coastguard Worker            def method(self, x):
10246*da0073e9SAndroid Build Coastguard Worker                y = self.baz(x)
10247*da0073e9SAndroid Build Coastguard Worker                return x
10248*da0073e9SAndroid Build Coastguard Worker
10249*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Expected at most 2 arguments but found 3"):
10250*da0073e9SAndroid Build Coastguard Worker            ModuleTooMany()
10251*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Argument y not provided"):
10252*da0073e9SAndroid Build Coastguard Worker            ModuleTooFew()
10253*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "need 3 values .* found only 2"):
10254*da0073e9SAndroid Build Coastguard Worker            ModuleTooManyAssign()
10255*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Argument y not provided."):
10256*da0073e9SAndroid Build Coastguard Worker            ModuleDefault()
10257*da0073e9SAndroid Build Coastguard Worker
10258*da0073e9SAndroid Build Coastguard Worker    def test_type_inferred_from_empty_annotation(self):
10259*da0073e9SAndroid Build Coastguard Worker        """
10260*da0073e9SAndroid Build Coastguard Worker        Test that the type inferred from an empty or missing annotation is Torch.Tensor wtih `inferred=true`
10261*da0073e9SAndroid Build Coastguard Worker        """
10262*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
10263*da0073e9SAndroid Build Coastguard Worker        def fn(x):
10264*da0073e9SAndroid Build Coastguard Worker            return x
10265*da0073e9SAndroid Build Coastguard Worker
10266*da0073e9SAndroid Build Coastguard Worker        graph = fn.graph
10267*da0073e9SAndroid Build Coastguard Worker        n = next(graph.inputs())
10268*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(n.type() == torch._C.TensorType.getInferred())
10269*da0073e9SAndroid Build Coastguard Worker
10270*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Inferred 'x' to be of type 'Tensor"):
10271*da0073e9SAndroid Build Coastguard Worker            fn("1")
10272*da0073e9SAndroid Build Coastguard Worker
10273*da0073e9SAndroid Build Coastguard Worker    def test_script_define_order(self):
10274*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
10275*da0073e9SAndroid Build Coastguard Worker
10276*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
10277*da0073e9SAndroid Build Coastguard Worker            def call_foo(self, input):
10278*da0073e9SAndroid Build Coastguard Worker                return self.foo(input)
10279*da0073e9SAndroid Build Coastguard Worker
10280*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
10281*da0073e9SAndroid Build Coastguard Worker            def foo(self, input):
10282*da0073e9SAndroid Build Coastguard Worker                return input + 1
10283*da0073e9SAndroid Build Coastguard Worker        m = M()
10284*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(2, m.call_foo(torch.ones((), dtype=torch.int64)))
10285*da0073e9SAndroid Build Coastguard Worker
10286*da0073e9SAndroid Build Coastguard Worker    def test_script_define_order_recursive_fail(self):
10287*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
10288*da0073e9SAndroid Build Coastguard Worker
10289*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
10290*da0073e9SAndroid Build Coastguard Worker            def call_foo(self, input):
10291*da0073e9SAndroid Build Coastguard Worker                return self.foo(input)
10292*da0073e9SAndroid Build Coastguard Worker
10293*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
10294*da0073e9SAndroid Build Coastguard Worker            def foo(self, input):
10295*da0073e9SAndroid Build Coastguard Worker                self.call_foo(input)
10296*da0073e9SAndroid Build Coastguard Worker
10297*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'called recursively'):
10298*da0073e9SAndroid Build Coastguard Worker            M()
10299*da0073e9SAndroid Build Coastguard Worker
10300*da0073e9SAndroid Build Coastguard Worker    def test_script_kwargs_fn_call(self):
10301*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
10302*da0073e9SAndroid Build Coastguard Worker
10303*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
10304*da0073e9SAndroid Build Coastguard Worker            def call_foo(self, input):
10305*da0073e9SAndroid Build Coastguard Worker                return self.foo(input=input, bar=1)
10306*da0073e9SAndroid Build Coastguard Worker
10307*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
10308*da0073e9SAndroid Build Coastguard Worker            def foo(self, bar, input):
10309*da0073e9SAndroid Build Coastguard Worker                # type: (int, Tensor) -> Tensor
10310*da0073e9SAndroid Build Coastguard Worker                return input + bar
10311*da0073e9SAndroid Build Coastguard Worker        m = M()
10312*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(2, m.call_foo(torch.ones((), dtype=torch.int64)))
10313*da0073e9SAndroid Build Coastguard Worker
10314*da0073e9SAndroid Build Coastguard Worker    def test_if_define(self):
10315*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
10316*da0073e9SAndroid Build Coastguard Worker        def foo(a):
10317*da0073e9SAndroid Build Coastguard Worker            if bool(a == 0):
10318*da0073e9SAndroid Build Coastguard Worker                b = 1
10319*da0073e9SAndroid Build Coastguard Worker            else:
10320*da0073e9SAndroid Build Coastguard Worker                b = 0
10321*da0073e9SAndroid Build Coastguard Worker            return b + 1
10322*da0073e9SAndroid Build Coastguard Worker
10323*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
10324*da0073e9SAndroid Build Coastguard Worker        def foo2(a):
10325*da0073e9SAndroid Build Coastguard Worker            b = 0
10326*da0073e9SAndroid Build Coastguard Worker            if bool(a == 0):
10327*da0073e9SAndroid Build Coastguard Worker                b = 1
10328*da0073e9SAndroid Build Coastguard Worker            return b + 1
10329*da0073e9SAndroid Build Coastguard Worker
10330*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
10331*da0073e9SAndroid Build Coastguard Worker        def foo3(a):
10332*da0073e9SAndroid Build Coastguard Worker            b = 1
10333*da0073e9SAndroid Build Coastguard Worker            if bool(a == 0):
10334*da0073e9SAndroid Build Coastguard Worker                c = 4
10335*da0073e9SAndroid Build Coastguard Worker            else:
10336*da0073e9SAndroid Build Coastguard Worker                b = 0
10337*da0073e9SAndroid Build Coastguard Worker            return b + 1
10338*da0073e9SAndroid Build Coastguard Worker
10339*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(1, dtype=torch.long)
10340*da0073e9SAndroid Build Coastguard Worker        b = torch.zeros(1, dtype=torch.long)
10341*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(1, foo(a))
10342*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(2, foo(b))
10343*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(1, foo2(a))
10344*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(2, foo2(b))
10345*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(1, foo3(a))
10346*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(2, foo3(b))
10347*da0073e9SAndroid Build Coastguard Worker
10348*da0073e9SAndroid Build Coastguard Worker    def test_script_module_export_submodule(self):
10349*da0073e9SAndroid Build Coastguard Worker        class M1(torch.jit.ScriptModule):
10350*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
10351*da0073e9SAndroid Build Coastguard Worker                super().__init__()
10352*da0073e9SAndroid Build Coastguard Worker                self.weight = nn.Parameter(torch.randn(2))
10353*da0073e9SAndroid Build Coastguard Worker
10354*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
10355*da0073e9SAndroid Build Coastguard Worker            def forward(self, thing):
10356*da0073e9SAndroid Build Coastguard Worker                return self.weight + thing
10357*da0073e9SAndroid Build Coastguard Worker
10358*da0073e9SAndroid Build Coastguard Worker        class M2(torch.jit.ScriptModule):
10359*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
10360*da0073e9SAndroid Build Coastguard Worker                super().__init__()
10361*da0073e9SAndroid Build Coastguard Worker                # test submodule
10362*da0073e9SAndroid Build Coastguard Worker                self.sub = M1()
10363*da0073e9SAndroid Build Coastguard Worker                self.weight = nn.Parameter(torch.randn(2, 3))
10364*da0073e9SAndroid Build Coastguard Worker                self.bias = nn.Parameter(torch.randn(2))
10365*da0073e9SAndroid Build Coastguard Worker                self.define("""
10366*da0073e9SAndroid Build Coastguard Worker                    def hi(self, a):
10367*da0073e9SAndroid Build Coastguard Worker                        return self.weight.mm(a)
10368*da0073e9SAndroid Build Coastguard Worker                """)
10369*da0073e9SAndroid Build Coastguard Worker
10370*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
10371*da0073e9SAndroid Build Coastguard Worker            def doit(self, input):
10372*da0073e9SAndroid Build Coastguard Worker                return self.weight.mm(input)
10373*da0073e9SAndroid Build Coastguard Worker
10374*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
10375*da0073e9SAndroid Build Coastguard Worker            def doit2(self, input):
10376*da0073e9SAndroid Build Coastguard Worker                return self.weight.mm(input)
10377*da0073e9SAndroid Build Coastguard Worker
10378*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
10379*da0073e9SAndroid Build Coastguard Worker            def doit3(self, input):
10380*da0073e9SAndroid Build Coastguard Worker                return input + torch.ones([1], dtype=torch.double)
10381*da0073e9SAndroid Build Coastguard Worker
10382*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
10383*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
10384*da0073e9SAndroid Build Coastguard Worker                a = self.doit(input)
10385*da0073e9SAndroid Build Coastguard Worker                b = self.doit2(input)
10386*da0073e9SAndroid Build Coastguard Worker                c = self.hi(input)
10387*da0073e9SAndroid Build Coastguard Worker                return a + b + self.bias + c
10388*da0073e9SAndroid Build Coastguard Worker
10389*da0073e9SAndroid Build Coastguard Worker        with torch.jit.optimized_execution(False):
10390*da0073e9SAndroid Build Coastguard Worker            m_orig = M2()
10391*da0073e9SAndroid Build Coastguard Worker            m_import = self.getExportImportCopy(m_orig)
10392*da0073e9SAndroid Build Coastguard Worker
10393*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(3, 2)
10394*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m_orig.doit(input), m_import.doit(input))
10395*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m_orig.hi(input), m_import.hi(input))
10396*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m_orig.doit3(input), m_import.doit3(input))
10397*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m_orig.forward(input), m_import.forward(input))
10398*da0073e9SAndroid Build Coastguard Worker
10399*da0073e9SAndroid Build Coastguard Worker    @slowTest
10400*da0073e9SAndroid Build Coastguard Worker    def test_compile_module_with_constant(self):
10401*da0073e9SAndroid Build Coastguard Worker        class Double(nn.Module):
10402*da0073e9SAndroid Build Coastguard Worker            def __init__(self, downsample=None):
10403*da0073e9SAndroid Build Coastguard Worker                super().__init__()
10404*da0073e9SAndroid Build Coastguard Worker
10405*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
10406*da0073e9SAndroid Build Coastguard Worker                return input * 2
10407*da0073e9SAndroid Build Coastguard Worker
10408*da0073e9SAndroid Build Coastguard Worker        class Mod(nn.Module):
10409*da0073e9SAndroid Build Coastguard Worker            __constants__ = ['downsample']
10410*da0073e9SAndroid Build Coastguard Worker
10411*da0073e9SAndroid Build Coastguard Worker            def __init__(self, downsample=None):
10412*da0073e9SAndroid Build Coastguard Worker                super().__init__()
10413*da0073e9SAndroid Build Coastguard Worker                self.downsample = downsample
10414*da0073e9SAndroid Build Coastguard Worker
10415*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
10416*da0073e9SAndroid Build Coastguard Worker                if self.downsample is not None:
10417*da0073e9SAndroid Build Coastguard Worker                    return self.downsample(input)
10418*da0073e9SAndroid Build Coastguard Worker                return input
10419*da0073e9SAndroid Build Coastguard Worker
10420*da0073e9SAndroid Build Coastguard Worker        none_mod = torch.jit.script(Mod(None))
10421*da0073e9SAndroid Build Coastguard Worker        double_mod = torch.jit.script(Mod(Double()))
10422*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(none_mod(torch.tensor(1)), torch.tensor(1))
10423*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(double_mod(torch.tensor(1)), torch.tensor(1) * 2)
10424*da0073e9SAndroid Build Coastguard Worker
10425*da0073e9SAndroid Build Coastguard Worker    def test_device_kwarg(self):
10426*da0073e9SAndroid Build Coastguard Worker        from torch import device
10427*da0073e9SAndroid Build Coastguard Worker
10428*da0073e9SAndroid Build Coastguard Worker        def f():
10429*da0073e9SAndroid Build Coastguard Worker            return device(type='cuda'), torch.device(type='cpu')
10430*da0073e9SAndroid Build Coastguard Worker        self.checkScript(f, ())
10431*da0073e9SAndroid Build Coastguard Worker
10432*da0073e9SAndroid Build Coastguard Worker    def test_script_module_export_tensor_type(self):
10433*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
10434*da0073e9SAndroid Build Coastguard Worker            def __init__(self, type):
10435*da0073e9SAndroid Build Coastguard Worker                super().__init__()
10436*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.zeros((5, 5), dtype=type).random_())
10437*da0073e9SAndroid Build Coastguard Worker
10438*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
10439*da0073e9SAndroid Build Coastguard Worker            def foo(self):
10440*da0073e9SAndroid Build Coastguard Worker                return self.param
10441*da0073e9SAndroid Build Coastguard Worker
10442*da0073e9SAndroid Build Coastguard Worker        with torch.jit.optimized_execution(False):
10443*da0073e9SAndroid Build Coastguard Worker            for type in [torch.float, torch.double]:
10444*da0073e9SAndroid Build Coastguard Worker                m_orig = M(type)
10445*da0073e9SAndroid Build Coastguard Worker                m_import = self.getExportImportCopy(m_orig)
10446*da0073e9SAndroid Build Coastguard Worker                # check to make sure the storage wasn't resized
10447*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(m_orig.param.storage().size() == 25)
10448*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(m_orig.foo(), m_import.foo())
10449*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype)
10450*da0073e9SAndroid Build Coastguard Worker
10451*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "testing cuda tensors require CUDA")
10452*da0073e9SAndroid Build Coastguard Worker    def test_script_module_export_tensor_cuda(self):
10453*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
10454*da0073e9SAndroid Build Coastguard Worker
10455*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
10456*da0073e9SAndroid Build Coastguard Worker                super().__init__()
10457*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.zeros((5, 5), device='cuda:0').random_())
10458*da0073e9SAndroid Build Coastguard Worker
10459*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
10460*da0073e9SAndroid Build Coastguard Worker            def foo(self):
10461*da0073e9SAndroid Build Coastguard Worker                return self.param
10462*da0073e9SAndroid Build Coastguard Worker
10463*da0073e9SAndroid Build Coastguard Worker        m_orig = M()
10464*da0073e9SAndroid Build Coastguard Worker        m_import = self.getExportImportCopy(m_orig)
10465*da0073e9SAndroid Build Coastguard Worker        # check to make sure the storage wasn't resized
10466*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m_orig.param.storage().size() == 25)
10467*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m_import.foo().device == torch.device('cuda:0'))
10468*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m_orig.foo(), m_import.foo())
10469*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype)
10470*da0073e9SAndroid Build Coastguard Worker
10471*da0073e9SAndroid Build Coastguard Worker    def test_script_module_export_blocks(self):
10472*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
10473*da0073e9SAndroid Build Coastguard Worker            def __init__(self, n, m):
10474*da0073e9SAndroid Build Coastguard Worker                super().__init__()
10475*da0073e9SAndroid Build Coastguard Worker                self.weight = torch.nn.Parameter(torch.rand(n, m))
10476*da0073e9SAndroid Build Coastguard Worker
10477*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
10478*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
10479*da0073e9SAndroid Build Coastguard Worker                if bool(input.sum() > 0):
10480*da0073e9SAndroid Build Coastguard Worker                    output = self.weight.mv(input)
10481*da0073e9SAndroid Build Coastguard Worker                else:
10482*da0073e9SAndroid Build Coastguard Worker                    output = self.weight + input
10483*da0073e9SAndroid Build Coastguard Worker                return output
10484*da0073e9SAndroid Build Coastguard Worker
10485*da0073e9SAndroid Build Coastguard Worker        m_orig = M(200, 200)
10486*da0073e9SAndroid Build Coastguard Worker        m_import = self.getExportImportCopy(m_orig)
10487*da0073e9SAndroid Build Coastguard Worker
10488*da0073e9SAndroid Build Coastguard Worker        t = torch.rand(200)
10489*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m_orig(t), m_import(t))
10490*da0073e9SAndroid Build Coastguard Worker
10491*da0073e9SAndroid Build Coastguard Worker    def test_script_module_export_shared_storage(self):
10492*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
10493*da0073e9SAndroid Build Coastguard Worker
10494*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
10495*da0073e9SAndroid Build Coastguard Worker                super().__init__()
10496*da0073e9SAndroid Build Coastguard Worker                self.param1 = torch.nn.Parameter(torch.rand(5, 5))
10497*da0073e9SAndroid Build Coastguard Worker                self.param2 = torch.nn.Parameter(self.param1[3])
10498*da0073e9SAndroid Build Coastguard Worker                self.param3 = torch.nn.Parameter(torch.rand(5, 5))
10499*da0073e9SAndroid Build Coastguard Worker                self.param4 = torch.nn.Parameter(torch.rand(11, 5)[1:6])
10500*da0073e9SAndroid Build Coastguard Worker
10501*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
10502*da0073e9SAndroid Build Coastguard Worker            def foo(self):
10503*da0073e9SAndroid Build Coastguard Worker                return self.param1 + self.param2 + self.param3 + self.param4
10504*da0073e9SAndroid Build Coastguard Worker
10505*da0073e9SAndroid Build Coastguard Worker        with torch.jit.optimized_execution(False):
10506*da0073e9SAndroid Build Coastguard Worker            m_orig = M()
10507*da0073e9SAndroid Build Coastguard Worker            m_import = self.getExportImportCopy(m_orig)
10508*da0073e9SAndroid Build Coastguard Worker
10509*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m_orig.foo(), m_import.foo())
10510*da0073e9SAndroid Build Coastguard Worker
10511*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(m_import.param1.storage().data_ptr() == m_import.param2.storage().data_ptr())
10512*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(m_import.param1.storage().data_ptr() != m_import.param3.storage().data_ptr())
10513*da0073e9SAndroid Build Coastguard Worker
10514*da0073e9SAndroid Build Coastguard Worker    def test_sequential_intermediary_types(self):
10515*da0073e9SAndroid Build Coastguard Worker        class A(torch.nn.Module):
10516*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
10517*da0073e9SAndroid Build Coastguard Worker                return x + 3
10518*da0073e9SAndroid Build Coastguard Worker
10519*da0073e9SAndroid Build Coastguard Worker        class B(torch.nn.Module):
10520*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
10521*da0073e9SAndroid Build Coastguard Worker                return {"1": x}
10522*da0073e9SAndroid Build Coastguard Worker
10523*da0073e9SAndroid Build Coastguard Worker        class C(torch.nn.Module):
10524*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
10525*da0073e9SAndroid Build Coastguard Worker                super().__init__()
10526*da0073e9SAndroid Build Coastguard Worker                self.foo = torch.nn.Sequential(A(), B())
10527*da0073e9SAndroid Build Coastguard Worker
10528*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
10529*da0073e9SAndroid Build Coastguard Worker                return self.foo(x)
10530*da0073e9SAndroid Build Coastguard Worker
10531*da0073e9SAndroid Build Coastguard Worker        self.checkModule(C(), (torch.tensor(1),))
10532*da0073e9SAndroid Build Coastguard Worker
10533*da0073e9SAndroid Build Coastguard Worker    def test_ellipsis_const_mid(self):
10534*da0073e9SAndroid Build Coastguard Worker        def ellipsize(x):
10535*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor) -> List[int]
10536*da0073e9SAndroid Build Coastguard Worker            return x[2, Ellipsis, 0:4, 4:8].size()
10537*da0073e9SAndroid Build Coastguard Worker
10538*da0073e9SAndroid Build Coastguard Worker        dummy = torch.zeros(8, 8, 8, 8, 8)
10539*da0073e9SAndroid Build Coastguard Worker        self.checkScript(ellipsize, (dummy,), optimize=True)
10540*da0073e9SAndroid Build Coastguard Worker
10541*da0073e9SAndroid Build Coastguard Worker    def test_ellipsis_const_mid_select(self):
10542*da0073e9SAndroid Build Coastguard Worker        def ellipsize(x):
10543*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor) -> List[int]
10544*da0073e9SAndroid Build Coastguard Worker            return x[2, Ellipsis, 4, 4, 4:8, 2].size()
10545*da0073e9SAndroid Build Coastguard Worker
10546*da0073e9SAndroid Build Coastguard Worker        dummy = torch.zeros(8, 8, 8, 8, 8, 8, 8)
10547*da0073e9SAndroid Build Coastguard Worker        self.checkScript(ellipsize, (dummy,), optimize=True)
10548*da0073e9SAndroid Build Coastguard Worker
10549*da0073e9SAndroid Build Coastguard Worker    def test_ellipsis_const_start(self):
10550*da0073e9SAndroid Build Coastguard Worker        def ellipsize(x):
10551*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor) -> List[int]
10552*da0073e9SAndroid Build Coastguard Worker            return x[Ellipsis, 0:4, 4:8].size()
10553*da0073e9SAndroid Build Coastguard Worker        dummy = torch.zeros(8, 8, 8, 8, 8)
10554*da0073e9SAndroid Build Coastguard Worker        self.checkScript(ellipsize, (dummy,), optimize=True)
10555*da0073e9SAndroid Build Coastguard Worker
10556*da0073e9SAndroid Build Coastguard Worker    def test_ellipsis_const_end(self):
10557*da0073e9SAndroid Build Coastguard Worker        def ellipsize(x):
10558*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor) -> List[int]
10559*da0073e9SAndroid Build Coastguard Worker            return x[0:4, 2, Ellipsis].size()
10560*da0073e9SAndroid Build Coastguard Worker        dummy = torch.zeros(8, 8, 8, 8, 8)
10561*da0073e9SAndroid Build Coastguard Worker        self.checkScript(ellipsize, (dummy,), optimize=True)
10562*da0073e9SAndroid Build Coastguard Worker
10563*da0073e9SAndroid Build Coastguard Worker    def test_ellipsis_mid(self):
10564*da0073e9SAndroid Build Coastguard Worker        def ellipsize(x):
10565*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor) -> List[int]
10566*da0073e9SAndroid Build Coastguard Worker            return x[2, ..., 0:4, 4:8].size()
10567*da0073e9SAndroid Build Coastguard Worker
10568*da0073e9SAndroid Build Coastguard Worker        dummy = torch.zeros(8, 8, 8, 8, 8)
10569*da0073e9SAndroid Build Coastguard Worker        self.checkScript(ellipsize, (dummy,), optimize=True)
10570*da0073e9SAndroid Build Coastguard Worker
10571*da0073e9SAndroid Build Coastguard Worker    def test_ellipsis_mid_select(self):
10572*da0073e9SAndroid Build Coastguard Worker        def ellipsize(x):
10573*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor) -> List[int]
10574*da0073e9SAndroid Build Coastguard Worker            return x[2, ..., 4, 4, 4:8, 2].size()
10575*da0073e9SAndroid Build Coastguard Worker
10576*da0073e9SAndroid Build Coastguard Worker        dummy = torch.zeros(8, 8, 8, 8, 8, 8, 8)
10577*da0073e9SAndroid Build Coastguard Worker        self.checkScript(ellipsize, (dummy,), optimize=True)
10578*da0073e9SAndroid Build Coastguard Worker
10579*da0073e9SAndroid Build Coastguard Worker    def test_ellipsis_start(self):
10580*da0073e9SAndroid Build Coastguard Worker        def ellipsize(x):
10581*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor) -> List[int]
10582*da0073e9SAndroid Build Coastguard Worker            return x[..., 0:4, 4:8].size()
10583*da0073e9SAndroid Build Coastguard Worker        dummy = torch.zeros(8, 8, 8, 8, 8)
10584*da0073e9SAndroid Build Coastguard Worker        self.checkScript(ellipsize, (dummy,), optimize=True)
10585*da0073e9SAndroid Build Coastguard Worker
10586*da0073e9SAndroid Build Coastguard Worker    def test_ellipsis_end(self):
10587*da0073e9SAndroid Build Coastguard Worker        def ellipsize(x):
10588*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor) -> List[int]
10589*da0073e9SAndroid Build Coastguard Worker            return x[0:4, 2, ...].size()
10590*da0073e9SAndroid Build Coastguard Worker        dummy = torch.zeros(8, 8, 8, 8, 8)
10591*da0073e9SAndroid Build Coastguard Worker        self.checkScript(ellipsize, (dummy,), optimize=True)
10592*da0073e9SAndroid Build Coastguard Worker
10593*da0073e9SAndroid Build Coastguard Worker    def test_torch_manual_seed(self):
10594*da0073e9SAndroid Build Coastguard Worker        with freeze_rng_state():
10595*da0073e9SAndroid Build Coastguard Worker            def test():
10596*da0073e9SAndroid Build Coastguard Worker                torch.manual_seed(2)
10597*da0073e9SAndroid Build Coastguard Worker                return torch.rand(1)
10598*da0073e9SAndroid Build Coastguard Worker
10599*da0073e9SAndroid Build Coastguard Worker            script = torch.jit.script(test)
10600*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(test(), script())
10601*da0073e9SAndroid Build Coastguard Worker            graph = script.graph_for()
10602*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("aten::manual_seed").run(graph)
10603*da0073e9SAndroid Build Coastguard Worker
10604*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
10605*da0073e9SAndroid Build Coastguard Worker    def test_index_select_shape_prop(self):
10606*da0073e9SAndroid Build Coastguard Worker
10607*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
10608*da0073e9SAndroid Build Coastguard Worker        def foo(x, y):
10609*da0073e9SAndroid Build Coastguard Worker            return torch.index_select(x, index=y, dim=1)
10610*da0073e9SAndroid Build Coastguard Worker
10611*da0073e9SAndroid Build Coastguard Worker        a = torch.zeros(2, 2)
10612*da0073e9SAndroid Build Coastguard Worker        b = torch.zeros(4, dtype=torch.long)
10613*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_complete_shape_analysis(foo.graph, (a, b), False)
10614*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("Float(2, 4, strides=[4, 1], requires_grad=0, device=cpu)").run(str(foo.graph))
10615*da0073e9SAndroid Build Coastguard Worker
10616*da0073e9SAndroid Build Coastguard Worker    def test_shape_analysis_loop(self):
10617*da0073e9SAndroid Build Coastguard Worker        def foo(a, b, x):
10618*da0073e9SAndroid Build Coastguard Worker            c = a
10619*da0073e9SAndroid Build Coastguard Worker            # on the first iteration of the loop it appears that
10620*da0073e9SAndroid Build Coastguard Worker            # c should have a expand to the size of b
10621*da0073e9SAndroid Build Coastguard Worker            # but on the second+ iterations, there is no broadcast and the
10622*da0073e9SAndroid Build Coastguard Worker            # sizes are different.
10623*da0073e9SAndroid Build Coastguard Worker            # previously this would cause the compiler to (1) enter an infinite
10624*da0073e9SAndroid Build Coastguard Worker            # loop trying to compute the shape, and (2) insert invalid
10625*da0073e9SAndroid Build Coastguard Worker            # broadcasts.
10626*da0073e9SAndroid Build Coastguard Worker            # this test ensure we don't regress on these issues
10627*da0073e9SAndroid Build Coastguard Worker            for _ in range(2):
10628*da0073e9SAndroid Build Coastguard Worker                a = c + b
10629*da0073e9SAndroid Build Coastguard Worker                c = x
10630*da0073e9SAndroid Build Coastguard Worker                b = x
10631*da0073e9SAndroid Build Coastguard Worker            return a
10632*da0073e9SAndroid Build Coastguard Worker
10633*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, (torch.zeros(1), torch.zeros(4), torch.zeros(5)), optimize=False)
10634*da0073e9SAndroid Build Coastguard Worker
10635*da0073e9SAndroid Build Coastguard Worker    def test_intlist_args(self):
10636*da0073e9SAndroid Build Coastguard Worker        def func_1(x):
10637*da0073e9SAndroid Build Coastguard Worker            return torch.nn.functional.adaptive_avg_pool1d(x, 1)
10638*da0073e9SAndroid Build Coastguard Worker
10639*da0073e9SAndroid Build Coastguard Worker        def func_2(x):
10640*da0073e9SAndroid Build Coastguard Worker            return torch.nn.functional.adaptive_avg_pool1d(x, output_size=1)
10641*da0073e9SAndroid Build Coastguard Worker
10642*da0073e9SAndroid Build Coastguard Worker        def func_3(x):
10643*da0073e9SAndroid Build Coastguard Worker            return torch.nn.functional.adaptive_avg_pool1d(x, output_size=[1])
10644*da0073e9SAndroid Build Coastguard Worker
10645*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(8, 8, 8)
10646*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func_1, [x], optimize=True)
10647*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func_2, [x], optimize=True)
10648*da0073e9SAndroid Build Coastguard Worker        self.checkScript(func_3, [x], optimize=True)
10649*da0073e9SAndroid Build Coastguard Worker
10650*da0073e9SAndroid Build Coastguard Worker    def test_wrong_implicit_expand(self):
10651*da0073e9SAndroid Build Coastguard Worker
10652*da0073e9SAndroid Build Coastguard Worker        @_trace(torch.zeros(3), torch.zeros(1))
10653*da0073e9SAndroid Build Coastguard Worker        def foo(a, b):
10654*da0073e9SAndroid Build Coastguard Worker            return a + b
10655*da0073e9SAndroid Build Coastguard Worker
10656*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(4)
10657*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(4)
10658*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a + b, foo(a, b))
10659*da0073e9SAndroid Build Coastguard Worker
10660*da0073e9SAndroid Build Coastguard Worker    def test_builtin_args_fails(self):
10661*da0073e9SAndroid Build Coastguard Worker
10662*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'Argument self not provided'):
10663*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
10664*da0073e9SAndroid Build Coastguard Worker            def f1(a):
10665*da0073e9SAndroid Build Coastguard Worker                torch.sum(foo=4)
10666*da0073e9SAndroid Build Coastguard Worker
10667*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'specified twice'):
10668*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
10669*da0073e9SAndroid Build Coastguard Worker            def f2(a):
10670*da0073e9SAndroid Build Coastguard Worker                torch.sum(a, self=a)
10671*da0073e9SAndroid Build Coastguard Worker
10672*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'not provided'):
10673*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
10674*da0073e9SAndroid Build Coastguard Worker            def f3(a):
10675*da0073e9SAndroid Build Coastguard Worker                torch.sum(dim=4)
10676*da0073e9SAndroid Build Coastguard Worker
10677*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'for argument \'tensors\' but instead found type \'Tensor'):
10678*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
10679*da0073e9SAndroid Build Coastguard Worker            def f4(a):
10680*da0073e9SAndroid Build Coastguard Worker                torch.cat(a)
10681*da0073e9SAndroid Build Coastguard Worker
10682*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'argument \'tensors\' but instead found type \'List\[int\]'):
10683*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
10684*da0073e9SAndroid Build Coastguard Worker            def f5(a):
10685*da0073e9SAndroid Build Coastguard Worker                torch.cat([3])
10686*da0073e9SAndroid Build Coastguard Worker
10687*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'Expected a value of'
10688*da0073e9SAndroid Build Coastguard Worker                                    r' type \'List\[int\]\' for argument'
10689*da0073e9SAndroid Build Coastguard Worker                                    r' \'size\' but instead found type '
10690*da0073e9SAndroid Build Coastguard Worker                                    r'\'List\[Union\[List\[int\], int\]\]'):
10691*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
10692*da0073e9SAndroid Build Coastguard Worker            def f6(a):
10693*da0073e9SAndroid Build Coastguard Worker                a.expand(size=[3, [4]])
10694*da0073e9SAndroid Build Coastguard Worker
10695*da0073e9SAndroid Build Coastguard Worker    def test_builtin_args(self):
10696*da0073e9SAndroid Build Coastguard Worker
10697*da0073e9SAndroid Build Coastguard Worker        def t0(a):
10698*da0073e9SAndroid Build Coastguard Worker            # default arg dim
10699*da0073e9SAndroid Build Coastguard Worker            return torch.cat([a, a])
10700*da0073e9SAndroid Build Coastguard Worker
10701*da0073e9SAndroid Build Coastguard Worker        self.checkScript(t0, (torch.zeros(1, 1),))
10702*da0073e9SAndroid Build Coastguard Worker
10703*da0073e9SAndroid Build Coastguard Worker        def t1(a):
10704*da0073e9SAndroid Build Coastguard Worker            # keywords out of order
10705*da0073e9SAndroid Build Coastguard Worker            return torch.cat(dim=1, tensors=[a, a])
10706*da0073e9SAndroid Build Coastguard Worker
10707*da0073e9SAndroid Build Coastguard Worker        self.checkScript(t1, (torch.zeros(1, 1, 2),))
10708*da0073e9SAndroid Build Coastguard Worker
10709*da0073e9SAndroid Build Coastguard Worker        def t2(a):
10710*da0073e9SAndroid Build Coastguard Worker            # mix const/non-const attributes
10711*da0073e9SAndroid Build Coastguard Worker            if 1 == 1:
10712*da0073e9SAndroid Build Coastguard Worker                b = 1
10713*da0073e9SAndroid Build Coastguard Worker            else:
10714*da0073e9SAndroid Build Coastguard Worker                b = 0
10715*da0073e9SAndroid Build Coastguard Worker            return torch.sum(a, dim=b, keepdim=False)
10716*da0073e9SAndroid Build Coastguard Worker
10717*da0073e9SAndroid Build Coastguard Worker        self.checkScript(t2, (torch.zeros(1, 1, 2),))
10718*da0073e9SAndroid Build Coastguard Worker
10719*da0073e9SAndroid Build Coastguard Worker    def test_parser_type_annotations(self):
10720*da0073e9SAndroid Build Coastguard Worker        cu = torch.jit.CompilationUnit('''
10721*da0073e9SAndroid Build Coastguard Worker            def foo(x : Tensor, y : Tuple[Tuple[Tensor, Tensor], Tensor]) -> Tuple[Tensor, Tensor]:
10722*da0073e9SAndroid Build Coastguard Worker                return x, x
10723*da0073e9SAndroid Build Coastguard Worker        ''')
10724*da0073e9SAndroid Build Coastguard Worker
10725*da0073e9SAndroid Build Coastguard Worker        self.assertExpected(str(cu.foo.schema))
10726*da0073e9SAndroid Build Coastguard Worker
10727*da0073e9SAndroid Build Coastguard Worker    def test_parser_type_annotations_comment(self):
10728*da0073e9SAndroid Build Coastguard Worker        cu = torch.jit.CompilationUnit('''
10729*da0073e9SAndroid Build Coastguard Worker            def foo(x, y):
10730*da0073e9SAndroid Build Coastguard Worker                # type: (Tensor, Tuple[Tuple[Tensor, Tensor], Tensor]) -> Tuple[Tensor, Tensor]
10731*da0073e9SAndroid Build Coastguard Worker                return x, x
10732*da0073e9SAndroid Build Coastguard Worker        ''')
10733*da0073e9SAndroid Build Coastguard Worker
10734*da0073e9SAndroid Build Coastguard Worker        self.assertExpected(str(cu.foo.schema))
10735*da0073e9SAndroid Build Coastguard Worker
10736*da0073e9SAndroid Build Coastguard Worker    def test_parser_type_annotations_unknown_type(self):
10737*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Unknown type name 'Foo'"):
10738*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit('''
10739*da0073e9SAndroid Build Coastguard Worker                def foo(x : Tensor, y : Tuple[Tuple[Foo, Tensor], Tensor]) -> Tuple[Tensor, Tensor]:
10740*da0073e9SAndroid Build Coastguard Worker                    return x, x
10741*da0073e9SAndroid Build Coastguard Worker            ''')
10742*da0073e9SAndroid Build Coastguard Worker
10743*da0073e9SAndroid Build Coastguard Worker    def test_parser_type_annotations_subscript_non_ident(self):
10744*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'Subscripted type must be a type identifier'):
10745*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit('''
10746*da0073e9SAndroid Build Coastguard Worker                def foo(x : Tensor, y : Tuple[Tensor, Tensor][Tensor]) -> Tuple[Tensor, Tensor]:
10747*da0073e9SAndroid Build Coastguard Worker                    return x, x
10748*da0073e9SAndroid Build Coastguard Worker            ''')
10749*da0073e9SAndroid Build Coastguard Worker
10750*da0073e9SAndroid Build Coastguard Worker    def test_parser_type_annotations_subscript_tensor(self):
10751*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'Unknown type constructor Tensor'):
10752*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit('''
10753*da0073e9SAndroid Build Coastguard Worker                def foo(x : Tensor, y : Tensor[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
10754*da0073e9SAndroid Build Coastguard Worker                    return x, x
10755*da0073e9SAndroid Build Coastguard Worker            ''')
10756*da0073e9SAndroid Build Coastguard Worker
10757*da0073e9SAndroid Build Coastguard Worker    def test_parser_type_annotations_incompatible_expression(self):
10758*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'Expression of type \+ cannot be used in a type expression'):
10759*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit('''
10760*da0073e9SAndroid Build Coastguard Worker                def foo(x : Tensor, y : Tuple[3 + 4, Tensor]) -> Tuple[Tensor, Tensor]:
10761*da0073e9SAndroid Build Coastguard Worker                    return x, x
10762*da0073e9SAndroid Build Coastguard Worker            ''')
10763*da0073e9SAndroid Build Coastguard Worker
10764*da0073e9SAndroid Build Coastguard Worker    def test_gather_dynamic_index(self):
10765*da0073e9SAndroid Build Coastguard Worker        def t(x):
10766*da0073e9SAndroid Build Coastguard Worker            gather1 = x[0]
10767*da0073e9SAndroid Build Coastguard Worker            idx = 0 + 1
10768*da0073e9SAndroid Build Coastguard Worker            gather2 = x[idx]
10769*da0073e9SAndroid Build Coastguard Worker            return gather1 + gather2
10770*da0073e9SAndroid Build Coastguard Worker
10771*da0073e9SAndroid Build Coastguard Worker        self.checkScript(t, (torch.zeros(3, 2, 3),))
10772*da0073e9SAndroid Build Coastguard Worker
10773*da0073e9SAndroid Build Coastguard Worker    def test_torch_ignore_conversion_to_none(self):
10774*da0073e9SAndroid Build Coastguard Worker        class A(torch.nn.Module):
10775*da0073e9SAndroid Build Coastguard Worker            @torch.jit.ignore
10776*da0073e9SAndroid Build Coastguard Worker            def ignored(self, a: int) -> None:
10777*da0073e9SAndroid Build Coastguard Worker                l: int = len([2 for i in range(a) if i > 2])
10778*da0073e9SAndroid Build Coastguard Worker                return
10779*da0073e9SAndroid Build Coastguard Worker
10780*da0073e9SAndroid Build Coastguard Worker            def forward(self) -> int:
10781*da0073e9SAndroid Build Coastguard Worker                a: int = 4
10782*da0073e9SAndroid Build Coastguard Worker                b: int = 5
10783*da0073e9SAndroid Build Coastguard Worker                self.ignored(a)
10784*da0073e9SAndroid Build Coastguard Worker                return a + b
10785*da0073e9SAndroid Build Coastguard Worker
10786*da0073e9SAndroid Build Coastguard Worker        class B(torch.nn.Module):
10787*da0073e9SAndroid Build Coastguard Worker            @torch.jit.ignore
10788*da0073e9SAndroid Build Coastguard Worker            def ignored(self, a: int):
10789*da0073e9SAndroid Build Coastguard Worker                l: int = len([2 for i in range(a) if i > 2])
10790*da0073e9SAndroid Build Coastguard Worker                return
10791*da0073e9SAndroid Build Coastguard Worker
10792*da0073e9SAndroid Build Coastguard Worker            def forward(self) -> int:
10793*da0073e9SAndroid Build Coastguard Worker                a: int = 4
10794*da0073e9SAndroid Build Coastguard Worker                b: int = 5
10795*da0073e9SAndroid Build Coastguard Worker                self.ignored(a)
10796*da0073e9SAndroid Build Coastguard Worker                return a + b
10797*da0073e9SAndroid Build Coastguard Worker
10798*da0073e9SAndroid Build Coastguard Worker        modelA = torch.jit.script(A())
10799*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(modelA(), 9)
10800*da0073e9SAndroid Build Coastguard Worker
10801*da0073e9SAndroid Build Coastguard Worker        modelB = torch.jit.script(B())
10802*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(modelB(), 9)
10803*da0073e9SAndroid Build Coastguard Worker
10804*da0073e9SAndroid Build Coastguard Worker    def test_addmm_grad(self):
10805*da0073e9SAndroid Build Coastguard Worker        """ This test checks several things:
10806*da0073e9SAndroid Build Coastguard Worker            1. An expand node was inserted before the addmm operating on the
10807*da0073e9SAndroid Build Coastguard Worker               bias term.
10808*da0073e9SAndroid Build Coastguard Worker            2. The fused form of addmm appears in the ultimate graph that's
10809*da0073e9SAndroid Build Coastguard Worker               executed.
10810*da0073e9SAndroid Build Coastguard Worker            3. A sum op was emitted for accumulating gradients along the 0th
10811*da0073e9SAndroid Build Coastguard Worker               (expanded) dimension of the bias term.
10812*da0073e9SAndroid Build Coastguard Worker            4. The correct symbolic representation for the backward pass of the
10813*da0073e9SAndroid Build Coastguard Worker               mm operator was emitted (x.t() -> mm)
10814*da0073e9SAndroid Build Coastguard Worker
10815*da0073e9SAndroid Build Coastguard Worker            TODO: we should actually check these conditions once we have a way
10816*da0073e9SAndroid Build Coastguard Worker            to dump the GraphExecutor state. Namely the processed forward graph
10817*da0073e9SAndroid Build Coastguard Worker            and the backward graph.
10818*da0073e9SAndroid Build Coastguard Worker        """
10819*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
10820*da0073e9SAndroid Build Coastguard Worker        def addmm_grad_test(b, x, w):
10821*da0073e9SAndroid Build Coastguard Worker            return torch.addmm(b, x, w)
10822*da0073e9SAndroid Build Coastguard Worker
10823*da0073e9SAndroid Build Coastguard Worker        # Initialize param and input values
10824*da0073e9SAndroid Build Coastguard Worker        w_init = torch.rand(2, 5)
10825*da0073e9SAndroid Build Coastguard Worker        b_init = torch.rand(5)
10826*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 2)
10827*da0073e9SAndroid Build Coastguard Worker
10828*da0073e9SAndroid Build Coastguard Worker        # Clone trainable params
10829*da0073e9SAndroid Build Coastguard Worker        b = b_init.clone()
10830*da0073e9SAndroid Build Coastguard Worker        b.requires_grad_()
10831*da0073e9SAndroid Build Coastguard Worker        w = w_init.clone()
10832*da0073e9SAndroid Build Coastguard Worker        w.requires_grad_()
10833*da0073e9SAndroid Build Coastguard Worker
10834*da0073e9SAndroid Build Coastguard Worker        # Test symbolic differentiation
10835*da0073e9SAndroid Build Coastguard Worker        y = addmm_grad_test(b, x, w)
10836*da0073e9SAndroid Build Coastguard Worker        y.sum().backward()
10837*da0073e9SAndroid Build Coastguard Worker
10838*da0073e9SAndroid Build Coastguard Worker        # clone params for autograd reference
10839*da0073e9SAndroid Build Coastguard Worker        b_ref = b_init.clone()
10840*da0073e9SAndroid Build Coastguard Worker        b_ref.requires_grad_()
10841*da0073e9SAndroid Build Coastguard Worker        w_ref = w_init.clone()
10842*da0073e9SAndroid Build Coastguard Worker        w_ref.requires_grad_()
10843*da0073e9SAndroid Build Coastguard Worker        y_ref = torch.addmm(b_ref, x, w_ref)
10844*da0073e9SAndroid Build Coastguard Worker        y_ref.sum().backward()
10845*da0073e9SAndroid Build Coastguard Worker
10846*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(w.grad, w_ref.grad)
10847*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b.grad, b_ref.grad)
10848*da0073e9SAndroid Build Coastguard Worker
10849*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "running tests on cuda to verify cudnn fix")
10850*da0073e9SAndroid Build Coastguard Worker    def test_batch_norm_inference_backward_cuda(self):
10851*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode_for_profiling_tests():
10852*da0073e9SAndroid Build Coastguard Worker            class MyBatchNorm(torch.nn.Module):
10853*da0073e9SAndroid Build Coastguard Worker                def __init__(self, num_features, affine, track_running_stats):
10854*da0073e9SAndroid Build Coastguard Worker                    super().__init__()
10855*da0073e9SAndroid Build Coastguard Worker                    self.bn = torch.nn.BatchNorm2d(
10856*da0073e9SAndroid Build Coastguard Worker                        num_features, 1e-5, affine=affine, track_running_stats=track_running_stats).float()
10857*da0073e9SAndroid Build Coastguard Worker
10858*da0073e9SAndroid Build Coastguard Worker                def forward(self, x: torch.Tensor):
10859*da0073e9SAndroid Build Coastguard Worker                    o = self.bn(x)
10860*da0073e9SAndroid Build Coastguard Worker                    o = torch.nn.functional.relu(o)
10861*da0073e9SAndroid Build Coastguard Worker                    return o
10862*da0073e9SAndroid Build Coastguard Worker
10863*da0073e9SAndroid Build Coastguard Worker            batch = 4
10864*da0073e9SAndroid Build Coastguard Worker            c = 2
10865*da0073e9SAndroid Build Coastguard Worker            hw = 3
10866*da0073e9SAndroid Build Coastguard Worker            # Initialize param and input values
10867*da0073e9SAndroid Build Coastguard Worker            x_init = torch.randn(batch, c, hw, hw, dtype=torch.float).cuda()
10868*da0073e9SAndroid Build Coastguard Worker            grad = torch.randn(batch, c, hw, hw, dtype=torch.float).cuda()
10869*da0073e9SAndroid Build Coastguard Worker
10870*da0073e9SAndroid Build Coastguard Worker            training = False
10871*da0073e9SAndroid Build Coastguard Worker            affine = True
10872*da0073e9SAndroid Build Coastguard Worker            track_running_stats = True
10873*da0073e9SAndroid Build Coastguard Worker
10874*da0073e9SAndroid Build Coastguard Worker            module = torch.jit.script(MyBatchNorm(c, affine, track_running_stats)).cuda()
10875*da0073e9SAndroid Build Coastguard Worker            ref_module = MyBatchNorm(c, affine, track_running_stats).cuda()
10876*da0073e9SAndroid Build Coastguard Worker            module.eval()
10877*da0073e9SAndroid Build Coastguard Worker            ref_module.eval()
10878*da0073e9SAndroid Build Coastguard Worker
10879*da0073e9SAndroid Build Coastguard Worker            jit_module = torch.jit.script(module)
10880*da0073e9SAndroid Build Coastguard Worker            ref_module.load_state_dict(module.state_dict())
10881*da0073e9SAndroid Build Coastguard Worker
10882*da0073e9SAndroid Build Coastguard Worker            x = x_init.detach().clone()
10883*da0073e9SAndroid Build Coastguard Worker            x.requires_grad_()
10884*da0073e9SAndroid Build Coastguard Worker            x_ref = x_init.detach().clone()
10885*da0073e9SAndroid Build Coastguard Worker            x_ref.requires_grad_()
10886*da0073e9SAndroid Build Coastguard Worker
10887*da0073e9SAndroid Build Coastguard Worker            # Test symbolic differentiation
10888*da0073e9SAndroid Build Coastguard Worker            # Run Forward and Backward thrice to trigger autodiff graph
10889*da0073e9SAndroid Build Coastguard Worker            for i in range(0, 3):
10890*da0073e9SAndroid Build Coastguard Worker                y = jit_module(x)
10891*da0073e9SAndroid Build Coastguard Worker                y.backward(grad)
10892*da0073e9SAndroid Build Coastguard Worker            x.grad.zero_()
10893*da0073e9SAndroid Build Coastguard Worker
10894*da0073e9SAndroid Build Coastguard Worker            module.bn.running_mean.zero_()
10895*da0073e9SAndroid Build Coastguard Worker            module.bn.running_var.fill_(1.0)
10896*da0073e9SAndroid Build Coastguard Worker            ref_module.bn.running_mean.zero_()
10897*da0073e9SAndroid Build Coastguard Worker            ref_module.bn.running_var.fill_(1.0)
10898*da0073e9SAndroid Build Coastguard Worker
10899*da0073e9SAndroid Build Coastguard Worker            # run jitted module
10900*da0073e9SAndroid Build Coastguard Worker            y = jit_module(x)
10901*da0073e9SAndroid Build Coastguard Worker            y.backward(grad)
10902*da0073e9SAndroid Build Coastguard Worker            # reference computation
10903*da0073e9SAndroid Build Coastguard Worker            y_ref = ref_module(x_ref)
10904*da0073e9SAndroid Build Coastguard Worker            y_ref.backward(grad)
10905*da0073e9SAndroid Build Coastguard Worker
10906*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_ref, y)
10907*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, x_ref.grad)
10908*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(module.bn.running_mean, ref_module.bn.running_mean)
10909*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(module.bn.running_var, ref_module.bn.running_var)
10910*da0073e9SAndroid Build Coastguard Worker
10911*da0073e9SAndroid Build Coastguard Worker    def test_zeros(self):
10912*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
10913*da0073e9SAndroid Build Coastguard Worker            __constants__ = ['d']
10914*da0073e9SAndroid Build Coastguard Worker
10915*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
10916*da0073e9SAndroid Build Coastguard Worker                super().__init__()
10917*da0073e9SAndroid Build Coastguard Worker                self.d = torch.device('cpu')
10918*da0073e9SAndroid Build Coastguard Worker
10919*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
10920*da0073e9SAndroid Build Coastguard Worker            def create(self):
10921*da0073e9SAndroid Build Coastguard Worker                return torch.zeros([1, 1, 2], dtype=torch.float, device=self.d, layout=torch.strided)
10922*da0073e9SAndroid Build Coastguard Worker
10923*da0073e9SAndroid Build Coastguard Worker        r = M().create()
10924*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r.dtype, torch.float)
10925*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.zeros([1, 1, 2], dtype=torch.float), r)
10926*da0073e9SAndroid Build Coastguard Worker
10927*da0073e9SAndroid Build Coastguard Worker        def fn():
10928*da0073e9SAndroid Build Coastguard Worker            return torch.zeros((1, 2, 3))
10929*da0073e9SAndroid Build Coastguard Worker
10930*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ())
10931*da0073e9SAndroid Build Coastguard Worker
10932*da0073e9SAndroid Build Coastguard Worker    def test_vararg_zeros(self):
10933*da0073e9SAndroid Build Coastguard Worker        def foo():
10934*da0073e9SAndroid Build Coastguard Worker            return torch.zeros(3, 4, 5, dtype=torch.int)
10935*da0073e9SAndroid Build Coastguard Worker
10936*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, ())
10937*da0073e9SAndroid Build Coastguard Worker
10938*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "the original version of test_rand")
10939*da0073e9SAndroid Build Coastguard Worker    def test_rand(self):
10940*da0073e9SAndroid Build Coastguard Worker        def test_rand():
10941*da0073e9SAndroid Build Coastguard Worker            a = torch.rand([3, 4])
10942*da0073e9SAndroid Build Coastguard Worker            return a + 1.0 - a
10943*da0073e9SAndroid Build Coastguard Worker
10944*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_rand, ())
10945*da0073e9SAndroid Build Coastguard Worker        fn = torch.jit.script(test_rand)
10946*da0073e9SAndroid Build Coastguard Worker        out = fn()
10947*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.dtype, torch.get_default_dtype())
10948*da0073e9SAndroid Build Coastguard Worker        g = fn.graph_for()
10949*da0073e9SAndroid Build Coastguard Worker        # Testing shape analysis correctly setting type
10950*da0073e9SAndroid Build Coastguard Worker        if GRAPH_EXECUTOR != ProfilingMode.SIMPLE:
10951*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("Double(*, *, requires_grad=0, device=cpu)") \
10952*da0073e9SAndroid Build Coastguard Worker                       .check_not("Float(*, *, requires_grad=0, device=cpu)").run(g)
10953*da0073e9SAndroid Build Coastguard Worker
10954*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
10955*da0073e9SAndroid Build Coastguard Worker        def randint():
10956*da0073e9SAndroid Build Coastguard Worker            return torch.randint(0, 5, [1, 2])
10957*da0073e9SAndroid Build Coastguard Worker        out = randint()
10958*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.dtype, torch.int64)
10959*da0073e9SAndroid Build Coastguard Worker        if GRAPH_EXECUTOR != ProfilingMode.SIMPLE:
10960*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("Long(*, *, requires_grad=0, device=cpu)") \
10961*da0073e9SAndroid Build Coastguard Worker                       .check_not("Float(*, *, requires_grad=0, device=cpu)") \
10962*da0073e9SAndroid Build Coastguard Worker                       .check_not("Double(*, *, requires_grad=0, device=cpu)") \
10963*da0073e9SAndroid Build Coastguard Worker                       .run(randint.graph_for())
10964*da0073e9SAndroid Build Coastguard Worker
10965*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "no CUDA")
10966*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "skip if profiling isn't enabled")
10967*da0073e9SAndroid Build Coastguard Worker    def test_autodiff_complex(self):
10968*da0073e9SAndroid Build Coastguard Worker        def foo(x: torch.Tensor, y: torch.Tensor, W: torch.Tensor):
10969*da0073e9SAndroid Build Coastguard Worker            return torch.exp(torch.mm(torch.complex(x, y), W.cfloat()))
10970*da0073e9SAndroid Build Coastguard Worker
10971*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
10972*da0073e9SAndroid Build Coastguard Worker        def jitted_foo(x: torch.Tensor, y: torch.Tensor, W: torch.Tensor):
10973*da0073e9SAndroid Build Coastguard Worker            return torch.exp(torch.mm(torch.complex(x, y), W.cfloat()))
10974*da0073e9SAndroid Build Coastguard Worker
10975*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(128, 16, dtype=torch.float32, device='cuda:0')
10976*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(128, 16, dtype=torch.float32, device='cuda:0')
10977*da0073e9SAndroid Build Coastguard Worker        W = torch.randn(16, 1, dtype=torch.float32, device='cuda:0', requires_grad=True)
10978*da0073e9SAndroid Build Coastguard Worker        W.data /= 4
10979*da0073e9SAndroid Build Coastguard Worker
10980*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode_for_profiling_tests():
10981*da0073e9SAndroid Build Coastguard Worker            for i in range(4):
10982*da0073e9SAndroid Build Coastguard Worker                self.assertTrue((foo(x, y, W).grad_fn is None) == (jitted_foo(x, y, W).grad_fn is None))
10983*da0073e9SAndroid Build Coastguard Worker
10984*da0073e9SAndroid Build Coastguard Worker
10985*da0073e9SAndroid Build Coastguard Worker    def test_linear_grad(self):
10986*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode_for_profiling_tests():
10987*da0073e9SAndroid Build Coastguard Worker            def t(x: torch.Tensor, w: torch.Tensor, b: Optional[torch.Tensor]):
10988*da0073e9SAndroid Build Coastguard Worker                return torch.nn.functional.linear(x, w, b)
10989*da0073e9SAndroid Build Coastguard Worker
10990*da0073e9SAndroid Build Coastguard Worker            x_init = torch.randn(4, 2)
10991*da0073e9SAndroid Build Coastguard Worker            w_init = torch.randn(3, 2)
10992*da0073e9SAndroid Build Coastguard Worker            b_init = torch.randn(3)
10993*da0073e9SAndroid Build Coastguard Worker            grad = torch.randn(4, 3)
10994*da0073e9SAndroid Build Coastguard Worker
10995*da0073e9SAndroid Build Coastguard Worker            with disable_autodiff_subgraph_inlining():
10996*da0073e9SAndroid Build Coastguard Worker                # script module
10997*da0073e9SAndroid Build Coastguard Worker                jit_t = torch.jit.script(t)
10998*da0073e9SAndroid Build Coastguard Worker
10999*da0073e9SAndroid Build Coastguard Worker                x = x_init.detach().requires_grad_()
11000*da0073e9SAndroid Build Coastguard Worker                w = w_init.detach().requires_grad_()
11001*da0073e9SAndroid Build Coastguard Worker                b = b_init.detach().requires_grad_()
11002*da0073e9SAndroid Build Coastguard Worker                x_ref = x_init.detach().requires_grad_()
11003*da0073e9SAndroid Build Coastguard Worker                w_ref = w_init.detach().requires_grad_()
11004*da0073e9SAndroid Build Coastguard Worker                b_ref = b_init.detach().requires_grad_()
11005*da0073e9SAndroid Build Coastguard Worker
11006*da0073e9SAndroid Build Coastguard Worker                # profiling/optimization runs
11007*da0073e9SAndroid Build Coastguard Worker                jit_o = jit_t(x, w, b)
11008*da0073e9SAndroid Build Coastguard Worker                jit_o.backward(grad)
11009*da0073e9SAndroid Build Coastguard Worker                jit_o = jit_t(x, w, b)
11010*da0073e9SAndroid Build Coastguard Worker                jit_o.backward(grad)
11011*da0073e9SAndroid Build Coastguard Worker
11012*da0073e9SAndroid Build Coastguard Worker                x.grad.zero_()
11013*da0073e9SAndroid Build Coastguard Worker                w.grad.zero_()
11014*da0073e9SAndroid Build Coastguard Worker                b.grad.zero_()
11015*da0073e9SAndroid Build Coastguard Worker                jit_o = jit_t(x, w, b)
11016*da0073e9SAndroid Build Coastguard Worker                jit_o.backward(grad)
11017*da0073e9SAndroid Build Coastguard Worker                o = t(x_ref, w_ref, b_ref)
11018*da0073e9SAndroid Build Coastguard Worker                o.backward(grad)
11019*da0073e9SAndroid Build Coastguard Worker
11020*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(jit_o, o)
11021*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x.grad, x_ref.grad)
11022*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(w.grad, w_ref.grad)
11023*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(b.grad, b_ref.grad)
11024*da0073e9SAndroid Build Coastguard Worker
11025*da0073e9SAndroid Build Coastguard Worker                x.grad.zero_()
11026*da0073e9SAndroid Build Coastguard Worker                w.grad.zero_()
11027*da0073e9SAndroid Build Coastguard Worker                x_ref.grad.zero_()
11028*da0073e9SAndroid Build Coastguard Worker                w_ref.grad.zero_()
11029*da0073e9SAndroid Build Coastguard Worker                jit_o = jit_t(x, w, None)
11030*da0073e9SAndroid Build Coastguard Worker                jit_o.backward(grad)
11031*da0073e9SAndroid Build Coastguard Worker                o = t(x_ref, w_ref, None)
11032*da0073e9SAndroid Build Coastguard Worker                o.backward(grad)
11033*da0073e9SAndroid Build Coastguard Worker
11034*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(jit_o, o)
11035*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x.grad, x_ref.grad)
11036*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(w.grad, w_ref.grad)
11037*da0073e9SAndroid Build Coastguard Worker
11038*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("TorchDynamo doesn't support profile")
11039*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "the profiling version of test_rand")
11040*da0073e9SAndroid Build Coastguard Worker    def test_rand_profiling(self):
11041*da0073e9SAndroid Build Coastguard Worker        def test_rand():
11042*da0073e9SAndroid Build Coastguard Worker            a = torch.rand([3, 4])
11043*da0073e9SAndroid Build Coastguard Worker            return a + 1.0 - a
11044*da0073e9SAndroid Build Coastguard Worker
11045*da0073e9SAndroid Build Coastguard Worker        # Testing shape analysis correctly setting type
11046*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode_for_profiling_tests():
11047*da0073e9SAndroid Build Coastguard Worker            with num_profiled_runs(1):
11048*da0073e9SAndroid Build Coastguard Worker                fn = torch.jit.script(test_rand)
11049*da0073e9SAndroid Build Coastguard Worker                out = fn()
11050*da0073e9SAndroid Build Coastguard Worker                graph_str = torch.jit.last_executed_optimized_graph()
11051*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out.dtype, torch.float)
11052*da0073e9SAndroid Build Coastguard Worker                FileCheck().check("Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)") \
11053*da0073e9SAndroid Build Coastguard Worker                           .check_not("Double(3, 4, strides=[4, 1], requires_grad=0, device=cpu)").run(graph_str)
11054*da0073e9SAndroid Build Coastguard Worker
11055*da0073e9SAndroid Build Coastguard Worker            # fn = self.checkScript(test_rand, ())
11056*da0073e9SAndroid Build Coastguard Worker            # out = fn()
11057*da0073e9SAndroid Build Coastguard Worker            # self.assertEqual(out.dtype, torch.float)
11058*da0073e9SAndroid Build Coastguard Worker
11059*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
11060*da0073e9SAndroid Build Coastguard Worker        def randint():
11061*da0073e9SAndroid Build Coastguard Worker            return torch.randint(0, 5, [1, 2])
11062*da0073e9SAndroid Build Coastguard Worker
11063*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode_for_profiling_tests():
11064*da0073e9SAndroid Build Coastguard Worker            with num_profiled_runs(1):
11065*da0073e9SAndroid Build Coastguard Worker                out = randint()
11066*da0073e9SAndroid Build Coastguard Worker                graph_str = torch.jit.last_executed_optimized_graph()
11067*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out.dtype, torch.int64)
11068*da0073e9SAndroid Build Coastguard Worker                FileCheck().check("profiled_type=Long(1, 2, strides=[2, 1], requires_grad=0, device=cpu)").run(graph_str)
11069*da0073e9SAndroid Build Coastguard Worker
11070*da0073e9SAndroid Build Coastguard Worker
11071*da0073e9SAndroid Build Coastguard Worker    def test_erase_number_types(self):
11072*da0073e9SAndroid Build Coastguard Worker        def func(a):
11073*da0073e9SAndroid Build Coastguard Worker            b = 7 + 1 + 3
11074*da0073e9SAndroid Build Coastguard Worker            c = a + b
11075*da0073e9SAndroid Build Coastguard Worker            c += b
11076*da0073e9SAndroid Build Coastguard Worker            return c
11077*da0073e9SAndroid Build Coastguard Worker
11078*da0073e9SAndroid Build Coastguard Worker        graph = torch.jit.script(func).graph
11079*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("int = prim::Constant").check("aten::add_").run(str(graph))
11080*da0073e9SAndroid Build Coastguard Worker        self.run_pass("erase_number_types", graph)
11081*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("int = prim::Constant").run(str(graph))
11082*da0073e9SAndroid Build Coastguard Worker
11083*da0073e9SAndroid Build Coastguard Worker    def test_refine_tuple_types(self):
11084*da0073e9SAndroid Build Coastguard Worker        # TupleConstruct output type is not correct here.
11085*da0073e9SAndroid Build Coastguard Worker        graph_str = """
11086*da0073e9SAndroid Build Coastguard Worker        graph(%a : Float(123), %b : Float(4, 5, 6)):
11087*da0073e9SAndroid Build Coastguard Worker          %c : (Tensor, Tensor) = prim::TupleConstruct(%a, %b)
11088*da0073e9SAndroid Build Coastguard Worker          return (%c)
11089*da0073e9SAndroid Build Coastguard Worker        """
11090*da0073e9SAndroid Build Coastguard Worker        graph = parse_ir(graph_str)
11091*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_refine_tuple_types(graph)
11092*da0073e9SAndroid Build Coastguard Worker
11093*da0073e9SAndroid Build Coastguard Worker        # After the pass, the output type should've been updated.
11094*da0073e9SAndroid Build Coastguard Worker        self.assertTrue('(Float(123), Float(4, 5, 6))' in str(graph.findNode('prim::TupleConstruct').output()))
11095*da0073e9SAndroid Build Coastguard Worker
11096*da0073e9SAndroid Build Coastguard Worker    # TODO(henrytu): Add test for RefineTypes for NamedTuple when it's supported by IR parser.
11097*da0073e9SAndroid Build Coastguard Worker
11098*da0073e9SAndroid Build Coastguard Worker    def test_remove_dropout(self):
11099*da0073e9SAndroid Build Coastguard Worker        weight_0_shape = (20, 5)
11100*da0073e9SAndroid Build Coastguard Worker        weight_1_shape = (20, 20)
11101*da0073e9SAndroid Build Coastguard Worker        input_shape = (10, 5)
11102*da0073e9SAndroid Build Coastguard Worker
11103*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
11104*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
11105*da0073e9SAndroid Build Coastguard Worker                super().__init__()
11106*da0073e9SAndroid Build Coastguard Worker                self.weight_0 = torch.nn.Parameter(torch.rand(weight_0_shape))
11107*da0073e9SAndroid Build Coastguard Worker                self.weight_1 = torch.nn.Parameter(torch.rand(weight_1_shape))
11108*da0073e9SAndroid Build Coastguard Worker
11109*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
11110*da0073e9SAndroid Build Coastguard Worker                o = F.linear(x, self.weight_0)
11111*da0073e9SAndroid Build Coastguard Worker                o = F.dropout(o, training=self.training)
11112*da0073e9SAndroid Build Coastguard Worker                o = F.linear(o, self.weight_1)
11113*da0073e9SAndroid Build Coastguard Worker                return o
11114*da0073e9SAndroid Build Coastguard Worker
11115*da0073e9SAndroid Build Coastguard Worker        data = torch.rand(input_shape)
11116*da0073e9SAndroid Build Coastguard Worker        m = M()
11117*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(m)
11118*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'Dropout removal module in training mode is not yet supported'):
11119*da0073e9SAndroid Build Coastguard Worker            torch._C._jit_pass_remove_dropout(m._c)
11120*da0073e9SAndroid Build Coastguard Worker        m.eval()
11121*da0073e9SAndroid Build Coastguard Worker        ref_res = m(data)
11122*da0073e9SAndroid Build Coastguard Worker        # Need to inline otherwise we see instances of Function.
11123*da0073e9SAndroid Build Coastguard Worker        # We would have to use torch.linear/dropout to get around it otherwise.
11124*da0073e9SAndroid Build Coastguard Worker        from torch.jit._recursive import wrap_cpp_module
11125*da0073e9SAndroid Build Coastguard Worker        m = wrap_cpp_module(torch._C._freeze_module(m._c))
11126*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_remove_dropout(m._c)
11127*da0073e9SAndroid Build Coastguard Worker        res = m(data)
11128*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("aten::dropout").run(str(m.graph))
11129*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(ref_res, res, rtol=1e-2, atol=1e-3)
11130*da0073e9SAndroid Build Coastguard Worker
11131*da0073e9SAndroid Build Coastguard Worker    def test_unfold_zero_dim(self):
11132*da0073e9SAndroid Build Coastguard Worker        def fn(x):
11133*da0073e9SAndroid Build Coastguard Worker            return x.unfold(0, 1, 1)
11134*da0073e9SAndroid Build Coastguard Worker
11135*da0073e9SAndroid Build Coastguard Worker        graph = torch.jit.script(fn).graph
11136*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_complete_shape_analysis(graph, (torch.tensor(0.39),), False)
11137*da0073e9SAndroid Build Coastguard Worker        out_dims = fn(torch.tensor(0.3923)).ndim
11138*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(graph.findNode("aten::unfold").output().type().dim(), out_dims)
11139*da0073e9SAndroid Build Coastguard Worker
11140*da0073e9SAndroid Build Coastguard Worker    def test_mm_batching(self):
11141*da0073e9SAndroid Build Coastguard Worker
11142*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode_for_profiling_tests():
11143*da0073e9SAndroid Build Coastguard Worker            lstm_cell = torch.jit.script(LSTMCellS)
11144*da0073e9SAndroid Build Coastguard Worker
11145*da0073e9SAndroid Build Coastguard Worker            def lstm(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
11146*da0073e9SAndroid Build Coastguard Worker                for i in range(x.size(0)):
11147*da0073e9SAndroid Build Coastguard Worker                    hx, cx = lstm_cell(x[i], hx, cx, w_ih, w_hh, b_ih, b_hh)
11148*da0073e9SAndroid Build Coastguard Worker                return hx
11149*da0073e9SAndroid Build Coastguard Worker
11150*da0073e9SAndroid Build Coastguard Worker            slstm = torch.jit.script(lstm)
11151*da0073e9SAndroid Build Coastguard Worker
11152*da0073e9SAndroid Build Coastguard Worker            inputs = get_lstm_inputs('cpu', training=True, seq_length=10)
11153*da0073e9SAndroid Build Coastguard Worker            slstm(*inputs, profile_and_replay=True).sum().backward(retain_graph=True)
11154*da0073e9SAndroid Build Coastguard Worker            if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
11155*da0073e9SAndroid Build Coastguard Worker                slstm(*inputs, profile_and_replay=True).sum().backward()
11156*da0073e9SAndroid Build Coastguard Worker
11157*da0073e9SAndroid Build Coastguard Worker            fw_graph = slstm.graph_for(*inputs)
11158*da0073e9SAndroid Build Coastguard Worker            if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
11159*da0073e9SAndroid Build Coastguard Worker                bw_graph = backward_graph(slstm, diff_graph_idx=0)
11160*da0073e9SAndroid Build Coastguard Worker                self.assertTrue('prim::MMBatchSide' in str(fw_graph))
11161*da0073e9SAndroid Build Coastguard Worker                self.assertTrue('prim::MMTreeReduce' in str(bw_graph))
11162*da0073e9SAndroid Build Coastguard Worker
11163*da0073e9SAndroid Build Coastguard Worker            sout = slstm(*inputs)
11164*da0073e9SAndroid Build Coastguard Worker            out = lstm(*inputs)
11165*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(sout, out)
11166*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.autograd.grad(sout.sum(), inputs),
11167*da0073e9SAndroid Build Coastguard Worker                             torch.autograd.grad(out.sum(), inputs))
11168*da0073e9SAndroid Build Coastguard Worker
11169*da0073e9SAndroid Build Coastguard Worker    def test_loop_unrolling(self):
11170*da0073e9SAndroid Build Coastguard Worker        def fn(x):
11171*da0073e9SAndroid Build Coastguard Worker            y = 0
11172*da0073e9SAndroid Build Coastguard Worker            for i in range(int(x)):
11173*da0073e9SAndroid Build Coastguard Worker                y -= i
11174*da0073e9SAndroid Build Coastguard Worker            return y
11175*da0073e9SAndroid Build Coastguard Worker
11176*da0073e9SAndroid Build Coastguard Worker        graph = torch.jit.script(fn).graph
11177*da0073e9SAndroid Build Coastguard Worker        self.run_pass('loop_unrolling', graph)
11178*da0073e9SAndroid Build Coastguard Worker        unroll_factor = 8
11179*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("prim::Loop").check_count("aten::sub", unroll_factor) \
11180*da0073e9SAndroid Build Coastguard Worker            .check("prim::Loop").check("aten::sub").run(str(graph))
11181*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (torch.tensor(10),))
11182*da0073e9SAndroid Build Coastguard Worker
11183*da0073e9SAndroid Build Coastguard Worker    def test_loop_unrolling_const(self):
11184*da0073e9SAndroid Build Coastguard Worker        def fn():
11185*da0073e9SAndroid Build Coastguard Worker            y = 0
11186*da0073e9SAndroid Build Coastguard Worker            for _ in range(10):
11187*da0073e9SAndroid Build Coastguard Worker                y -= 1
11188*da0073e9SAndroid Build Coastguard Worker            return y
11189*da0073e9SAndroid Build Coastguard Worker
11190*da0073e9SAndroid Build Coastguard Worker        def fn2():
11191*da0073e9SAndroid Build Coastguard Worker            y = 0
11192*da0073e9SAndroid Build Coastguard Worker            for i in range(10):
11193*da0073e9SAndroid Build Coastguard Worker                y -= i
11194*da0073e9SAndroid Build Coastguard Worker            return y
11195*da0073e9SAndroid Build Coastguard Worker
11196*da0073e9SAndroid Build Coastguard Worker        def check(fn, name):
11197*da0073e9SAndroid Build Coastguard Worker            graph = torch.jit.script(fn).graph
11198*da0073e9SAndroid Build Coastguard Worker            self.run_pass('loop_unrolling', graph)
11199*da0073e9SAndroid Build Coastguard Worker            # entirely unrolled
11200*da0073e9SAndroid Build Coastguard Worker            FileCheck().check_not("prim::Loop'").run(str(graph))
11201*da0073e9SAndroid Build Coastguard Worker            self.checkScript(fn, ())
11202*da0073e9SAndroid Build Coastguard Worker
11203*da0073e9SAndroid Build Coastguard Worker        check(fn, 'add_const')
11204*da0073e9SAndroid Build Coastguard Worker        check(fn2, 'add_iter')
11205*da0073e9SAndroid Build Coastguard Worker
11206*da0073e9SAndroid Build Coastguard Worker    def test_loop_unrolling_nested(self):
11207*da0073e9SAndroid Build Coastguard Worker        def fn(x):
11208*da0073e9SAndroid Build Coastguard Worker            y = 0
11209*da0073e9SAndroid Build Coastguard Worker            for _ in range(10):
11210*da0073e9SAndroid Build Coastguard Worker                for j in range(int(x)):
11211*da0073e9SAndroid Build Coastguard Worker                    y -= j
11212*da0073e9SAndroid Build Coastguard Worker            return y
11213*da0073e9SAndroid Build Coastguard Worker
11214*da0073e9SAndroid Build Coastguard Worker        graph = torch.jit.script(fn).graph
11215*da0073e9SAndroid Build Coastguard Worker        self.run_pass('loop_unrolling', graph)
11216*da0073e9SAndroid Build Coastguard Worker        # inner loop with 8 subs followed by loop epilogue
11217*da0073e9SAndroid Build Coastguard Worker        unroll_factor = 8
11218*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("prim::Loop").check("prim::Loop").check_count('aten::sub', unroll_factor) \
11219*da0073e9SAndroid Build Coastguard Worker            .check("prim::Loop").check("aten::sub").run(str(graph))
11220*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (torch.tensor(10),))
11221*da0073e9SAndroid Build Coastguard Worker
11222*da0073e9SAndroid Build Coastguard Worker    def test_loop_unroll_unused_counter(self):
11223*da0073e9SAndroid Build Coastguard Worker        def fn(x):
11224*da0073e9SAndroid Build Coastguard Worker            y = 0
11225*da0073e9SAndroid Build Coastguard Worker            for _ in range(int(x)):
11226*da0073e9SAndroid Build Coastguard Worker                y -= 1
11227*da0073e9SAndroid Build Coastguard Worker            return y
11228*da0073e9SAndroid Build Coastguard Worker
11229*da0073e9SAndroid Build Coastguard Worker        graph = torch.jit.script(fn).graph
11230*da0073e9SAndroid Build Coastguard Worker        self.run_pass('loop_unrolling', graph)
11231*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("prim::Loop").check_not("aten::add").check("return") \
11232*da0073e9SAndroid Build Coastguard Worker            .run(str(graph))
11233*da0073e9SAndroid Build Coastguard Worker
11234*da0073e9SAndroid Build Coastguard Worker    def test_loop_unroll_negative(self):
11235*da0073e9SAndroid Build Coastguard Worker        def fn(x):
11236*da0073e9SAndroid Build Coastguard Worker            y = 0
11237*da0073e9SAndroid Build Coastguard Worker            for _ in range(int(x)):
11238*da0073e9SAndroid Build Coastguard Worker                y += 1
11239*da0073e9SAndroid Build Coastguard Worker            return y
11240*da0073e9SAndroid Build Coastguard Worker
11241*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (torch.tensor(-20),))
11242*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (torch.tensor(-2),))
11243*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (torch.tensor(-1),))
11244*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (torch.tensor(0),))
11245*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (torch.tensor(1),))
11246*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (torch.tensor(2),))
11247*da0073e9SAndroid Build Coastguard Worker
11248*da0073e9SAndroid Build Coastguard Worker    def test_where(self):
11249*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
11250*da0073e9SAndroid Build Coastguard Worker            return torch.where(x > 0.0, x, y)
11251*da0073e9SAndroid Build Coastguard Worker
11252*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (torch.randn(3, 2, dtype=torch.float), torch.ones(3, 2, dtype=torch.float)))
11253*da0073e9SAndroid Build Coastguard Worker
11254*da0073e9SAndroid Build Coastguard Worker    def test_where_method(self):
11255*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
11256*da0073e9SAndroid Build Coastguard Worker            return x.where(x > 0.0, y)
11257*da0073e9SAndroid Build Coastguard Worker
11258*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (torch.randn(3, 2, dtype=torch.float), torch.ones(3, 2, dtype=torch.float)))
11259*da0073e9SAndroid Build Coastguard Worker
11260*da0073e9SAndroid Build Coastguard Worker    def test_union_to_number(self):
11261*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
11262*da0073e9SAndroid Build Coastguard Worker        def fn(x: Union[int, complex, float], y: Union[int, complex, float]):
11263*da0073e9SAndroid Build Coastguard Worker            return x + y
11264*da0073e9SAndroid Build Coastguard Worker        FileCheck().check(": Scalar):").run(fn.graph)
11265*da0073e9SAndroid Build Coastguard Worker
11266*da0073e9SAndroid Build Coastguard Worker    def test_reassign_module_lhs(self):
11267*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'Cannot re-assign \'self\''):
11268*da0073e9SAndroid Build Coastguard Worker            class ReassignSelfLHS(torch.jit.ScriptModule):
11269*da0073e9SAndroid Build Coastguard Worker                @torch.jit.script_method
11270*da0073e9SAndroid Build Coastguard Worker                def forward(self, x):
11271*da0073e9SAndroid Build Coastguard Worker                    for _ in range(20):
11272*da0073e9SAndroid Build Coastguard Worker                        self = x
11273*da0073e9SAndroid Build Coastguard Worker                    return self
11274*da0073e9SAndroid Build Coastguard Worker
11275*da0073e9SAndroid Build Coastguard Worker            ReassignSelfLHS()
11276*da0073e9SAndroid Build Coastguard Worker
11277*da0073e9SAndroid Build Coastguard Worker    def test_reassign_module_rhs(self):
11278*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'Cannot re-assign \'x\' to a value of type module'):
11279*da0073e9SAndroid Build Coastguard Worker            class ReassignSelfRHS(torch.jit.ScriptModule):
11280*da0073e9SAndroid Build Coastguard Worker                @torch.jit.script_method
11281*da0073e9SAndroid Build Coastguard Worker                def forward(self, x):
11282*da0073e9SAndroid Build Coastguard Worker                    for _ in range(20):
11283*da0073e9SAndroid Build Coastguard Worker                        x = self
11284*da0073e9SAndroid Build Coastguard Worker                    return self
11285*da0073e9SAndroid Build Coastguard Worker
11286*da0073e9SAndroid Build Coastguard Worker            ReassignSelfRHS()
11287*da0073e9SAndroid Build Coastguard Worker
11288*da0073e9SAndroid Build Coastguard Worker    def test_unknown_builtin(self):
11289*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'object has no attribute or method'):
11290*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
11291*da0073e9SAndroid Build Coastguard Worker            def unknown_builtin(x):
11292*da0073e9SAndroid Build Coastguard Worker                return x.splork(3)
11293*da0073e9SAndroid Build Coastguard Worker
11294*da0073e9SAndroid Build Coastguard Worker    def test_return_tuple(self):
11295*da0073e9SAndroid Build Coastguard Worker        def return_tuple(x):
11296*da0073e9SAndroid Build Coastguard Worker            a = (x, x)
11297*da0073e9SAndroid Build Coastguard Worker            return a, x
11298*da0073e9SAndroid Build Coastguard Worker        self.checkScript(return_tuple, (torch.rand(4),))
11299*da0073e9SAndroid Build Coastguard Worker
11300*da0073e9SAndroid Build Coastguard Worker    def test_add_tuple_optional(self):
11301*da0073e9SAndroid Build Coastguard Worker        def foo(input: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]) -> Optional[torch.Tensor]:
11302*da0073e9SAndroid Build Coastguard Worker            changed_input = input[0] + 1
11303*da0073e9SAndroid Build Coastguard Worker            value: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]] = (changed_input,) + input[1:]
11304*da0073e9SAndroid Build Coastguard Worker            return value[2]
11305*da0073e9SAndroid Build Coastguard Worker        inp: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]] = (torch.rand(4), None, None)
11306*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, (inp,))
11307*da0073e9SAndroid Build Coastguard Worker
11308*da0073e9SAndroid Build Coastguard Worker    def test_add_tuple_non_optional(self):
11309*da0073e9SAndroid Build Coastguard Worker        def foo(input: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> torch.Tensor:
11310*da0073e9SAndroid Build Coastguard Worker            changed_input = input[0] + 1
11311*da0073e9SAndroid Build Coastguard Worker            value: Tuple[torch.Tensor, torch.Tensor, torch.Tensor] = (changed_input,) + input[1:]
11312*da0073e9SAndroid Build Coastguard Worker            return torch.sum(value[2]) + 4
11313*da0073e9SAndroid Build Coastguard Worker        inp: Tuple[torch.Tensor, torch.Tensor, torch.Tensor] = (torch.rand(4), torch.rand(4), torch.rand(4))
11314*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, (inp,))
11315*da0073e9SAndroid Build Coastguard Worker
11316*da0073e9SAndroid Build Coastguard Worker    def test_add_tuple_different_types(self):
11317*da0073e9SAndroid Build Coastguard Worker        def foo(a: Tuple[int, float], b: Tuple[int]) -> int:
11318*da0073e9SAndroid Build Coastguard Worker            c: Tuple[int, float, int] = a + b
11319*da0073e9SAndroid Build Coastguard Worker            d: Tuple[int, float, int, int] = c + b
11320*da0073e9SAndroid Build Coastguard Worker            return d[3] + 1
11321*da0073e9SAndroid Build Coastguard Worker        a = (1, 2.0)
11322*da0073e9SAndroid Build Coastguard Worker        b = (3,)
11323*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, (a, b))
11324*da0073e9SAndroid Build Coastguard Worker
11325*da0073e9SAndroid Build Coastguard Worker    def test_add_tuple_same_types(self):
11326*da0073e9SAndroid Build Coastguard Worker        def foo(a: Tuple[int, int], b: Tuple[int, int, int]) -> int:
11327*da0073e9SAndroid Build Coastguard Worker            c: Tuple[int, int, int, int, int] = a + b
11328*da0073e9SAndroid Build Coastguard Worker            d: Tuple[int, int, int, int, int, int, int, int] = c + b
11329*da0073e9SAndroid Build Coastguard Worker            return d[6] - 2
11330*da0073e9SAndroid Build Coastguard Worker        a = (1, 2)
11331*da0073e9SAndroid Build Coastguard Worker        b = (3, 4, 5)
11332*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, (a, b))
11333*da0073e9SAndroid Build Coastguard Worker
11334*da0073e9SAndroid Build Coastguard Worker    def test_method_no_self(self):
11335*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'methods must have a self argument'):
11336*da0073e9SAndroid Build Coastguard Worker            class MethodNoSelf(torch.jit.ScriptModule):
11337*da0073e9SAndroid Build Coastguard Worker                @torch.jit.script_method  # noqa: B902
11338*da0073e9SAndroid Build Coastguard Worker                def forward():  # noqa: B902
11339*da0073e9SAndroid Build Coastguard Worker                    return torch.zeros(3, 4)
11340*da0073e9SAndroid Build Coastguard Worker
11341*da0073e9SAndroid Build Coastguard Worker            MethodNoSelf()
11342*da0073e9SAndroid Build Coastguard Worker
11343*da0073e9SAndroid Build Coastguard Worker    def test_return_stmt_not_at_end(self):
11344*da0073e9SAndroid Build Coastguard Worker        def return_stmt(x):
11345*da0073e9SAndroid Build Coastguard Worker            if bool(x > 3):
11346*da0073e9SAndroid Build Coastguard Worker                return x + 3
11347*da0073e9SAndroid Build Coastguard Worker            else:
11348*da0073e9SAndroid Build Coastguard Worker                return x
11349*da0073e9SAndroid Build Coastguard Worker        self.checkScript(return_stmt, (torch.rand(1),))
11350*da0073e9SAndroid Build Coastguard Worker
11351*da0073e9SAndroid Build Coastguard Worker    def test_for_in_range(self):
11352*da0073e9SAndroid Build Coastguard Worker        def fn():
11353*da0073e9SAndroid Build Coastguard Worker            c = 0
11354*da0073e9SAndroid Build Coastguard Worker            for i in range(100):
11355*da0073e9SAndroid Build Coastguard Worker                c += i
11356*da0073e9SAndroid Build Coastguard Worker            return c
11357*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ())
11358*da0073e9SAndroid Build Coastguard Worker
11359*da0073e9SAndroid Build Coastguard Worker    def test_for_in_range_dynamic(self):
11360*da0073e9SAndroid Build Coastguard Worker        def fn():
11361*da0073e9SAndroid Build Coastguard Worker            c = 0
11362*da0073e9SAndroid Build Coastguard Worker            for i in range(100):
11363*da0073e9SAndroid Build Coastguard Worker                acc = 0
11364*da0073e9SAndroid Build Coastguard Worker                for j in range(i):
11365*da0073e9SAndroid Build Coastguard Worker                    acc += j
11366*da0073e9SAndroid Build Coastguard Worker                c += acc
11367*da0073e9SAndroid Build Coastguard Worker            return c
11368*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (), optimize=False)
11369*da0073e9SAndroid Build Coastguard Worker
11370*da0073e9SAndroid Build Coastguard Worker    def test_for_in_range_ast(self):
11371*da0073e9SAndroid Build Coastguard Worker        def test_script_for_in_range_ast():
11372*da0073e9SAndroid Build Coastguard Worker            c = 0
11373*da0073e9SAndroid Build Coastguard Worker            for i in range(100):
11374*da0073e9SAndroid Build Coastguard Worker                acc = 0
11375*da0073e9SAndroid Build Coastguard Worker                for j in range(i):
11376*da0073e9SAndroid Build Coastguard Worker                    acc += j
11377*da0073e9SAndroid Build Coastguard Worker                c += acc
11378*da0073e9SAndroid Build Coastguard Worker            return c
11379*da0073e9SAndroid Build Coastguard Worker
11380*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_script_for_in_range_ast, ())
11381*da0073e9SAndroid Build Coastguard Worker
11382*da0073e9SAndroid Build Coastguard Worker    def test_for_in_range_if_ast(self):
11383*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
11384*da0073e9SAndroid Build Coastguard Worker        def test_script_for_in_range_if_ast(x):
11385*da0073e9SAndroid Build Coastguard Worker            output = x
11386*da0073e9SAndroid Build Coastguard Worker            for i in range(20):
11387*da0073e9SAndroid Build Coastguard Worker                if i == 0:
11388*da0073e9SAndroid Build Coastguard Worker                    output = x.unsqueeze(0)
11389*da0073e9SAndroid Build Coastguard Worker                else:
11390*da0073e9SAndroid Build Coastguard Worker                    output = torch.cat((output, x.unsqueeze(0)), dim=0)
11391*da0073e9SAndroid Build Coastguard Worker            return output
11392*da0073e9SAndroid Build Coastguard Worker        inputs = self._make_scalar_vars([0], torch.int64)
11393*da0073e9SAndroid Build Coastguard Worker
11394*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(test_script_for_in_range_if_ast(*inputs).shape[0], 20)
11395*da0073e9SAndroid Build Coastguard Worker
11396*da0073e9SAndroid Build Coastguard Worker    def test_for_in_range_start_end(self):
11397*da0073e9SAndroid Build Coastguard Worker        def fn():
11398*da0073e9SAndroid Build Coastguard Worker            x = 0
11399*da0073e9SAndroid Build Coastguard Worker            for i in range(7, 100):
11400*da0073e9SAndroid Build Coastguard Worker                x += i
11401*da0073e9SAndroid Build Coastguard Worker            return x
11402*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ())
11403*da0073e9SAndroid Build Coastguard Worker
11404*da0073e9SAndroid Build Coastguard Worker    def test_for_in_range_start_end_step(self):
11405*da0073e9SAndroid Build Coastguard Worker        def fn(start, end, step):
11406*da0073e9SAndroid Build Coastguard Worker            # type: (int, int, int) -> int
11407*da0073e9SAndroid Build Coastguard Worker            x = 0
11408*da0073e9SAndroid Build Coastguard Worker            for i in range(start, end, step):
11409*da0073e9SAndroid Build Coastguard Worker                x += i
11410*da0073e9SAndroid Build Coastguard Worker            return x
11411*da0073e9SAndroid Build Coastguard Worker
11412*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (7, 100, 7))
11413*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (7, 100, -7))
11414*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (2, -11, -3))
11415*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (2, -11, 3))
11416*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (2, 10, 3))
11417*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (-2, -10, -10))
11418*da0073e9SAndroid Build Coastguard Worker
11419*da0073e9SAndroid Build Coastguard Worker    def test_for_in_range_zero_step(self):
11420*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
11421*da0073e9SAndroid Build Coastguard Worker        def fn():
11422*da0073e9SAndroid Build Coastguard Worker            x = 0
11423*da0073e9SAndroid Build Coastguard Worker            for i in range(2, -11, 0):
11424*da0073e9SAndroid Build Coastguard Worker                x += i
11425*da0073e9SAndroid Build Coastguard Worker            return x
11426*da0073e9SAndroid Build Coastguard Worker
11427*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "must not be zero"):
11428*da0073e9SAndroid Build Coastguard Worker            fn()
11429*da0073e9SAndroid Build Coastguard Worker
11430*da0073e9SAndroid Build Coastguard Worker    def test_range_args(self):
11431*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'range expected at least 1 arguments, got 0'):
11432*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
11433*da0073e9SAndroid Build Coastguard Worker            def range_no_arg(x):
11434*da0073e9SAndroid Build Coastguard Worker                for _ in range():
11435*da0073e9SAndroid Build Coastguard Worker                    x += 1
11436*da0073e9SAndroid Build Coastguard Worker                return x
11437*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'found float'):
11438*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
11439*da0073e9SAndroid Build Coastguard Worker            def range_non_float():
11440*da0073e9SAndroid Build Coastguard Worker                for i in range(.5):
11441*da0073e9SAndroid Build Coastguard Worker                    print(i)
11442*da0073e9SAndroid Build Coastguard Worker
11443*da0073e9SAndroid Build Coastguard Worker    def test_parse_empty_tuple_annotation(self):
11444*da0073e9SAndroid Build Coastguard Worker        cu = torch.jit.CompilationUnit('''
11445*da0073e9SAndroid Build Coastguard Worker            def foo(x : Tuple[()]) -> Tuple[()]:
11446*da0073e9SAndroid Build Coastguard Worker                return x
11447*da0073e9SAndroid Build Coastguard Worker        ''')
11448*da0073e9SAndroid Build Coastguard Worker
11449*da0073e9SAndroid Build Coastguard Worker        foo_code = cu.find_function('foo').code
11450*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("Tuple[()]").check("Tuple[()]").run(foo_code)
11451*da0073e9SAndroid Build Coastguard Worker
11452*da0073e9SAndroid Build Coastguard Worker    def test_parse_empty_tuple_annotation_element_error(self):
11453*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
11454*da0073e9SAndroid Build Coastguard Worker                RuntimeError, 'Tuple literal in Tuple type annotation must not have any elements'):
11455*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit('''
11456*da0073e9SAndroid Build Coastguard Worker                def foo(x : Tuple[(int,)]) -> Tuple[(int,)]:
11457*da0073e9SAndroid Build Coastguard Worker                    return x
11458*da0073e9SAndroid Build Coastguard Worker            ''')
11459*da0073e9SAndroid Build Coastguard Worker
11460*da0073e9SAndroid Build Coastguard Worker    def test_parse_none_type_annotation(self):
11461*da0073e9SAndroid Build Coastguard Worker        cu = torch.jit.CompilationUnit('''
11462*da0073e9SAndroid Build Coastguard Worker            def foo(x : NoneType) -> NoneType:
11463*da0073e9SAndroid Build Coastguard Worker                return x
11464*da0073e9SAndroid Build Coastguard Worker        ''')
11465*da0073e9SAndroid Build Coastguard Worker
11466*da0073e9SAndroid Build Coastguard Worker        foo_code = cu.find_function('foo').code
11467*da0073e9SAndroid Build Coastguard Worker        FileCheck().check(": NoneType").check("-> NoneType").run(foo_code)
11468*da0073e9SAndroid Build Coastguard Worker
11469*da0073e9SAndroid Build Coastguard Worker    def test_empty_tuple_str(self):
11470*da0073e9SAndroid Build Coastguard Worker        empty_tuple_type = torch._C.TupleType([])
11471*da0073e9SAndroid Build Coastguard Worker        g = {'Tuple' : typing.Tuple}
11472*da0073e9SAndroid Build Coastguard Worker        python_type = eval(empty_tuple_type.annotation_str, g)
11473*da0073e9SAndroid Build Coastguard Worker        assert python_type is typing.Tuple[()]
11474*da0073e9SAndroid Build Coastguard Worker
11475*da0073e9SAndroid Build Coastguard Worker    def test_tuple_str(self):
11476*da0073e9SAndroid Build Coastguard Worker        tuple1_type = torch._C.TupleType([torch._C.StringType.get()])
11477*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tuple1_type.annotation_str, "Tuple[str]")
11478*da0073e9SAndroid Build Coastguard Worker        tuple2_type = torch._C.TupleType([torch._C.StringType.get(), torch._C.StringType.get()])
11479*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tuple2_type.annotation_str, "Tuple[str, str]")
11480*da0073e9SAndroid Build Coastguard Worker
11481*da0073e9SAndroid Build Coastguard Worker    def test_dict_str(self):
11482*da0073e9SAndroid Build Coastguard Worker        dict_type = torch._C.DictType(torch._C.StringType.get(), torch._C.StringType.get())
11483*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dict_type.annotation_str, "Dict[str, str]")
11484*da0073e9SAndroid Build Coastguard Worker
11485*da0073e9SAndroid Build Coastguard Worker    def test_none_type_str(self):
11486*da0073e9SAndroid Build Coastguard Worker        none_type = torch._C.NoneType.get()
11487*da0073e9SAndroid Build Coastguard Worker        g = {'NoneType' : type(None)}
11488*da0073e9SAndroid Build Coastguard Worker        python_type = eval(none_type.annotation_str, g)
11489*da0073e9SAndroid Build Coastguard Worker        assert python_type is type(None)
11490*da0073e9SAndroid Build Coastguard Worker
11491*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
11492*da0073e9SAndroid Build Coastguard Worker    def test_zip_enumerate_modulelist(self):
11493*da0073e9SAndroid Build Coastguard Worker        class Sub(torch.nn.Module):
11494*da0073e9SAndroid Build Coastguard Worker            def forward(self, thing):
11495*da0073e9SAndroid Build Coastguard Worker                return thing - 2
11496*da0073e9SAndroid Build Coastguard Worker
11497*da0073e9SAndroid Build Coastguard Worker        class Double(torch.nn.Module):
11498*da0073e9SAndroid Build Coastguard Worker            def forward(self, thing):
11499*da0073e9SAndroid Build Coastguard Worker                return thing * 2
11500*da0073e9SAndroid Build Coastguard Worker
11501*da0073e9SAndroid Build Coastguard Worker        # zipping over two
11502*da0073e9SAndroid Build Coastguard Worker        class ZipModLists(torch.nn.Module):
11503*da0073e9SAndroid Build Coastguard Worker            def __init__(self, mods, mods2):
11504*da0073e9SAndroid Build Coastguard Worker                super().__init__()
11505*da0073e9SAndroid Build Coastguard Worker                self.mods = mods
11506*da0073e9SAndroid Build Coastguard Worker                self.mods2 = mods2
11507*da0073e9SAndroid Build Coastguard Worker
11508*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
11509*da0073e9SAndroid Build Coastguard Worker                iter = 0
11510*da0073e9SAndroid Build Coastguard Worker                for mod1, mod2 in zip(self.mods, self.mods2):
11511*da0073e9SAndroid Build Coastguard Worker                    x = mod2(mod1(x))
11512*da0073e9SAndroid Build Coastguard Worker                    iter += 1
11513*da0073e9SAndroid Build Coastguard Worker                return x, iter
11514*da0073e9SAndroid Build Coastguard Worker
11515*da0073e9SAndroid Build Coastguard Worker        class ZipWithValues(torch.nn.Module):
11516*da0073e9SAndroid Build Coastguard Worker            __constants__ = ['tup_larger', 'tup_smaller']
11517*da0073e9SAndroid Build Coastguard Worker
11518*da0073e9SAndroid Build Coastguard Worker            def __init__(self, mods, mods2):
11519*da0073e9SAndroid Build Coastguard Worker                super().__init__()
11520*da0073e9SAndroid Build Coastguard Worker                self.mods = mods
11521*da0073e9SAndroid Build Coastguard Worker                self.mods2 = mods2
11522*da0073e9SAndroid Build Coastguard Worker                self.tup_larger = list(range(len(mods2) + 1))
11523*da0073e9SAndroid Build Coastguard Worker                self.tup_smaller = list(range(max(len(mods2) + 1, 1)))
11524*da0073e9SAndroid Build Coastguard Worker
11525*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
11526*da0073e9SAndroid Build Coastguard Worker                iter = 0
11527*da0073e9SAndroid Build Coastguard Worker                x2 = x
11528*da0073e9SAndroid Build Coastguard Worker                for val, mod1, mod2 in zip(self.tup_larger, self.mods, self.mods2):
11529*da0073e9SAndroid Build Coastguard Worker                    x = mod2(mod1(x)) + val
11530*da0073e9SAndroid Build Coastguard Worker                    iter += 1
11531*da0073e9SAndroid Build Coastguard Worker                for val, mod1, mod2 in zip(self.tup_smaller, self.mods, self.mods2):
11532*da0073e9SAndroid Build Coastguard Worker                    x2 = mod2(mod1(x2)) + val
11533*da0073e9SAndroid Build Coastguard Worker                    iter += 1
11534*da0073e9SAndroid Build Coastguard Worker                return x, iter
11535*da0073e9SAndroid Build Coastguard Worker
11536*da0073e9SAndroid Build Coastguard Worker        mods = nn.ModuleList([Double()]), nn.ModuleList([Double(), Sub(), Sub()]), nn.ModuleList([Sub(), Double()])
11537*da0073e9SAndroid Build Coastguard Worker        for i in range(len(mods)):
11538*da0073e9SAndroid Build Coastguard Worker            for j in range(len(mods)):
11539*da0073e9SAndroid Build Coastguard Worker                mod = ZipModLists(mods[i], mods[j])
11540*da0073e9SAndroid Build Coastguard Worker                self.checkModule(mod, (torch.tensor(.5),))
11541*da0073e9SAndroid Build Coastguard Worker                mod2 = ZipWithValues(mods[i], mods[j])
11542*da0073e9SAndroid Build Coastguard Worker                self.checkModule(mod2, (torch.tensor(.5),))
11543*da0073e9SAndroid Build Coastguard Worker
11544*da0073e9SAndroid Build Coastguard Worker
11545*da0073e9SAndroid Build Coastguard Worker    def test_enumerate_modlist_range(self):
11546*da0073e9SAndroid Build Coastguard Worker        class Double(torch.nn.Module):
11547*da0073e9SAndroid Build Coastguard Worker            def forward(self, thing):
11548*da0073e9SAndroid Build Coastguard Worker                return thing * 2
11549*da0073e9SAndroid Build Coastguard Worker
11550*da0073e9SAndroid Build Coastguard Worker        class Mod(torch.nn.Module):
11551*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
11552*da0073e9SAndroid Build Coastguard Worker                super().__init__()
11553*da0073e9SAndroid Build Coastguard Worker                self.mods = nn.ModuleList([Double(), Double()])
11554*da0073e9SAndroid Build Coastguard Worker
11555*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
11556*da0073e9SAndroid Build Coastguard Worker                x2 = x
11557*da0073e9SAndroid Build Coastguard Worker                iter = 0
11558*da0073e9SAndroid Build Coastguard Worker                for val, mod in enumerate(self.mods):
11559*da0073e9SAndroid Build Coastguard Worker                    x2 = mod(x2) * val
11560*da0073e9SAndroid Build Coastguard Worker                    iter += 1
11561*da0073e9SAndroid Build Coastguard Worker                return iter, x, x2
11562*da0073e9SAndroid Build Coastguard Worker
11563*da0073e9SAndroid Build Coastguard Worker        self.checkModule(Mod(), (torch.tensor(.5),))
11564*da0073e9SAndroid Build Coastguard Worker
11565*da0073e9SAndroid Build Coastguard Worker        # variable length, modulelist
11566*da0073e9SAndroid Build Coastguard Worker        class Mod2(Mod):
11567*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
11568*da0073e9SAndroid Build Coastguard Worker                for val, mod in zip(range(int(x)), self.mods):
11569*da0073e9SAndroid Build Coastguard Worker                    x = mod(x) * val
11570*da0073e9SAndroid Build Coastguard Worker                return x
11571*da0073e9SAndroid Build Coastguard Worker
11572*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(Exception, "that does not have a statically determinable length"):
11573*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(Mod2())
11574*da0073e9SAndroid Build Coastguard Worker
11575*da0073e9SAndroid Build Coastguard Worker        # modulelist, variable length
11576*da0073e9SAndroid Build Coastguard Worker        class Mod3(Mod):
11577*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
11578*da0073e9SAndroid Build Coastguard Worker                for val, mod in zip(self.mods, range(int(x))):
11579*da0073e9SAndroid Build Coastguard Worker                    x = mod(x) * val
11580*da0073e9SAndroid Build Coastguard Worker                return x
11581*da0073e9SAndroid Build Coastguard Worker
11582*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(Exception, "that does not have a statically determinable length"):
11583*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(Mod3())
11584*da0073e9SAndroid Build Coastguard Worker
11585*da0073e9SAndroid Build Coastguard Worker    def test_for_in_enumerate(self):
11586*da0073e9SAndroid Build Coastguard Worker        def fn(x):
11587*da0073e9SAndroid Build Coastguard Worker            # type: (List[int]) -> int
11588*da0073e9SAndroid Build Coastguard Worker            sum = 0
11589*da0073e9SAndroid Build Coastguard Worker            for (i, v) in enumerate(x):
11590*da0073e9SAndroid Build Coastguard Worker                sum += i * v
11591*da0073e9SAndroid Build Coastguard Worker
11592*da0073e9SAndroid Build Coastguard Worker            return sum
11593*da0073e9SAndroid Build Coastguard Worker
11594*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ([1, 2, 3, 4, 5],))
11595*da0073e9SAndroid Build Coastguard Worker
11596*da0073e9SAndroid Build Coastguard Worker        def fn_enumerate_start_arg(x):
11597*da0073e9SAndroid Build Coastguard Worker            # type: (List[int]) -> int
11598*da0073e9SAndroid Build Coastguard Worker            sum = 0
11599*da0073e9SAndroid Build Coastguard Worker            for (i, v) in enumerate(x, 1):
11600*da0073e9SAndroid Build Coastguard Worker                sum += i * v
11601*da0073e9SAndroid Build Coastguard Worker
11602*da0073e9SAndroid Build Coastguard Worker            return sum
11603*da0073e9SAndroid Build Coastguard Worker
11604*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn_enumerate_start_arg, ([1, 2, 3, 4, 5],))
11605*da0073e9SAndroid Build Coastguard Worker
11606*da0073e9SAndroid Build Coastguard Worker        def fn_enumerate_start_kwarg(x):
11607*da0073e9SAndroid Build Coastguard Worker            # type: (List[int]) -> int
11608*da0073e9SAndroid Build Coastguard Worker            sum = 0
11609*da0073e9SAndroid Build Coastguard Worker            for (i, v) in enumerate(x, start=1):
11610*da0073e9SAndroid Build Coastguard Worker                sum += i * v
11611*da0073e9SAndroid Build Coastguard Worker
11612*da0073e9SAndroid Build Coastguard Worker            return sum
11613*da0073e9SAndroid Build Coastguard Worker
11614*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn_enumerate_start_kwarg, ([1, 2, 3, 4, 5],))
11615*da0073e9SAndroid Build Coastguard Worker
11616*da0073e9SAndroid Build Coastguard Worker        def fn_nested_enumerate(x):
11617*da0073e9SAndroid Build Coastguard Worker            # type: (List[int]) -> int
11618*da0073e9SAndroid Build Coastguard Worker            sum = 0
11619*da0073e9SAndroid Build Coastguard Worker            for (i, (j, v)) in enumerate(enumerate(x)):
11620*da0073e9SAndroid Build Coastguard Worker                sum += i * j * v
11621*da0073e9SAndroid Build Coastguard Worker
11622*da0073e9SAndroid Build Coastguard Worker            return sum
11623*da0073e9SAndroid Build Coastguard Worker
11624*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn_nested_enumerate, ([1, 2, 3, 4, 5],))
11625*da0073e9SAndroid Build Coastguard Worker
11626*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'enumerate expected at least 1 arguments, got 0'):
11627*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
11628*da0073e9SAndroid Build Coastguard Worker            def enumerate_no_arg(x):
11629*da0073e9SAndroid Build Coastguard Worker                # type: (List[int]) -> int
11630*da0073e9SAndroid Build Coastguard Worker                sum = 0
11631*da0073e9SAndroid Build Coastguard Worker                for _ in enumerate():
11632*da0073e9SAndroid Build Coastguard Worker                    sum += 1
11633*da0073e9SAndroid Build Coastguard Worker
11634*da0073e9SAndroid Build Coastguard Worker                return sum
11635*da0073e9SAndroid Build Coastguard Worker
11636*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'enumerate expected at most 2 arguments, got 3'):
11637*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
11638*da0073e9SAndroid Build Coastguard Worker            def enumerate_too_many_args(x):
11639*da0073e9SAndroid Build Coastguard Worker                # type: (List[int]) -> int
11640*da0073e9SAndroid Build Coastguard Worker                sum = 0
11641*da0073e9SAndroid Build Coastguard Worker                for _ in enumerate(x, x, x):
11642*da0073e9SAndroid Build Coastguard Worker                    sum += 1
11643*da0073e9SAndroid Build Coastguard Worker
11644*da0073e9SAndroid Build Coastguard Worker                return sum
11645*da0073e9SAndroid Build Coastguard Worker
11646*da0073e9SAndroid Build Coastguard Worker    def test_list_comprehension_modulelist(self):
11647*da0073e9SAndroid Build Coastguard Worker        class Inner(torch.nn.Module):
11648*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
11649*da0073e9SAndroid Build Coastguard Worker                return x + 10
11650*da0073e9SAndroid Build Coastguard Worker
11651*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
11652*da0073e9SAndroid Build Coastguard Worker            def __init__(self, mod_list):
11653*da0073e9SAndroid Build Coastguard Worker                super().__init__()
11654*da0073e9SAndroid Build Coastguard Worker                self.module_list = mod_list
11655*da0073e9SAndroid Build Coastguard Worker
11656*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
11657*da0073e9SAndroid Build Coastguard Worker                out = torch.jit.annotate(List[Tensor], [mod(x) for mod in self.module_list])
11658*da0073e9SAndroid Build Coastguard Worker                return out
11659*da0073e9SAndroid Build Coastguard Worker
11660*da0073e9SAndroid Build Coastguard Worker        mod = M(nn.ModuleList([Inner(), Inner()]))
11661*da0073e9SAndroid Build Coastguard Worker        self.checkModule(mod, (torch.tensor(3),))
11662*da0073e9SAndroid Build Coastguard Worker
11663*da0073e9SAndroid Build Coastguard Worker        mod = M(nn.ModuleList([]))
11664*da0073e9SAndroid Build Coastguard Worker        torch.jit.script(mod)
11665*da0073e9SAndroid Build Coastguard Worker
11666*da0073e9SAndroid Build Coastguard Worker        class M2(M):
11667*da0073e9SAndroid Build Coastguard Worker            def __init__(self, mod_list):
11668*da0073e9SAndroid Build Coastguard Worker                super().__init__(mod_list)
11669*da0073e9SAndroid Build Coastguard Worker
11670*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
11671*da0073e9SAndroid Build Coastguard Worker                out = [mod(x) for mod in self.module_list]
11672*da0073e9SAndroid Build Coastguard Worker                return out
11673*da0073e9SAndroid Build Coastguard Worker
11674*da0073e9SAndroid Build Coastguard Worker        mod = M2(nn.ModuleList([Inner(), Inner()]))
11675*da0073e9SAndroid Build Coastguard Worker        self.checkModule(mod, (torch.tensor(3),))
11676*da0073e9SAndroid Build Coastguard Worker
11677*da0073e9SAndroid Build Coastguard Worker        mod = M2(nn.ModuleList([]))
11678*da0073e9SAndroid Build Coastguard Worker        # defaults to List of Tensor for empty modulelist
11679*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.jit.script(mod)(torch.tensor(.5)), [])
11680*da0073e9SAndroid Build Coastguard Worker
11681*da0073e9SAndroid Build Coastguard Worker        def bad_type_annotation():
11682*da0073e9SAndroid Build Coastguard Worker            out = torch.jit.annotate(int, [x for x in [1, 2, 3]])  # noqa: C416
11683*da0073e9SAndroid Build Coastguard Worker            return out
11684*da0073e9SAndroid Build Coastguard Worker
11685*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(Exception, "Expected an annotation"
11686*da0073e9SAndroid Build Coastguard Worker                                    " of type List"):
11687*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(bad_type_annotation)
11688*da0073e9SAndroid Build Coastguard Worker
11689*da0073e9SAndroid Build Coastguard Worker    def test_list_comprehension_variable_write(self):
11690*da0073e9SAndroid Build Coastguard Worker        # i in comprehension doesn't write to function scope
11691*da0073e9SAndroid Build Coastguard Worker        def foo():
11692*da0073e9SAndroid Build Coastguard Worker            i = 1
11693*da0073e9SAndroid Build Coastguard Worker            x = [i if i != 5 else 3 for i in range(7)]  # noqa: C416
11694*da0073e9SAndroid Build Coastguard Worker            return i, x
11695*da0073e9SAndroid Build Coastguard Worker
11696*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(), torch.jit.script(foo)())
11697*da0073e9SAndroid Build Coastguard Worker
11698*da0073e9SAndroid Build Coastguard Worker    def test_for_in_zip(self):
11699*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
11700*da0073e9SAndroid Build Coastguard Worker            # type: (List[int], List[int]) -> int
11701*da0073e9SAndroid Build Coastguard Worker            sum = 0
11702*da0073e9SAndroid Build Coastguard Worker            for (i, j) in zip(x, y):
11703*da0073e9SAndroid Build Coastguard Worker                sum += i * j
11704*da0073e9SAndroid Build Coastguard Worker
11705*da0073e9SAndroid Build Coastguard Worker            return sum
11706*da0073e9SAndroid Build Coastguard Worker
11707*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ([1, 2, 3, 4, 5], [2, 3, 4, 5, 6]))
11708*da0073e9SAndroid Build Coastguard Worker
11709*da0073e9SAndroid Build Coastguard Worker        def fn_multi_inputs(x, y, z):
11710*da0073e9SAndroid Build Coastguard Worker            # type: (List[int], List[int], List[int]) -> int
11711*da0073e9SAndroid Build Coastguard Worker            sum = 0
11712*da0073e9SAndroid Build Coastguard Worker            for (i, j, k) in zip(x, y, z):
11713*da0073e9SAndroid Build Coastguard Worker                sum += i * j * k
11714*da0073e9SAndroid Build Coastguard Worker
11715*da0073e9SAndroid Build Coastguard Worker            return sum
11716*da0073e9SAndroid Build Coastguard Worker
11717*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn_multi_inputs, ([1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6]))
11718*da0073e9SAndroid Build Coastguard Worker
11719*da0073e9SAndroid Build Coastguard Worker        def fn_nested_zip(x, y, z):
11720*da0073e9SAndroid Build Coastguard Worker            # type: (List[int], List[int], List[int]) -> int
11721*da0073e9SAndroid Build Coastguard Worker            sum = 0
11722*da0073e9SAndroid Build Coastguard Worker            for (i, (j, k)) in zip(x, zip(y, z)):
11723*da0073e9SAndroid Build Coastguard Worker                sum += i * j * k
11724*da0073e9SAndroid Build Coastguard Worker
11725*da0073e9SAndroid Build Coastguard Worker            return sum
11726*da0073e9SAndroid Build Coastguard Worker
11727*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn_multi_inputs, ([1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6]))
11728*da0073e9SAndroid Build Coastguard Worker
11729*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'zip expected at least 1 arguments, got 0'):
11730*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
11731*da0073e9SAndroid Build Coastguard Worker            def zip_no_arg(x):
11732*da0073e9SAndroid Build Coastguard Worker                # type: (List[int]) -> int
11733*da0073e9SAndroid Build Coastguard Worker                sum = 0
11734*da0073e9SAndroid Build Coastguard Worker                for _ in zip():
11735*da0073e9SAndroid Build Coastguard Worker                    sum += 1
11736*da0073e9SAndroid Build Coastguard Worker
11737*da0073e9SAndroid Build Coastguard Worker                return sum
11738*da0073e9SAndroid Build Coastguard Worker
11739*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'too many values to unpack: need 2 but found 3'):
11740*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
11741*da0073e9SAndroid Build Coastguard Worker            def fn_nested_zip_wrong_target_assign(x, y, z):
11742*da0073e9SAndroid Build Coastguard Worker                # type: (List[int], List[int], List[int]) -> int
11743*da0073e9SAndroid Build Coastguard Worker                sum = 0
11744*da0073e9SAndroid Build Coastguard Worker                for (i, (j, k)) in zip(x, y, z):
11745*da0073e9SAndroid Build Coastguard Worker                    sum += i * j * k
11746*da0073e9SAndroid Build Coastguard Worker
11747*da0073e9SAndroid Build Coastguard Worker                return sum
11748*da0073e9SAndroid Build Coastguard Worker
11749*da0073e9SAndroid Build Coastguard Worker    def test_for_in_zip_enumerate(self):
11750*da0073e9SAndroid Build Coastguard Worker        def fn_zip_enumerate(x, y):
11751*da0073e9SAndroid Build Coastguard Worker            # type: (List[int], List[int]) -> int
11752*da0073e9SAndroid Build Coastguard Worker            sum = 0
11753*da0073e9SAndroid Build Coastguard Worker            for (i, (j, v), k) in zip(x, enumerate(y), range(0, 100)):
11754*da0073e9SAndroid Build Coastguard Worker                sum += i * j * v * k
11755*da0073e9SAndroid Build Coastguard Worker
11756*da0073e9SAndroid Build Coastguard Worker            return sum
11757*da0073e9SAndroid Build Coastguard Worker
11758*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn_zip_enumerate, ([1, 2, 3, 4], [2, 3, 4, 5]))
11759*da0073e9SAndroid Build Coastguard Worker
11760*da0073e9SAndroid Build Coastguard Worker        def fn_enumerate_zip(x, y):
11761*da0073e9SAndroid Build Coastguard Worker            # type: (List[int], List[int]) -> int
11762*da0073e9SAndroid Build Coastguard Worker            sum = 0
11763*da0073e9SAndroid Build Coastguard Worker            for (i, (j, v)) in enumerate(zip(x, y)):
11764*da0073e9SAndroid Build Coastguard Worker                sum += i * j * v
11765*da0073e9SAndroid Build Coastguard Worker
11766*da0073e9SAndroid Build Coastguard Worker            return sum
11767*da0073e9SAndroid Build Coastguard Worker
11768*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn_enumerate_zip, ([1, 2, 3, 4], [2, 3, 4, 5]))
11769*da0073e9SAndroid Build Coastguard Worker
11770*da0073e9SAndroid Build Coastguard Worker    def test_for_in_tensors(self):
11771*da0073e9SAndroid Build Coastguard Worker        def test_sizes(x):
11772*da0073e9SAndroid Build Coastguard Worker            sumz = 0
11773*da0073e9SAndroid Build Coastguard Worker            for s in x:
11774*da0073e9SAndroid Build Coastguard Worker                sumz += 1
11775*da0073e9SAndroid Build Coastguard Worker            return sumz
11776*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_sizes, (torch.rand(5, 4, 3, 2, 1),))
11777*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_sizes, (torch.rand(777),))
11778*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_sizes, (torch.rand(0),))
11779*da0073e9SAndroid Build Coastguard Worker
11780*da0073e9SAndroid Build Coastguard Worker    def test_for_in_tensors_rank0(self):
11781*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "of a 0-d tensor"):
11782*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
11783*da0073e9SAndroid Build Coastguard Worker            def test_sizes(x):
11784*da0073e9SAndroid Build Coastguard Worker                sumz = 0
11785*da0073e9SAndroid Build Coastguard Worker                for s in x:
11786*da0073e9SAndroid Build Coastguard Worker                    sumz += 1
11787*da0073e9SAndroid Build Coastguard Worker                return sumz
11788*da0073e9SAndroid Build Coastguard Worker
11789*da0073e9SAndroid Build Coastguard Worker            test_sizes(torch.tensor(1))
11790*da0073e9SAndroid Build Coastguard Worker
11791*da0073e9SAndroid Build Coastguard Worker    def test_for_in_tensors_fail_scalar(self):
11792*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "'float' object is not iterable"):
11793*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
11794*da0073e9SAndroid Build Coastguard Worker            def test_sizes(x):
11795*da0073e9SAndroid Build Coastguard Worker                # type: (float) -> int
11796*da0073e9SAndroid Build Coastguard Worker                sumz = 0
11797*da0073e9SAndroid Build Coastguard Worker                for s in x:
11798*da0073e9SAndroid Build Coastguard Worker                    sumz += 1
11799*da0073e9SAndroid Build Coastguard Worker                return sumz
11800*da0073e9SAndroid Build Coastguard Worker
11801*da0073e9SAndroid Build Coastguard Worker            test_sizes(0.0)
11802*da0073e9SAndroid Build Coastguard Worker
11803*da0073e9SAndroid Build Coastguard Worker    def test_for_in_tensors_nested(self):
11804*da0073e9SAndroid Build Coastguard Worker        def test_sizes(x):
11805*da0073e9SAndroid Build Coastguard Worker            sumz = 0
11806*da0073e9SAndroid Build Coastguard Worker            for n in x:
11807*da0073e9SAndroid Build Coastguard Worker                for t in n:
11808*da0073e9SAndroid Build Coastguard Worker                    sumz += 1
11809*da0073e9SAndroid Build Coastguard Worker            return sumz
11810*da0073e9SAndroid Build Coastguard Worker
11811*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_sizes, (torch.rand(5, 4, 3, 2, 1),))
11812*da0073e9SAndroid Build Coastguard Worker
11813*da0073e9SAndroid Build Coastguard Worker    # to avoid defining sum_list in multiple tests
11814*da0073e9SAndroid Build Coastguard Worker    def get_sum_list_fn(self):
11815*da0073e9SAndroid Build Coastguard Worker        def sum_list(a):
11816*da0073e9SAndroid Build Coastguard Worker            # type: (List[int]) -> int
11817*da0073e9SAndroid Build Coastguard Worker            sum = 0
11818*da0073e9SAndroid Build Coastguard Worker            for i in a:
11819*da0073e9SAndroid Build Coastguard Worker                sum += i
11820*da0073e9SAndroid Build Coastguard Worker
11821*da0073e9SAndroid Build Coastguard Worker            return sum
11822*da0073e9SAndroid Build Coastguard Worker
11823*da0073e9SAndroid Build Coastguard Worker        return sum_list
11824*da0073e9SAndroid Build Coastguard Worker
11825*da0073e9SAndroid Build Coastguard Worker    def test_sum_list_diff_elms(self):
11826*da0073e9SAndroid Build Coastguard Worker        self.checkScript(self.get_sum_list_fn(), ([1, 2, 3, 4, 5],))
11827*da0073e9SAndroid Build Coastguard Worker
11828*da0073e9SAndroid Build Coastguard Worker    def test_sum_list_empty(self):
11829*da0073e9SAndroid Build Coastguard Worker        self.checkScript(self.get_sum_list_fn(), ([],))
11830*da0073e9SAndroid Build Coastguard Worker
11831*da0073e9SAndroid Build Coastguard Worker    def test_sum_list_one(self):
11832*da0073e9SAndroid Build Coastguard Worker        self.checkScript(self.get_sum_list_fn(), ([1],))
11833*da0073e9SAndroid Build Coastguard Worker
11834*da0073e9SAndroid Build Coastguard Worker    def test_sum_list_literal(self):
11835*da0073e9SAndroid Build Coastguard Worker
11836*da0073e9SAndroid Build Coastguard Worker        def sum_list():
11837*da0073e9SAndroid Build Coastguard Worker            # type: () -> int
11838*da0073e9SAndroid Build Coastguard Worker            sum = 0
11839*da0073e9SAndroid Build Coastguard Worker            for i in [1, 2, 3, 4, 5]:
11840*da0073e9SAndroid Build Coastguard Worker                sum += i
11841*da0073e9SAndroid Build Coastguard Worker
11842*da0073e9SAndroid Build Coastguard Worker            return sum
11843*da0073e9SAndroid Build Coastguard Worker
11844*da0073e9SAndroid Build Coastguard Worker        self.checkScript(sum_list, ())
11845*da0073e9SAndroid Build Coastguard Worker
11846*da0073e9SAndroid Build Coastguard Worker    def test_sum_list_wrong_type(self):
11847*da0073e9SAndroid Build Coastguard Worker
11848*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"):
11849*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
11850*da0073e9SAndroid Build Coastguard Worker            def sum_list(a):
11851*da0073e9SAndroid Build Coastguard Worker                # type: (int) -> int
11852*da0073e9SAndroid Build Coastguard Worker                sum = 0
11853*da0073e9SAndroid Build Coastguard Worker                for i in a:  # noqa: T484
11854*da0073e9SAndroid Build Coastguard Worker                    sum += i
11855*da0073e9SAndroid Build Coastguard Worker
11856*da0073e9SAndroid Build Coastguard Worker                return sum
11857*da0073e9SAndroid Build Coastguard Worker
11858*da0073e9SAndroid Build Coastguard Worker            sum_list(1)
11859*da0073e9SAndroid Build Coastguard Worker
11860*da0073e9SAndroid Build Coastguard Worker    def test_list_iterables(self):
11861*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'List of iterables is not supported currently'):
11862*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit('''
11863*da0073e9SAndroid Build Coastguard Worker            def list_iterables(x):
11864*da0073e9SAndroid Build Coastguard Worker                for i, j in [2, 3, 4], [5, 6, 7]:
11865*da0073e9SAndroid Build Coastguard Worker                    x += i
11866*da0073e9SAndroid Build Coastguard Worker                    x += j
11867*da0073e9SAndroid Build Coastguard Worker                return x
11868*da0073e9SAndroid Build Coastguard Worker            ''')
11869*da0073e9SAndroid Build Coastguard Worker
11870*da0073e9SAndroid Build Coastguard Worker    def test_for_in_string(self):
11871*da0073e9SAndroid Build Coastguard Worker        def test_strings(x):
11872*da0073e9SAndroid Build Coastguard Worker            # type: (str) -> str
11873*da0073e9SAndroid Build Coastguard Worker            reverse = ""
11874*da0073e9SAndroid Build Coastguard Worker            for c in x:
11875*da0073e9SAndroid Build Coastguard Worker                reverse = c + reverse
11876*da0073e9SAndroid Build Coastguard Worker            return reverse
11877*da0073e9SAndroid Build Coastguard Worker
11878*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_strings, ("hello",))
11879*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_strings, ("",))
11880*da0073e9SAndroid Build Coastguard Worker
11881*da0073e9SAndroid Build Coastguard Worker        def test_list_strings(x):
11882*da0073e9SAndroid Build Coastguard Worker            # type: (List[str]) -> str
11883*da0073e9SAndroid Build Coastguard Worker            result = ""
11884*da0073e9SAndroid Build Coastguard Worker            for sub_str in x:
11885*da0073e9SAndroid Build Coastguard Worker                result += sub_str
11886*da0073e9SAndroid Build Coastguard Worker            return result
11887*da0073e9SAndroid Build Coastguard Worker
11888*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_list_strings, (["hello", "world"],))
11889*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_list_strings, (["hello", " ", "world", ""],))
11890*da0073e9SAndroid Build Coastguard Worker
11891*da0073e9SAndroid Build Coastguard Worker    def test_for_in_dict(self):
11892*da0073e9SAndroid Build Coastguard Worker        def test_dicts(x):
11893*da0073e9SAndroid Build Coastguard Worker            # type: (Dict[str, int]) -> int
11894*da0073e9SAndroid Build Coastguard Worker            sum = 0
11895*da0073e9SAndroid Build Coastguard Worker            for key in x:
11896*da0073e9SAndroid Build Coastguard Worker                sum += x[key]
11897*da0073e9SAndroid Build Coastguard Worker            return sum
11898*da0073e9SAndroid Build Coastguard Worker
11899*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_dicts, ({"a": 1, "b": 2, "c": 3},))
11900*da0073e9SAndroid Build Coastguard Worker
11901*da0073e9SAndroid Build Coastguard Worker        def test_dict_keys_values(x):
11902*da0073e9SAndroid Build Coastguard Worker            # type: (Dict[str, int]) -> Tuple[str, int]
11903*da0073e9SAndroid Build Coastguard Worker            key_str = ""
11904*da0073e9SAndroid Build Coastguard Worker            sum = 0
11905*da0073e9SAndroid Build Coastguard Worker            for key in x.keys():
11906*da0073e9SAndroid Build Coastguard Worker                key_str += key
11907*da0073e9SAndroid Build Coastguard Worker            for val in x.values():
11908*da0073e9SAndroid Build Coastguard Worker                sum += val
11909*da0073e9SAndroid Build Coastguard Worker            return key_str, sum
11910*da0073e9SAndroid Build Coastguard Worker
11911*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_dicts, ({"a": 1, "b": 2, "c": 3},))
11912*da0073e9SAndroid Build Coastguard Worker
11913*da0073e9SAndroid Build Coastguard Worker    def test_for_tuple_unpack(self):
11914*da0073e9SAndroid Build Coastguard Worker        def for_tuple_unpack(x, y):
11915*da0073e9SAndroid Build Coastguard Worker            for i, j in [[3, 4], [5, 6], [7, 8]]:
11916*da0073e9SAndroid Build Coastguard Worker                x += i
11917*da0073e9SAndroid Build Coastguard Worker                y += j
11918*da0073e9SAndroid Build Coastguard Worker            return x, y
11919*da0073e9SAndroid Build Coastguard Worker
11920*da0073e9SAndroid Build Coastguard Worker        self.checkScript(for_tuple_unpack, (torch.tensor(3), torch.tensor(5)))
11921*da0073e9SAndroid Build Coastguard Worker
11922*da0073e9SAndroid Build Coastguard Worker        def nested_tuple_unpack(x, y):
11923*da0073e9SAndroid Build Coastguard Worker            # type: (List[int], List[int]) -> int
11924*da0073e9SAndroid Build Coastguard Worker            sum = 0
11925*da0073e9SAndroid Build Coastguard Worker            for i, (j, k), v in zip(x, enumerate(x), y):
11926*da0073e9SAndroid Build Coastguard Worker                sum += i + j + k + v
11927*da0073e9SAndroid Build Coastguard Worker            return sum
11928*da0073e9SAndroid Build Coastguard Worker
11929*da0073e9SAndroid Build Coastguard Worker        self.checkScript(nested_tuple_unpack, ([1, 3, 5], [2, 4, 6]))
11930*da0073e9SAndroid Build Coastguard Worker
11931*da0073e9SAndroid Build Coastguard Worker    def test_for_tuple_assign(self):
11932*da0073e9SAndroid Build Coastguard Worker        def test_simple_assign(x):
11933*da0073e9SAndroid Build Coastguard Worker            # type: (Tuple[int, float]) -> float
11934*da0073e9SAndroid Build Coastguard Worker            sum = 0.0
11935*da0073e9SAndroid Build Coastguard Worker            for a in x:
11936*da0073e9SAndroid Build Coastguard Worker                sum += float(a)
11937*da0073e9SAndroid Build Coastguard Worker            return sum
11938*da0073e9SAndroid Build Coastguard Worker
11939*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_simple_assign, ((1, 2.5),))
11940*da0073e9SAndroid Build Coastguard Worker
11941*da0073e9SAndroid Build Coastguard Worker        def test_tuple_assign(x):
11942*da0073e9SAndroid Build Coastguard Worker            # type: (Tuple[Tuple[int, int], Tuple[int, int]]) -> int
11943*da0073e9SAndroid Build Coastguard Worker            sum = 0
11944*da0073e9SAndroid Build Coastguard Worker            for a in x:
11945*da0073e9SAndroid Build Coastguard Worker                sum += a[0]
11946*da0073e9SAndroid Build Coastguard Worker                sum += a[1]
11947*da0073e9SAndroid Build Coastguard Worker            return sum
11948*da0073e9SAndroid Build Coastguard Worker
11949*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_tuple_assign, (((1, 2), (4, 7)), ))
11950*da0073e9SAndroid Build Coastguard Worker
11951*da0073e9SAndroid Build Coastguard Worker    def test_single_starred_lhs(self):
11952*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'A Starred expression may only appear on the lhs within the presence'
11953*da0073e9SAndroid Build Coastguard Worker                                                  ' of another non-starred expression'):
11954*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit('''
11955*da0073e9SAndroid Build Coastguard Worker            def single_starred_lhs(x):
11956*da0073e9SAndroid Build Coastguard Worker                a = (x, x, x)
11957*da0073e9SAndroid Build Coastguard Worker                *b, = a
11958*da0073e9SAndroid Build Coastguard Worker                return b
11959*da0073e9SAndroid Build Coastguard Worker            ''')
11960*da0073e9SAndroid Build Coastguard Worker
11961*da0073e9SAndroid Build Coastguard Worker    def test_singleton_tuple_unpack(self):
11962*da0073e9SAndroid Build Coastguard Worker        def foo(a):
11963*da0073e9SAndroid Build Coastguard Worker            b, = (a,)
11964*da0073e9SAndroid Build Coastguard Worker            return b + 1
11965*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, (torch.rand(3),))
11966*da0073e9SAndroid Build Coastguard Worker
11967*da0073e9SAndroid Build Coastguard Worker    def test_tuple_assignments(self):
11968*da0073e9SAndroid Build Coastguard Worker        def var_tuple_assign(x, y):
11969*da0073e9SAndroid Build Coastguard Worker            # type: (Tuple[Tensor, Tensor], Tensor) -> Tensor
11970*da0073e9SAndroid Build Coastguard Worker            (a, b), c = x, y
11971*da0073e9SAndroid Build Coastguard Worker            return a + b + c
11972*da0073e9SAndroid Build Coastguard Worker
11973*da0073e9SAndroid Build Coastguard Worker        tuple_inputs = (torch.randn(1, 4), torch.randn(3, 4))
11974*da0073e9SAndroid Build Coastguard Worker        self.checkScript(var_tuple_assign, (tuple_inputs, torch.randn(3, 4)))
11975*da0073e9SAndroid Build Coastguard Worker
11976*da0073e9SAndroid Build Coastguard Worker        def nested_tuple_assign(x, y, z):
11977*da0073e9SAndroid Build Coastguard Worker            # type: (int, Tuple[int, Tuple[int, int]], Tuple[int, int]) -> int
11978*da0073e9SAndroid Build Coastguard Worker            a, (b, (c, d)), (e, f) = x, y, z
11979*da0073e9SAndroid Build Coastguard Worker            return a + b + c + d + e + f
11980*da0073e9SAndroid Build Coastguard Worker
11981*da0073e9SAndroid Build Coastguard Worker        self.checkScript(nested_tuple_assign, ((1, (2, (3, 4)), (5, 6))))
11982*da0073e9SAndroid Build Coastguard Worker
11983*da0073e9SAndroid Build Coastguard Worker        def subscript_tuple_assign(a, x, i):
11984*da0073e9SAndroid Build Coastguard Worker            # type: (List[int], Tensor, int) -> Tuple[int, Tensor, int]
11985*da0073e9SAndroid Build Coastguard Worker            a[i], (x[i], b) = 1, (2, 3)
11986*da0073e9SAndroid Build Coastguard Worker            return a[i] + 1, x + 5, b
11987*da0073e9SAndroid Build Coastguard Worker
11988*da0073e9SAndroid Build Coastguard Worker        self.checkScript(subscript_tuple_assign, ([12, 7, 9, 11], torch.tensor((3, 13, 17)), 0))
11989*da0073e9SAndroid Build Coastguard Worker
11990*da0073e9SAndroid Build Coastguard Worker        def star_tuple_assign():
11991*da0073e9SAndroid Build Coastguard Worker            # type: () -> Tuple[int, int, Tuple[int, int], Tuple[int, int]]
11992*da0073e9SAndroid Build Coastguard Worker            a, (b, *c), *d = 1, (2, 3, 4), 5, 6
11993*da0073e9SAndroid Build Coastguard Worker            return a, b, c, d
11994*da0073e9SAndroid Build Coastguard Worker
11995*da0073e9SAndroid Build Coastguard Worker        self.checkScript(star_tuple_assign, ())
11996*da0073e9SAndroid Build Coastguard Worker
11997*da0073e9SAndroid Build Coastguard Worker        def subscript_tuple_augmented_assign(a):
11998*da0073e9SAndroid Build Coastguard Worker            # type: (Tuple[int, int]) -> Tuple[int, int]
11999*da0073e9SAndroid Build Coastguard Worker            a[0] += 1
12000*da0073e9SAndroid Build Coastguard Worker            return a
12001*da0073e9SAndroid Build Coastguard Worker
12002*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'does not support augmented assign'):
12003*da0073e9SAndroid Build Coastguard Worker            scripted_aug_assign = torch.jit.script(subscript_tuple_augmented_assign)
12004*da0073e9SAndroid Build Coastguard Worker
12005*da0073e9SAndroid Build Coastguard Worker        class AttrTupleAssignmentTestClass:
12006*da0073e9SAndroid Build Coastguard Worker            def __init__(self, a: int, b: int):
12007*da0073e9SAndroid Build Coastguard Worker                self.a = a
12008*da0073e9SAndroid Build Coastguard Worker                self.b = b
12009*da0073e9SAndroid Build Coastguard Worker
12010*da0073e9SAndroid Build Coastguard Worker            def set_ab(self, a: int, b: int):
12011*da0073e9SAndroid Build Coastguard Worker                self.a, self.b = (a, b)
12012*da0073e9SAndroid Build Coastguard Worker
12013*da0073e9SAndroid Build Coastguard Worker            def get(self) -> Tuple[int, int]:
12014*da0073e9SAndroid Build Coastguard Worker                return (self.a, self.b)
12015*da0073e9SAndroid Build Coastguard Worker
12016*da0073e9SAndroid Build Coastguard Worker        make_global(AttrTupleAssignmentTestClass)
12017*da0073e9SAndroid Build Coastguard Worker
12018*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
12019*da0073e9SAndroid Build Coastguard Worker        def attr_tuple_assignment(o: AttrTupleAssignmentTestClass, a: int, b: int):
12020*da0073e9SAndroid Build Coastguard Worker            o.set_ab(a, b)
12021*da0073e9SAndroid Build Coastguard Worker            return o
12022*da0073e9SAndroid Build Coastguard Worker
12023*da0073e9SAndroid Build Coastguard Worker        o = AttrTupleAssignmentTestClass(1, 2)
12024*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(attr_tuple_assignment(o, 3, 4).get(), (3, 4))
12025*da0073e9SAndroid Build Coastguard Worker
12026*da0073e9SAndroid Build Coastguard Worker    def test_multiple_assign(self):
12027*da0073e9SAndroid Build Coastguard Worker        def test():
12028*da0073e9SAndroid Build Coastguard Worker            a = b, c = d, f = (1, 1)
12029*da0073e9SAndroid Build Coastguard Worker
12030*da0073e9SAndroid Build Coastguard Worker            # side effect
12031*da0073e9SAndroid Build Coastguard Worker            ten = torch.tensor(1)
12032*da0073e9SAndroid Build Coastguard Worker            ten1 = ten2 = ten.add_(1)
12033*da0073e9SAndroid Build Coastguard Worker
12034*da0073e9SAndroid Build Coastguard Worker            # ordering
12035*da0073e9SAndroid Build Coastguard Worker            x = 1
12036*da0073e9SAndroid Build Coastguard Worker            y = 3
12037*da0073e9SAndroid Build Coastguard Worker            x, y = y, x + y
12038*da0073e9SAndroid Build Coastguard Worker
12039*da0073e9SAndroid Build Coastguard Worker            return a, b, c, d, f, ten, ten1, ten2, x, y
12040*da0073e9SAndroid Build Coastguard Worker
12041*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test, ())
12042*da0073e9SAndroid Build Coastguard Worker
12043*da0073e9SAndroid Build Coastguard Worker    def test_multi_reduction(self):
12044*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
12045*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
12046*da0073e9SAndroid Build Coastguard Worker                'augmented assignment can only have one LHS expression'):
12047*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit('''
12048*da0073e9SAndroid Build Coastguard Worker            def multi_reduction(x):
12049*da0073e9SAndroid Build Coastguard Worker                a, b += x
12050*da0073e9SAndroid Build Coastguard Worker                return a, b
12051*da0073e9SAndroid Build Coastguard Worker            ''')
12052*da0073e9SAndroid Build Coastguard Worker
12053*da0073e9SAndroid Build Coastguard Worker    def test_invalid_call_arguments(self):
12054*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'but instead found type '):
12055*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
12056*da0073e9SAndroid Build Coastguard Worker            def invalid_call_arguments(x):
12057*da0073e9SAndroid Build Coastguard Worker                return torch.unsqueeze(3, 4, 5, 6, 7, 8)
12058*da0073e9SAndroid Build Coastguard Worker
12059*da0073e9SAndroid Build Coastguard Worker    def test_invalid_lhs_assignment(self):
12060*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'unexpected expression'):
12061*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit('''
12062*da0073e9SAndroid Build Coastguard Worker            def invalid_lhs_assignment(x):
12063*da0073e9SAndroid Build Coastguard Worker                x + 1 = x
12064*da0073e9SAndroid Build Coastguard Worker                return x
12065*da0073e9SAndroid Build Coastguard Worker            ''')
12066*da0073e9SAndroid Build Coastguard Worker
12067*da0073e9SAndroid Build Coastguard Worker    def test_multi_starred_expr_lhs(self):
12068*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'Only one starred expression is allowed on the lhs'):
12069*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit('''
12070*da0073e9SAndroid Build Coastguard Worker            def multi_starred_expr_lhs():
12071*da0073e9SAndroid Build Coastguard Worker                a, *b, *c = [1, 2, 3, 4, 5, 6]
12072*da0073e9SAndroid Build Coastguard Worker                return a
12073*da0073e9SAndroid Build Coastguard Worker            ''')
12074*da0073e9SAndroid Build Coastguard Worker
12075*da0073e9SAndroid Build Coastguard Worker    def test_pack_tuple_into_non_var(self):
12076*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'Cannot pack a tuple into a non-variable'):
12077*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit('''
12078*da0073e9SAndroid Build Coastguard Worker            def pack_tuple_into_non_var(x):
12079*da0073e9SAndroid Build Coastguard Worker                a, *1 = (3, 4, 5)
12080*da0073e9SAndroid Build Coastguard Worker                return x
12081*da0073e9SAndroid Build Coastguard Worker            ''')
12082*da0073e9SAndroid Build Coastguard Worker
12083*da0073e9SAndroid Build Coastguard Worker    def test_print_kwargs(self):
12084*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'print doesn\'t accept any keyword arguments'):
12085*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit('''
12086*da0073e9SAndroid Build Coastguard Worker            def print_kwargs(x):
12087*da0073e9SAndroid Build Coastguard Worker                print(x, flush=True)
12088*da0073e9SAndroid Build Coastguard Worker                return x
12089*da0073e9SAndroid Build Coastguard Worker            ''')
12090*da0073e9SAndroid Build Coastguard Worker
12091*da0073e9SAndroid Build Coastguard Worker    def test_builtin_use_as_value(self):
12092*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'builtin cannot be used as a value'):
12093*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
12094*da0073e9SAndroid Build Coastguard Worker            def builtin_use_as_value(x):
12095*da0073e9SAndroid Build Coastguard Worker                return x.unsqueeze
12096*da0073e9SAndroid Build Coastguard Worker
12097*da0073e9SAndroid Build Coastguard Worker    def test_wrong_use_as_tuple(self):
12098*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'cannot be used as a tuple'):
12099*da0073e9SAndroid Build Coastguard Worker            def test_fn():
12100*da0073e9SAndroid Build Coastguard Worker                return 3
12101*da0073e9SAndroid Build Coastguard Worker
12102*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
12103*da0073e9SAndroid Build Coastguard Worker            def wrong_use_as_tuple(self):
12104*da0073e9SAndroid Build Coastguard Worker                a, b = test_fn
12105*da0073e9SAndroid Build Coastguard Worker                return a
12106*da0073e9SAndroid Build Coastguard Worker
12107*da0073e9SAndroid Build Coastguard Worker    def test_wrong_attr_lookup(self):
12108*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'attribute lookup is not defined on builtin'):
12109*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
12110*da0073e9SAndroid Build Coastguard Worker            def wrong_attr_lookup(self, x):
12111*da0073e9SAndroid Build Coastguard Worker                a = x.unsqueeze.myattr
12112*da0073e9SAndroid Build Coastguard Worker                return a
12113*da0073e9SAndroid Build Coastguard Worker
12114*da0073e9SAndroid Build Coastguard Worker    def test_wrong_use_as_callable(self):
12115*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'cannot call a value'):
12116*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
12117*da0073e9SAndroid Build Coastguard Worker            def wrong_use_as_callable(x):
12118*da0073e9SAndroid Build Coastguard Worker                return x(3, 4, 5)
12119*da0073e9SAndroid Build Coastguard Worker
12120*da0073e9SAndroid Build Coastguard Worker    def test_python_val_doesnt_have_attr(self):
12121*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'object has no attribute abcd'):
12122*da0073e9SAndroid Build Coastguard Worker
12123*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
12124*da0073e9SAndroid Build Coastguard Worker            def python_val_doesnt_have_attr():
12125*da0073e9SAndroid Build Coastguard Worker                # this has to be a module otherwise attr lookup would not be
12126*da0073e9SAndroid Build Coastguard Worker                # allowed in the first place
12127*da0073e9SAndroid Build Coastguard Worker                return shutil.abcd
12128*da0073e9SAndroid Build Coastguard Worker
12129*da0073e9SAndroid Build Coastguard Worker    def test_wrong_module_attr_lookup(self):
12130*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'python value of type \'type\' cannot be used as a value'):
12131*da0073e9SAndroid Build Coastguard Worker            import io
12132*da0073e9SAndroid Build Coastguard Worker
12133*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
12134*da0073e9SAndroid Build Coastguard Worker            def wrong_module_attr_lookup():
12135*da0073e9SAndroid Build Coastguard Worker                return io.BytesIO
12136*da0073e9SAndroid Build Coastguard Worker
12137*da0073e9SAndroid Build Coastguard Worker    def test_wrong_method_call_inputs(self):
12138*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'Argument y not provided'):
12139*da0073e9SAndroid Build Coastguard Worker            class SomeModule(torch.jit.ScriptModule):
12140*da0073e9SAndroid Build Coastguard Worker
12141*da0073e9SAndroid Build Coastguard Worker                @torch.jit.script_method
12142*da0073e9SAndroid Build Coastguard Worker                def foo(self, x, y):
12143*da0073e9SAndroid Build Coastguard Worker                    return x
12144*da0073e9SAndroid Build Coastguard Worker
12145*da0073e9SAndroid Build Coastguard Worker                @torch.jit.script_method
12146*da0073e9SAndroid Build Coastguard Worker                def forward(self, x, y):
12147*da0073e9SAndroid Build Coastguard Worker                    return self.foo(x)
12148*da0073e9SAndroid Build Coastguard Worker            SomeModule()
12149*da0073e9SAndroid Build Coastguard Worker
12150*da0073e9SAndroid Build Coastguard Worker    def test_single_starred_expr_for_loop(self):
12151*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'A Starred expression may only appear'):
12152*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit('''
12153*da0073e9SAndroid Build Coastguard Worker            def test():
12154*da0073e9SAndroid Build Coastguard Worker                x = 0
12155*da0073e9SAndroid Build Coastguard Worker                for *a in [1, 2, 3]:
12156*da0073e9SAndroid Build Coastguard Worker                    x = x + 1
12157*da0073e9SAndroid Build Coastguard Worker                return x
12158*da0073e9SAndroid Build Coastguard Worker            ''')
12159*da0073e9SAndroid Build Coastguard Worker
12160*da0073e9SAndroid Build Coastguard Worker    def test_call_ge(self):
12161*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'Expected at most 1 arguments but found 3'):
12162*da0073e9SAndroid Build Coastguard Worker            @_trace(torch.zeros(1, 2, 3))
12163*da0073e9SAndroid Build Coastguard Worker            def foo(x):
12164*da0073e9SAndroid Build Coastguard Worker                return x
12165*da0073e9SAndroid Build Coastguard Worker
12166*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
12167*da0073e9SAndroid Build Coastguard Worker            def test_fn():
12168*da0073e9SAndroid Build Coastguard Worker                return foo(torch.full([1], 1), torch.full([1], 2), torch.full([1], 3))
12169*da0073e9SAndroid Build Coastguard Worker
12170*da0073e9SAndroid Build Coastguard Worker    def test_wrong_return_type(self):
12171*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'but instead got value of type tuple'):
12172*da0073e9SAndroid Build Coastguard Worker            @torch.jit.ignore
12173*da0073e9SAndroid Build Coastguard Worker            def somefunc():
12174*da0073e9SAndroid Build Coastguard Worker                # type: () -> Tuple[Tuple[Tensor, Tensor]]
12175*da0073e9SAndroid Build Coastguard Worker                return torch.zeros(3, 4), torch.zeros(4, 5)  # noqa: T484
12176*da0073e9SAndroid Build Coastguard Worker
12177*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
12178*da0073e9SAndroid Build Coastguard Worker            def wrong_return_type():
12179*da0073e9SAndroid Build Coastguard Worker                return somefunc()
12180*da0073e9SAndroid Build Coastguard Worker            wrong_return_type()
12181*da0073e9SAndroid Build Coastguard Worker
12182*da0073e9SAndroid Build Coastguard Worker    # Tests for calling between different front-end modes
12183*da0073e9SAndroid Build Coastguard Worker    def test_call_python_fn_from_tracing_fn(self):
12184*da0073e9SAndroid Build Coastguard Worker        def python_fn(x):
12185*da0073e9SAndroid Build Coastguard Worker            return torch.neg(x)
12186*da0073e9SAndroid Build Coastguard Worker
12187*da0073e9SAndroid Build Coastguard Worker        @_trace(torch.rand(3, 4))
12188*da0073e9SAndroid Build Coastguard Worker        def traced_fn(x):
12189*da0073e9SAndroid Build Coastguard Worker            return python_fn(x) + 1
12190*da0073e9SAndroid Build Coastguard Worker
12191*da0073e9SAndroid Build Coastguard Worker        # The neg op in the python function should be properly inlined to the
12192*da0073e9SAndroid Build Coastguard Worker        # graph
12193*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::neg").run(str(traced_fn.graph))
12194*da0073e9SAndroid Build Coastguard Worker
12195*da0073e9SAndroid Build Coastguard Worker    def test_call_python_mod_from_tracing_fn(self):
12196*da0073e9SAndroid Build Coastguard Worker        class PythonMod(torch.nn.Module):
12197*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
12198*da0073e9SAndroid Build Coastguard Worker                super().__init__()
12199*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(4, 3), requires_grad=False)
12200*da0073e9SAndroid Build Coastguard Worker
12201*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
12202*da0073e9SAndroid Build Coastguard Worker                return torch.mm(x, self.param)
12203*da0073e9SAndroid Build Coastguard Worker
12204*da0073e9SAndroid Build Coastguard Worker        pm = PythonMod()
12205*da0073e9SAndroid Build Coastguard Worker
12206*da0073e9SAndroid Build Coastguard Worker        @_trace(torch.rand(3, 4))
12207*da0073e9SAndroid Build Coastguard Worker        def traced_fn(x):
12208*da0073e9SAndroid Build Coastguard Worker            return pm(x) + 1.0
12209*da0073e9SAndroid Build Coastguard Worker
12210*da0073e9SAndroid Build Coastguard Worker        # Note: the parameter self.param from the Python module is inlined
12211*da0073e9SAndroid Build Coastguard Worker        # into the graph
12212*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(list(traced_fn.graph.inputs())) == 1)
12213*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::mm").check("aten::add").run(str(traced_fn.graph))
12214*da0073e9SAndroid Build Coastguard Worker
12215*da0073e9SAndroid Build Coastguard Worker    @_tmp_donotuse_dont_inline_everything
12216*da0073e9SAndroid Build Coastguard Worker    def test_call_traced_fn_from_tracing_fn(self):
12217*da0073e9SAndroid Build Coastguard Worker        @_trace(torch.rand(3, 4))
12218*da0073e9SAndroid Build Coastguard Worker        def traced_fn1(x):
12219*da0073e9SAndroid Build Coastguard Worker            return torch.neg(x)
12220*da0073e9SAndroid Build Coastguard Worker
12221*da0073e9SAndroid Build Coastguard Worker        @_trace(torch.rand(3, 4))
12222*da0073e9SAndroid Build Coastguard Worker        def traced_fn(x):
12223*da0073e9SAndroid Build Coastguard Worker            return traced_fn1(x) + 1
12224*da0073e9SAndroid Build Coastguard Worker
12225*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("traced_fn").check("prim::CallFunction").check("aten::add") \
12226*da0073e9SAndroid Build Coastguard Worker            .run(str(traced_fn.graph))
12227*da0073e9SAndroid Build Coastguard Worker
12228*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("error in first class mode")
12229*da0073e9SAndroid Build Coastguard Worker    def test_call_traced_mod_from_tracing_fn(self):
12230*da0073e9SAndroid Build Coastguard Worker        class TracedModule(torch.nn.Module):
12231*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
12232*da0073e9SAndroid Build Coastguard Worker                super().__init__()
12233*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(4, 3), requires_grad=False)
12234*da0073e9SAndroid Build Coastguard Worker
12235*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
12236*da0073e9SAndroid Build Coastguard Worker                return torch.mm(x, self.param)
12237*da0073e9SAndroid Build Coastguard Worker
12238*da0073e9SAndroid Build Coastguard Worker        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
12239*da0073e9SAndroid Build Coastguard Worker
12240*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "must be registered as submodules"):
12241*da0073e9SAndroid Build Coastguard Worker            @_trace(torch.rand(3, 4))
12242*da0073e9SAndroid Build Coastguard Worker            def traced_fn(x):
12243*da0073e9SAndroid Build Coastguard Worker                return tm(x) + 1.0
12244*da0073e9SAndroid Build Coastguard Worker
12245*da0073e9SAndroid Build Coastguard Worker    @_tmp_donotuse_dont_inline_everything
12246*da0073e9SAndroid Build Coastguard Worker    def test_call_script_fn_from_tracing_fn(self):
12247*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
12248*da0073e9SAndroid Build Coastguard Worker        def script_fn(x):
12249*da0073e9SAndroid Build Coastguard Worker            return torch.neg(x)
12250*da0073e9SAndroid Build Coastguard Worker
12251*da0073e9SAndroid Build Coastguard Worker        @_trace(torch.rand(3, 4))
12252*da0073e9SAndroid Build Coastguard Worker        def traced_fn(x):
12253*da0073e9SAndroid Build Coastguard Worker            return script_fn(x) + 1
12254*da0073e9SAndroid Build Coastguard Worker
12255*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("prim::CallFunction").check("aten::add").run(str(traced_fn.graph))
12256*da0073e9SAndroid Build Coastguard Worker
12257*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("error in first class mode")
12258*da0073e9SAndroid Build Coastguard Worker    def test_call_script_mod_from_tracing_fn(self):
12259*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "must be registered as submodules"):
12260*da0073e9SAndroid Build Coastguard Worker            class ScriptMod(torch.jit.ScriptModule):
12261*da0073e9SAndroid Build Coastguard Worker                def __init__(self) -> None:
12262*da0073e9SAndroid Build Coastguard Worker                    super().__init__()
12263*da0073e9SAndroid Build Coastguard Worker                    self.param = torch.nn.Parameter(torch.rand(3, 4), requires_grad=False)
12264*da0073e9SAndroid Build Coastguard Worker
12265*da0073e9SAndroid Build Coastguard Worker                @torch.jit.script_method
12266*da0073e9SAndroid Build Coastguard Worker                def forward(self, x):
12267*da0073e9SAndroid Build Coastguard Worker                    for _i in range(4):
12268*da0073e9SAndroid Build Coastguard Worker                        x += self.param
12269*da0073e9SAndroid Build Coastguard Worker                    return x
12270*da0073e9SAndroid Build Coastguard Worker
12271*da0073e9SAndroid Build Coastguard Worker            sm = ScriptMod()
12272*da0073e9SAndroid Build Coastguard Worker
12273*da0073e9SAndroid Build Coastguard Worker            @_trace(torch.rand(3, 4))
12274*da0073e9SAndroid Build Coastguard Worker            def traced_fn(x):
12275*da0073e9SAndroid Build Coastguard Worker                return sm(x) + 1.0
12276*da0073e9SAndroid Build Coastguard Worker
12277*da0073e9SAndroid Build Coastguard Worker
12278*da0073e9SAndroid Build Coastguard Worker    def test_call_python_fn_from_traced_module(self):
12279*da0073e9SAndroid Build Coastguard Worker        def python_fn(x):
12280*da0073e9SAndroid Build Coastguard Worker            return torch.neg(x)
12281*da0073e9SAndroid Build Coastguard Worker
12282*da0073e9SAndroid Build Coastguard Worker        class TracedModule(torch.nn.Module):
12283*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
12284*da0073e9SAndroid Build Coastguard Worker                super().__init__()
12285*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(4, 3))
12286*da0073e9SAndroid Build Coastguard Worker
12287*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
12288*da0073e9SAndroid Build Coastguard Worker                return torch.mm(python_fn(x), self.param)
12289*da0073e9SAndroid Build Coastguard Worker
12290*da0073e9SAndroid Build Coastguard Worker        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
12291*da0073e9SAndroid Build Coastguard Worker
12292*da0073e9SAndroid Build Coastguard Worker        # Note: parameter self.param from the traced module should appear as
12293*da0073e9SAndroid Build Coastguard Worker        # an input to the graph and the neg op from the Python function should
12294*da0073e9SAndroid Build Coastguard Worker        # be properly inlined
12295*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(list(tm.graph.inputs())) == 2)
12296*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::neg").check("aten::mm").run(str(tm.graph))
12297*da0073e9SAndroid Build Coastguard Worker
12298*da0073e9SAndroid Build Coastguard Worker    def test_call_python_mod_from_traced_module(self):
12299*da0073e9SAndroid Build Coastguard Worker        class PythonModule(torch.nn.Module):
12300*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
12301*da0073e9SAndroid Build Coastguard Worker                super().__init__()
12302*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(5, 7))
12303*da0073e9SAndroid Build Coastguard Worker
12304*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
12305*da0073e9SAndroid Build Coastguard Worker                return torch.mm(x, self.param)
12306*da0073e9SAndroid Build Coastguard Worker
12307*da0073e9SAndroid Build Coastguard Worker        class TracedModule(torch.nn.Module):
12308*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
12309*da0073e9SAndroid Build Coastguard Worker                super().__init__()
12310*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(4, 5))
12311*da0073e9SAndroid Build Coastguard Worker                self.mod = PythonModule()
12312*da0073e9SAndroid Build Coastguard Worker
12313*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
12314*da0073e9SAndroid Build Coastguard Worker                return self.mod(torch.mm(x, self.param)) + 1.0
12315*da0073e9SAndroid Build Coastguard Worker
12316*da0073e9SAndroid Build Coastguard Worker        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
12317*da0073e9SAndroid Build Coastguard Worker
12318*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("value=<Tensor>").check("aten::mm")\
12319*da0073e9SAndroid Build Coastguard Worker            .check('prim::CallMethod[name="forward"]').check("aten::add") \
12320*da0073e9SAndroid Build Coastguard Worker            .run(str(tm.graph))
12321*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::mm").run(str(tm.mod.graph))
12322*da0073e9SAndroid Build Coastguard Worker
12323*da0073e9SAndroid Build Coastguard Worker    def test_op_dtype(self):
12324*da0073e9SAndroid Build Coastguard Worker
12325*da0073e9SAndroid Build Coastguard Worker        def check_equal_and_dtype(a, b):
12326*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a, b)
12327*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a.dtype, b.dtype)
12328*da0073e9SAndroid Build Coastguard Worker
12329*da0073e9SAndroid Build Coastguard Worker        def fn():
12330*da0073e9SAndroid Build Coastguard Worker            a = torch.arange(10)
12331*da0073e9SAndroid Build Coastguard Worker            b = torch.arange(10, dtype=torch.float)
12332*da0073e9SAndroid Build Coastguard Worker            c = torch.arange(1, 10, 2)
12333*da0073e9SAndroid Build Coastguard Worker            d = torch.arange(1, 10, 2, dtype=torch.float)
12334*da0073e9SAndroid Build Coastguard Worker            e = torch.arange(1, 10., 2)
12335*da0073e9SAndroid Build Coastguard Worker            f = torch.arange(1, 10., 2, dtype=torch.float)
12336*da0073e9SAndroid Build Coastguard Worker            return a, b, c, d, e, f
12337*da0073e9SAndroid Build Coastguard Worker
12338*da0073e9SAndroid Build Coastguard Worker        scripted_fn = torch.jit.script(fn)
12339*da0073e9SAndroid Build Coastguard Worker        eager_out = fn()
12340*da0073e9SAndroid Build Coastguard Worker        script_out = scripted_fn()
12341*da0073e9SAndroid Build Coastguard Worker        for a, b in zip(eager_out, script_out):
12342*da0073e9SAndroid Build Coastguard Worker            check_equal_and_dtype(a, b)
12343*da0073e9SAndroid Build Coastguard Worker
12344*da0073e9SAndroid Build Coastguard Worker    def test_floor_div(self):
12345*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
12346*da0073e9SAndroid Build Coastguard Worker        def foo(a, b):
12347*da0073e9SAndroid Build Coastguard Worker            # type: (int, int) -> int
12348*da0073e9SAndroid Build Coastguard Worker            return a // b
12349*da0073e9SAndroid Build Coastguard Worker        for i in range(-8, 8):
12350*da0073e9SAndroid Build Coastguard Worker            for j in range(-8, 8):
12351*da0073e9SAndroid Build Coastguard Worker                if j != 0:
12352*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(foo(i, j), i // j)
12353*da0073e9SAndroid Build Coastguard Worker
12354*da0073e9SAndroid Build Coastguard Worker    def test_floordiv(self):
12355*da0073e9SAndroid Build Coastguard Worker        funcs_template = dedent('''
12356*da0073e9SAndroid Build Coastguard Worker        def fn():
12357*da0073e9SAndroid Build Coastguard Worker            ten = {a_construct}
12358*da0073e9SAndroid Build Coastguard Worker            ten_or_scalar = {b_construct}
12359*da0073e9SAndroid Build Coastguard Worker            return ten // ten_or_scalar, torch.floor_divide(ten, ten_or_scalar)
12360*da0073e9SAndroid Build Coastguard Worker        ''')
12361*da0073e9SAndroid Build Coastguard Worker
12362*da0073e9SAndroid Build Coastguard Worker        lhs = ["torch.tensor([5.5, 3.2])", "torch.tensor([2, 2])", "torch.tensor([3, 2])"]
12363*da0073e9SAndroid Build Coastguard Worker        rhs = ["1.5", "2", "4", "1.1"] + lhs
12364*da0073e9SAndroid Build Coastguard Worker        for tensor in lhs:
12365*da0073e9SAndroid Build Coastguard Worker            for tensor_or_scalar in rhs:
12366*da0073e9SAndroid Build Coastguard Worker                funcs_str = funcs_template.format(a_construct=tensor, b_construct=tensor_or_scalar)
12367*da0073e9SAndroid Build Coastguard Worker                scope = {}
12368*da0073e9SAndroid Build Coastguard Worker                execWrapper(funcs_str, globals(), scope)
12369*da0073e9SAndroid Build Coastguard Worker                cu = torch.jit.CompilationUnit(funcs_str)
12370*da0073e9SAndroid Build Coastguard Worker                f_script = cu.fn
12371*da0073e9SAndroid Build Coastguard Worker                f = scope['fn']
12372*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(f_script(), f())
12373*da0073e9SAndroid Build Coastguard Worker
12374*da0073e9SAndroid Build Coastguard Worker    def test_call_python_fn_from_script_fn(self):
12375*da0073e9SAndroid Build Coastguard Worker        @torch.jit.ignore
12376*da0073e9SAndroid Build Coastguard Worker        def python_fn(x):
12377*da0073e9SAndroid Build Coastguard Worker            return torch.neg(x)
12378*da0073e9SAndroid Build Coastguard Worker
12379*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
12380*da0073e9SAndroid Build Coastguard Worker        def script_fn(x):
12381*da0073e9SAndroid Build Coastguard Worker            return python_fn(x) + 1
12382*da0073e9SAndroid Build Coastguard Worker
12383*da0073e9SAndroid Build Coastguard Worker        # Note: the call to python_fn appears as `^python_fn()` and is called
12384*da0073e9SAndroid Build Coastguard Worker        # as a PythonOp in the interpreter
12385*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1)
12386*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(script_fn(a), torch.tensor(0))
12387*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("python_fn").run(str(script_fn.graph))
12388*da0073e9SAndroid Build Coastguard Worker
12389*da0073e9SAndroid Build Coastguard Worker    def test_call_python_mod_from_script_fn(self):
12390*da0073e9SAndroid Build Coastguard Worker        class PythonModule(torch.nn.Module):
12391*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
12392*da0073e9SAndroid Build Coastguard Worker                super().__init__()
12393*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(5, 7))
12394*da0073e9SAndroid Build Coastguard Worker
12395*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
12396*da0073e9SAndroid Build Coastguard Worker                return torch.mm(x, self.param)
12397*da0073e9SAndroid Build Coastguard Worker
12398*da0073e9SAndroid Build Coastguard Worker        pm = PythonModule()
12399*da0073e9SAndroid Build Coastguard Worker
12400*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
12401*da0073e9SAndroid Build Coastguard Worker        def script_fn(x):
12402*da0073e9SAndroid Build Coastguard Worker            return pm(x) + 1
12403*da0073e9SAndroid Build Coastguard Worker
12404*da0073e9SAndroid Build Coastguard Worker        # Note: call to pm(x) appears as ^<python_value>() in the trace.
12405*da0073e9SAndroid Build Coastguard Worker        # Parameters are NOT inlined.
12406*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("python_value").check("aten::add").run(str(script_fn.graph))
12407*da0073e9SAndroid Build Coastguard Worker
12408*da0073e9SAndroid Build Coastguard Worker    @_tmp_donotuse_dont_inline_everything
12409*da0073e9SAndroid Build Coastguard Worker    def test_call_script_fn_from_script_fn(self):
12410*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
12411*da0073e9SAndroid Build Coastguard Worker        def script_fn1(x):
12412*da0073e9SAndroid Build Coastguard Worker            return torch.neg(x)
12413*da0073e9SAndroid Build Coastguard Worker
12414*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
12415*da0073e9SAndroid Build Coastguard Worker        def script_fn(x):
12416*da0073e9SAndroid Build Coastguard Worker            return script_fn1(x) + 1
12417*da0073e9SAndroid Build Coastguard Worker
12418*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("prim::CallFunction").run(str(script_fn.graph))
12419*da0073e9SAndroid Build Coastguard Worker
12420*da0073e9SAndroid Build Coastguard Worker    def test_call_script_mod_from_script_fn(self):
12421*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Cannot call a ScriptModule that is not a submodule of the caller"):
12422*da0073e9SAndroid Build Coastguard Worker            class ScriptMod(torch.jit.ScriptModule):
12423*da0073e9SAndroid Build Coastguard Worker                @torch.jit.script_method
12424*da0073e9SAndroid Build Coastguard Worker                def forward(self, x):
12425*da0073e9SAndroid Build Coastguard Worker                    return torch.mm(x, torch.zeros([4, 3]))
12426*da0073e9SAndroid Build Coastguard Worker
12427*da0073e9SAndroid Build Coastguard Worker            sm = ScriptMod()
12428*da0073e9SAndroid Build Coastguard Worker
12429*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
12430*da0073e9SAndroid Build Coastguard Worker            def script_fn(x):
12431*da0073e9SAndroid Build Coastguard Worker                return sm(x) + 1
12432*da0073e9SAndroid Build Coastguard Worker
12433*da0073e9SAndroid Build Coastguard Worker    def test_call_python_fn_from_script_module(self):
12434*da0073e9SAndroid Build Coastguard Worker        @torch.jit.ignore
12435*da0073e9SAndroid Build Coastguard Worker        def python_fn(x):
12436*da0073e9SAndroid Build Coastguard Worker            return torch.neg(x)
12437*da0073e9SAndroid Build Coastguard Worker
12438*da0073e9SAndroid Build Coastguard Worker        class ScriptMod(torch.jit.ScriptModule):
12439*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
12440*da0073e9SAndroid Build Coastguard Worker                super().__init__()
12441*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(4, 3))
12442*da0073e9SAndroid Build Coastguard Worker
12443*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
12444*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
12445*da0073e9SAndroid Build Coastguard Worker                return python_fn(torch.mm(x, self.param))
12446*da0073e9SAndroid Build Coastguard Worker
12447*da0073e9SAndroid Build Coastguard Worker        sm = ScriptMod()
12448*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::mm").check("python_fn") \
12449*da0073e9SAndroid Build Coastguard Worker            .run(str(sm.forward.graph))
12450*da0073e9SAndroid Build Coastguard Worker
12451*da0073e9SAndroid Build Coastguard Worker    def test_call_python_mod_from_script_module(self):
12452*da0073e9SAndroid Build Coastguard Worker        class PythonMod(torch.nn.Module):
12453*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
12454*da0073e9SAndroid Build Coastguard Worker                super().__init__()
12455*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(3, 5))
12456*da0073e9SAndroid Build Coastguard Worker
12457*da0073e9SAndroid Build Coastguard Worker            @torch.jit.ignore
12458*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
12459*da0073e9SAndroid Build Coastguard Worker                return torch.mm(x, self.param)
12460*da0073e9SAndroid Build Coastguard Worker
12461*da0073e9SAndroid Build Coastguard Worker        class ScriptMod(torch.jit.ScriptModule):
12462*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
12463*da0073e9SAndroid Build Coastguard Worker                super().__init__()
12464*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(4, 3))
12465*da0073e9SAndroid Build Coastguard Worker                self.pm = PythonMod()
12466*da0073e9SAndroid Build Coastguard Worker
12467*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
12468*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
12469*da0073e9SAndroid Build Coastguard Worker                return self.pm(torch.mm(x, self.param))
12470*da0073e9SAndroid Build Coastguard Worker
12471*da0073e9SAndroid Build Coastguard Worker        sm = ScriptMod()
12472*da0073e9SAndroid Build Coastguard Worker        # Note: the call into PythonMod appears as ^forward(). Parameters
12473*da0073e9SAndroid Build Coastguard Worker        # are NOT inlined
12474*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::mm").check("forward").run(str(sm.graph))
12475*da0073e9SAndroid Build Coastguard Worker
12476*da0073e9SAndroid Build Coastguard Worker    @_tmp_donotuse_dont_inline_everything
12477*da0073e9SAndroid Build Coastguard Worker    def test_call_script_fn_from_script_module(self):
12478*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
12479*da0073e9SAndroid Build Coastguard Worker        def script_fn(x):
12480*da0073e9SAndroid Build Coastguard Worker            return torch.neg(x)
12481*da0073e9SAndroid Build Coastguard Worker
12482*da0073e9SAndroid Build Coastguard Worker        class ScriptMod(torch.jit.ScriptModule):
12483*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
12484*da0073e9SAndroid Build Coastguard Worker                super().__init__()
12485*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(4, 3))
12486*da0073e9SAndroid Build Coastguard Worker
12487*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
12488*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
12489*da0073e9SAndroid Build Coastguard Worker                return script_fn(torch.mm(x, self.param))
12490*da0073e9SAndroid Build Coastguard Worker
12491*da0073e9SAndroid Build Coastguard Worker        sm = ScriptMod()
12492*da0073e9SAndroid Build Coastguard Worker        graph = (sm.forward.graph)
12493*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::mm").check("prim::CallFunction").run(str(graph))
12494*da0073e9SAndroid Build Coastguard Worker
12495*da0073e9SAndroid Build Coastguard Worker    @_tmp_donotuse_dont_inline_everything
12496*da0073e9SAndroid Build Coastguard Worker    def test_call_script_mod_from_script_module(self):
12497*da0073e9SAndroid Build Coastguard Worker        class ScriptMod1(torch.jit.ScriptModule):
12498*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
12499*da0073e9SAndroid Build Coastguard Worker                super().__init__()
12500*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(3, 5))
12501*da0073e9SAndroid Build Coastguard Worker
12502*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
12503*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
12504*da0073e9SAndroid Build Coastguard Worker                return torch.mm(x, self.param)
12505*da0073e9SAndroid Build Coastguard Worker
12506*da0073e9SAndroid Build Coastguard Worker        class ScriptMod(torch.jit.ScriptModule):
12507*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
12508*da0073e9SAndroid Build Coastguard Worker                super().__init__()
12509*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(4, 3))
12510*da0073e9SAndroid Build Coastguard Worker                self.tm = ScriptMod1()
12511*da0073e9SAndroid Build Coastguard Worker
12512*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
12513*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
12514*da0073e9SAndroid Build Coastguard Worker                return self.tm(torch.mm(x, self.param))
12515*da0073e9SAndroid Build Coastguard Worker
12516*da0073e9SAndroid Build Coastguard Worker        sm = ScriptMod()
12517*da0073e9SAndroid Build Coastguard Worker        # Note: the parameters from both modules should appear in the flattened
12518*da0073e9SAndroid Build Coastguard Worker        # input list to the graph. The mm op from ScriptMod1 should be properly
12519*da0073e9SAndroid Build Coastguard Worker        # inlined
12520*da0073e9SAndroid Build Coastguard Worker        # 3 % values in graph input lists, two mms in body
12521*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count('%', 3).check(":").check_count("mm", 1).check("prim::CallMethod").run(str(sm.graph))
12522*da0073e9SAndroid Build Coastguard Worker
12523*da0073e9SAndroid Build Coastguard Worker    def test_module_with_params_called_fails(self):
12524*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Cannot call a ScriptModule that is not a submodule of the caller"):
12525*da0073e9SAndroid Build Coastguard Worker            class ScriptMod(torch.jit.ScriptModule):
12526*da0073e9SAndroid Build Coastguard Worker                def __init__(self) -> None:
12527*da0073e9SAndroid Build Coastguard Worker                    super().__init__()
12528*da0073e9SAndroid Build Coastguard Worker                    self.param = torch.nn.Parameter(torch.rand(3, 3))
12529*da0073e9SAndroid Build Coastguard Worker
12530*da0073e9SAndroid Build Coastguard Worker                @torch.jit.script_method
12531*da0073e9SAndroid Build Coastguard Worker                def forward(self, x):
12532*da0073e9SAndroid Build Coastguard Worker                    return torch.mm(x, self.param)
12533*da0073e9SAndroid Build Coastguard Worker
12534*da0073e9SAndroid Build Coastguard Worker            sm = ScriptMod()
12535*da0073e9SAndroid Build Coastguard Worker
12536*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
12537*da0073e9SAndroid Build Coastguard Worker            def some_func(x):
12538*da0073e9SAndroid Build Coastguard Worker                return sm(x)
12539*da0073e9SAndroid Build Coastguard Worker
12540*da0073e9SAndroid Build Coastguard Worker    def test_tuple_index_to_list(self):
12541*da0073e9SAndroid Build Coastguard Worker        def test_non_constant_input(a):
12542*da0073e9SAndroid Build Coastguard Worker            # type: (bool) -> int
12543*da0073e9SAndroid Build Coastguard Worker            if a:
12544*da0073e9SAndroid Build Coastguard Worker                b = 1
12545*da0073e9SAndroid Build Coastguard Worker            else:
12546*da0073e9SAndroid Build Coastguard Worker                b = 0
12547*da0073e9SAndroid Build Coastguard Worker            c = (0, 1)
12548*da0073e9SAndroid Build Coastguard Worker            return c[b]
12549*da0073e9SAndroid Build Coastguard Worker
12550*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_non_constant_input, (True,))
12551*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_non_constant_input, (False,))
12552*da0073e9SAndroid Build Coastguard Worker
12553*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "because we cannot resolve the output type"):
12554*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
12555*da0073e9SAndroid Build Coastguard Worker            def test_non_constant_input(a):
12556*da0073e9SAndroid Build Coastguard Worker                # type: (bool) -> None
12557*da0073e9SAndroid Build Coastguard Worker                if a:
12558*da0073e9SAndroid Build Coastguard Worker                    b = 1
12559*da0073e9SAndroid Build Coastguard Worker                else:
12560*da0073e9SAndroid Build Coastguard Worker                    b = 0
12561*da0073e9SAndroid Build Coastguard Worker                c = (0, 1.1)
12562*da0073e9SAndroid Build Coastguard Worker                print(c[b])
12563*da0073e9SAndroid Build Coastguard Worker
12564*da0073e9SAndroid Build Coastguard Worker    def test_tuple_indexing(self):
12565*da0073e9SAndroid Build Coastguard Worker        def tuple_index(a):
12566*da0073e9SAndroid Build Coastguard Worker            if bool(a):
12567*da0073e9SAndroid Build Coastguard Worker                b = (1, 2)
12568*da0073e9SAndroid Build Coastguard Worker            else:
12569*da0073e9SAndroid Build Coastguard Worker                b = (0, 2)
12570*da0073e9SAndroid Build Coastguard Worker            return b[-2], b[1]
12571*da0073e9SAndroid Build Coastguard Worker
12572*da0073e9SAndroid Build Coastguard Worker        self.checkScript(tuple_index, (torch.tensor([0]),))
12573*da0073e9SAndroid Build Coastguard Worker        self.checkScript(tuple_index, (torch.tensor([1]),))
12574*da0073e9SAndroid Build Coastguard Worker        self.checkScript(tuple_index, (torch.tensor([1]),), optimize=True)
12575*da0073e9SAndroid Build Coastguard Worker        tuple_comp = torch.jit.script(tuple_index)
12576*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("TupleIndex", 2, exactly=True).run(str(tuple_comp.graph))
12577*da0073e9SAndroid Build Coastguard Worker
12578*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "index must be an integer"):
12579*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
12580*da0073e9SAndroid Build Coastguard Worker            def test_indexing_float():
12581*da0073e9SAndroid Build Coastguard Worker                c = (1, 2)
12582*da0073e9SAndroid Build Coastguard Worker                return c[0.1]
12583*da0073e9SAndroid Build Coastguard Worker
12584*da0073e9SAndroid Build Coastguard Worker        def test_indexing_out_of_bounds_pos():
12585*da0073e9SAndroid Build Coastguard Worker            c = (1, 2)
12586*da0073e9SAndroid Build Coastguard Worker            return c[2]
12587*da0073e9SAndroid Build Coastguard Worker
12588*da0073e9SAndroid Build Coastguard Worker        self.checkScriptRaisesRegex(test_indexing_out_of_bounds_pos, (), Exception,
12589*da0073e9SAndroid Build Coastguard Worker                                    "out of range")
12590*da0073e9SAndroid Build Coastguard Worker
12591*da0073e9SAndroid Build Coastguard Worker        def test_indexing_out_of_bounds_neg():
12592*da0073e9SAndroid Build Coastguard Worker            c = (1, 2)
12593*da0073e9SAndroid Build Coastguard Worker            return c[-3]
12594*da0073e9SAndroid Build Coastguard Worker
12595*da0073e9SAndroid Build Coastguard Worker        self.checkScriptRaisesRegex(test_indexing_out_of_bounds_pos, (), Exception,
12596*da0073e9SAndroid Build Coastguard Worker                                    "out of range")
12597*da0073e9SAndroid Build Coastguard Worker
12598*da0073e9SAndroid Build Coastguard Worker        def negative_index():
12599*da0073e9SAndroid Build Coastguard Worker            tup = (1, 2, 3, 4)
12600*da0073e9SAndroid Build Coastguard Worker            return tup[-1]
12601*da0073e9SAndroid Build Coastguard Worker
12602*da0073e9SAndroid Build Coastguard Worker        self.checkScript(negative_index, [])
12603*da0073e9SAndroid Build Coastguard Worker
12604*da0073e9SAndroid Build Coastguard Worker        def really_negative_index():
12605*da0073e9SAndroid Build Coastguard Worker            tup = (1, 2, 3, 4)
12606*da0073e9SAndroid Build Coastguard Worker            return tup[-100]
12607*da0073e9SAndroid Build Coastguard Worker
12608*da0073e9SAndroid Build Coastguard Worker        self.checkScriptRaisesRegex(really_negative_index, [], Exception, "index out of range")
12609*da0073e9SAndroid Build Coastguard Worker
12610*da0073e9SAndroid Build Coastguard Worker        def negative_slice():
12611*da0073e9SAndroid Build Coastguard Worker            tup = (1, 2, 3, 4)
12612*da0073e9SAndroid Build Coastguard Worker            return tup[-3:4]
12613*da0073e9SAndroid Build Coastguard Worker
12614*da0073e9SAndroid Build Coastguard Worker        self.checkScript(negative_slice, [])
12615*da0073e9SAndroid Build Coastguard Worker
12616*da0073e9SAndroid Build Coastguard Worker        def really_slice_out_of_bounds():
12617*da0073e9SAndroid Build Coastguard Worker            tup = (1, 2, 3, 4)
12618*da0073e9SAndroid Build Coastguard Worker            return tup[-300:4000]
12619*da0073e9SAndroid Build Coastguard Worker
12620*da0073e9SAndroid Build Coastguard Worker        self.checkScript(really_slice_out_of_bounds, [])
12621*da0073e9SAndroid Build Coastguard Worker
12622*da0073e9SAndroid Build Coastguard Worker    def test_namedtuple_attr(self):
12623*da0073e9SAndroid Build Coastguard Worker        def f(x):
12624*da0073e9SAndroid Build Coastguard Worker            return x.max(dim=1).indices + torch.max(x, dim=1).indices
12625*da0073e9SAndroid Build Coastguard Worker
12626*da0073e9SAndroid Build Coastguard Worker        self.checkScript(f, (torch.rand(20, 20, 20),), optimize=True)
12627*da0073e9SAndroid Build Coastguard Worker
12628*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "object has no attribute or method"):
12629*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
12630*da0073e9SAndroid Build Coastguard Worker            def g1(x):
12631*da0073e9SAndroid Build Coastguard Worker                return x.max(dim=1).unknown_symbol
12632*da0073e9SAndroid Build Coastguard Worker
12633*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "object has no attribute or method"):
12634*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
12635*da0073e9SAndroid Build Coastguard Worker            def g2(x):
12636*da0073e9SAndroid Build Coastguard Worker                print((x, x, x).__doc__)
12637*da0073e9SAndroid Build Coastguard Worker                return x
12638*da0073e9SAndroid Build Coastguard Worker
12639*da0073e9SAndroid Build Coastguard Worker    def test_tuple_len(self):
12640*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
12641*da0073e9SAndroid Build Coastguard Worker        def foo():
12642*da0073e9SAndroid Build Coastguard Worker            return len((1, "str", None))
12643*da0073e9SAndroid Build Coastguard Worker
12644*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(), 3)
12645*da0073e9SAndroid Build Coastguard Worker
12646*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
12647*da0073e9SAndroid Build Coastguard Worker        def test_indexing_end_out_of_bounds():
12648*da0073e9SAndroid Build Coastguard Worker            c = (1, 2)
12649*da0073e9SAndroid Build Coastguard Worker            return c[2:10]
12650*da0073e9SAndroid Build Coastguard Worker
12651*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(test_indexing_end_out_of_bounds(), ())
12652*da0073e9SAndroid Build Coastguard Worker
12653*da0073e9SAndroid Build Coastguard Worker    def test_lower_nested_tuples(self):
12654*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
12655*da0073e9SAndroid Build Coastguard Worker        def test():
12656*da0073e9SAndroid Build Coastguard Worker            return ((1, 2), 3)
12657*da0073e9SAndroid Build Coastguard Worker
12658*da0073e9SAndroid Build Coastguard Worker        self.run_pass('constant_propagation', test.graph)
12659*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("prim::Constant").check_not("TupleConstruct").run(test.graph)
12660*da0073e9SAndroid Build Coastguard Worker        # fails if a tuple can't be lowered
12661*da0073e9SAndroid Build Coastguard Worker        self.run_pass('lower_all_tuples', test.graph)
12662*da0073e9SAndroid Build Coastguard Worker
12663*da0073e9SAndroid Build Coastguard Worker    def test_unwrap_optional_builtin(self):
12664*da0073e9SAndroid Build Coastguard Worker        def test(x):
12665*da0073e9SAndroid Build Coastguard Worker            # type: (Optional[int]) -> int
12666*da0073e9SAndroid Build Coastguard Worker            x = torch.jit._unwrap_optional(x)
12667*da0073e9SAndroid Build Coastguard Worker            x = x + x  # noqa: T484
12668*da0073e9SAndroid Build Coastguard Worker            return x
12669*da0073e9SAndroid Build Coastguard Worker
12670*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test, (3,))
12671*da0073e9SAndroid Build Coastguard Worker
12672*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(AssertionError, "Unwrapping null optional"):
12673*da0073e9SAndroid Build Coastguard Worker            test(None)
12674*da0073e9SAndroid Build Coastguard Worker
12675*da0073e9SAndroid Build Coastguard Worker        test_script = torch.jit.script(test)
12676*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Unwrapping null optional"):
12677*da0073e9SAndroid Build Coastguard Worker            test_script(None)
12678*da0073e9SAndroid Build Coastguard Worker
12679*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
12680*da0073e9SAndroid Build Coastguard Worker        def test_test():
12681*da0073e9SAndroid Build Coastguard Worker            return torch.jit._unwrap_optional(1)
12682*da0073e9SAndroid Build Coastguard Worker
12683*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"could not be inferred from actual type None"):
12684*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
12685*da0073e9SAndroid Build Coastguard Worker            def test_no_type():
12686*da0073e9SAndroid Build Coastguard Worker                # type: () -> int
12687*da0073e9SAndroid Build Coastguard Worker                return torch.jit._unwrap_optional(None)
12688*da0073e9SAndroid Build Coastguard Worker
12689*da0073e9SAndroid Build Coastguard Worker    def test_indexing_error(self):
12690*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "'int' object is not subscriptable"):
12691*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
12692*da0073e9SAndroid Build Coastguard Worker            def test_wrong_type():
12693*da0073e9SAndroid Build Coastguard Worker                a = 8
12694*da0073e9SAndroid Build Coastguard Worker                return a[0]
12695*da0073e9SAndroid Build Coastguard Worker
12696*da0073e9SAndroid Build Coastguard Worker    def test_unsupported_builtin_error(self):
12697*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError,
12698*da0073e9SAndroid Build Coastguard Worker                                    "Python builtin <built-in function hypot> is currently"):
12699*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
12700*da0073e9SAndroid Build Coastguard Worker            def test_unsupported(a):
12701*da0073e9SAndroid Build Coastguard Worker                return math.hypot(a, 2.0)
12702*da0073e9SAndroid Build Coastguard Worker
12703*da0073e9SAndroid Build Coastguard Worker    def test_annotated_script_fn(self):
12704*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
12705*da0073e9SAndroid Build Coastguard Worker        def foo(x, y, z):
12706*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor, Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tuple[Tensor, Tensor]]) -> Tensor
12707*da0073e9SAndroid Build Coastguard Worker            return x
12708*da0073e9SAndroid Build Coastguard Worker
12709*da0073e9SAndroid Build Coastguard Worker        self.assertExpected(str(foo.schema))
12710*da0073e9SAndroid Build Coastguard Worker
12711*da0073e9SAndroid Build Coastguard Worker    def test_annotated_script_method(self):
12712*da0073e9SAndroid Build Coastguard Worker        class SM(torch.jit.ScriptModule):
12713*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
12714*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y):
12715*da0073e9SAndroid Build Coastguard Worker                # type: (Tuple[Tensor, Tensor], Tensor) -> Tuple[Tensor, Tensor, Tensor]
12716*da0073e9SAndroid Build Coastguard Worker                return y, y, y
12717*da0073e9SAndroid Build Coastguard Worker
12718*da0073e9SAndroid Build Coastguard Worker        sm = SM()
12719*da0073e9SAndroid Build Coastguard Worker
12720*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedStripMangled(str(sm.forward.schema))
12721*da0073e9SAndroid Build Coastguard Worker
12722*da0073e9SAndroid Build Coastguard Worker    def test_annotated_script_fn_return_mismatch(self):
12723*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "but is actually of type"):
12724*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
12725*da0073e9SAndroid Build Coastguard Worker            def return_tup(x):
12726*da0073e9SAndroid Build Coastguard Worker                # type: (Tensor) -> Tuple[Tuple[Tensor, Tensor], Tensor]
12727*da0073e9SAndroid Build Coastguard Worker                return x, x  # noqa: T484
12728*da0073e9SAndroid Build Coastguard Worker
12729*da0073e9SAndroid Build Coastguard Worker    def test_annotated_script_fn_arg_mismatch(self):
12730*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"Arguments for call are not valid"):
12731*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
12732*da0073e9SAndroid Build Coastguard Worker            def tuple_arg(x):
12733*da0073e9SAndroid Build Coastguard Worker                # type: (Tuple[Tensor, Tensor]) -> Tensor
12734*da0073e9SAndroid Build Coastguard Worker                return x + 1  # noqa: T484
12735*da0073e9SAndroid Build Coastguard Worker
12736*da0073e9SAndroid Build Coastguard Worker    def test_script_non_tensor_args_outputs(self):
12737*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
12738*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
12739*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor, float) -> float
12740*da0073e9SAndroid Build Coastguard Worker            return float((x + y).sum())
12741*da0073e9SAndroid Build Coastguard Worker
12742*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2, 2)
12743*da0073e9SAndroid Build Coastguard Worker        z = fn(x, 1)
12744*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(z, float)
12745*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z, 8.)
12746*da0073e9SAndroid Build Coastguard Worker
12747*da0073e9SAndroid Build Coastguard Worker    @unittest.skip('https://github.com/pytorch/pytorch/issues/9595')
12748*da0073e9SAndroid Build Coastguard Worker    def test_inline_and_run_annotated_script_fn(self):
12749*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
12750*da0073e9SAndroid Build Coastguard Worker        def to_inline(x, y):
12751*da0073e9SAndroid Build Coastguard Worker            # type: (Tuple[Tensor, Tensor], Tensor) -> Tensor
12752*da0073e9SAndroid Build Coastguard Worker            return y
12753*da0073e9SAndroid Build Coastguard Worker
12754*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
12755*da0073e9SAndroid Build Coastguard Worker        def some_func(x):
12756*da0073e9SAndroid Build Coastguard Worker            return to_inline((x, x), x)
12757*da0073e9SAndroid Build Coastguard Worker
12758*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
12759*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(some_func(x), x)
12760*da0073e9SAndroid Build Coastguard Worker
12761*da0073e9SAndroid Build Coastguard Worker    def _make_filereader_test_file(self):
12762*da0073e9SAndroid Build Coastguard Worker        filename = tempfile.mktemp()
12763*da0073e9SAndroid Build Coastguard Worker        writer = torch._C.PyTorchFileWriter(filename)
12764*da0073e9SAndroid Build Coastguard Worker        buffers = [os.urandom(size) for size in [random.randint(1, 100) for i in range(20)]]
12765*da0073e9SAndroid Build Coastguard Worker        offsets = []
12766*da0073e9SAndroid Build Coastguard Worker        for i, buf in enumerate(buffers):
12767*da0073e9SAndroid Build Coastguard Worker            writer.write_record(str(i), buf, len(buf))
12768*da0073e9SAndroid Build Coastguard Worker            offsets.append(i)
12769*da0073e9SAndroid Build Coastguard Worker        serialized_offsets = pickle.dumps(offsets)
12770*da0073e9SAndroid Build Coastguard Worker        writer.write_record("meta", serialized_offsets, len(serialized_offsets))
12771*da0073e9SAndroid Build Coastguard Worker        writer.write_end_of_file()
12772*da0073e9SAndroid Build Coastguard Worker        return filename, buffers, serialized_offsets
12773*da0073e9SAndroid Build Coastguard Worker
12774*da0073e9SAndroid Build Coastguard Worker    def test_file_format_serialization(self):
12775*da0073e9SAndroid Build Coastguard Worker        filename, buffers, serialized_offsets = self._make_filereader_test_file()
12776*da0073e9SAndroid Build Coastguard Worker
12777*da0073e9SAndroid Build Coastguard Worker        reader = torch._C.PyTorchFileReader(filename)
12778*da0073e9SAndroid Build Coastguard Worker        serialized_offsets_read = reader.get_record("meta")
12779*da0073e9SAndroid Build Coastguard Worker        parsed_serialized_offsets = pickle.loads(serialized_offsets)
12780*da0073e9SAndroid Build Coastguard Worker
12781*da0073e9SAndroid Build Coastguard Worker        for i, offset in enumerate(parsed_serialized_offsets):
12782*da0073e9SAndroid Build Coastguard Worker            data = reader.get_record(str(offset))
12783*da0073e9SAndroid Build Coastguard Worker            assert data == buffers[i]
12784*da0073e9SAndroid Build Coastguard Worker
12785*da0073e9SAndroid Build Coastguard Worker    def test_file_reader_no_memory_leak(self):
12786*da0073e9SAndroid Build Coastguard Worker        num_iters = 10000
12787*da0073e9SAndroid Build Coastguard Worker        filename, _, _ = self._make_filereader_test_file()
12788*da0073e9SAndroid Build Coastguard Worker
12789*da0073e9SAndroid Build Coastguard Worker        # Load from filename
12790*da0073e9SAndroid Build Coastguard Worker        tracemalloc.start()
12791*da0073e9SAndroid Build Coastguard Worker        for i in range(num_iters):
12792*da0073e9SAndroid Build Coastguard Worker            torch._C.PyTorchFileReader(filename)
12793*da0073e9SAndroid Build Coastguard Worker        _, peak_from_string = tracemalloc.get_traced_memory()
12794*da0073e9SAndroid Build Coastguard Worker        tracemalloc.stop()
12795*da0073e9SAndroid Build Coastguard Worker
12796*da0073e9SAndroid Build Coastguard Worker        # Load from stream
12797*da0073e9SAndroid Build Coastguard Worker        tracemalloc.start()
12798*da0073e9SAndroid Build Coastguard Worker        with open(filename, 'rb') as f:
12799*da0073e9SAndroid Build Coastguard Worker            for i in range(num_iters):
12800*da0073e9SAndroid Build Coastguard Worker                f.seek(0)
12801*da0073e9SAndroid Build Coastguard Worker                torch._C.PyTorchFileReader(f)
12802*da0073e9SAndroid Build Coastguard Worker        _, peak_from_file = tracemalloc.get_traced_memory()
12803*da0073e9SAndroid Build Coastguard Worker        tracemalloc.stop()
12804*da0073e9SAndroid Build Coastguard Worker
12805*da0073e9SAndroid Build Coastguard Worker        # Check if the peak sizes at most differ by an empirically obtained factor
12806*da0073e9SAndroid Build Coastguard Worker        self.assertLess(peak_from_file, peak_from_string * 500)
12807*da0073e9SAndroid Build Coastguard Worker
12808*da0073e9SAndroid Build Coastguard Worker    # for each type, the input type annotation and corresponding return type annotation
12809*da0073e9SAndroid Build Coastguard Worker    def type_input_return_pairs(self):
12810*da0073e9SAndroid Build Coastguard Worker        return [
12811*da0073e9SAndroid Build Coastguard Worker            ('Tensor', 'Tensor'),
12812*da0073e9SAndroid Build Coastguard Worker            ('torch.Tensor', 'Tensor'),
12813*da0073e9SAndroid Build Coastguard Worker            ('str', 'str'),
12814*da0073e9SAndroid Build Coastguard Worker            ('int', 'int'),
12815*da0073e9SAndroid Build Coastguard Worker            ('bool', 'bool'),
12816*da0073e9SAndroid Build Coastguard Worker            ('BroadcastingList3[float]', 'List[float]'),
12817*da0073e9SAndroid Build Coastguard Worker            ('BroadcastingList2[int]', 'List[int]'),
12818*da0073e9SAndroid Build Coastguard Worker            ('List[int]', 'List[int]'),
12819*da0073e9SAndroid Build Coastguard Worker            ('Optional[int]', 'Optional[int]'),
12820*da0073e9SAndroid Build Coastguard Worker        ]
12821*da0073e9SAndroid Build Coastguard Worker
12822*da0073e9SAndroid Build Coastguard Worker    # replacing code input & return type pair
12823*da0073e9SAndroid Build Coastguard Worker    def format_code(self, code, pair):
12824*da0073e9SAndroid Build Coastguard Worker        return code.format(input=pair[0], output=pair[1])
12825*da0073e9SAndroid Build Coastguard Worker
12826*da0073e9SAndroid Build Coastguard Worker    # ***** Type annotation tests ****
12827*da0073e9SAndroid Build Coastguard Worker    # Test combinations of:
12828*da0073e9SAndroid Build Coastguard Worker    # {String frontend, Python AST Frontend}
12829*da0073e9SAndroid Build Coastguard Worker    # {Python 3-style type annotations, MyPy-style type comments}
12830*da0073e9SAndroid Build Coastguard Worker    # {Script method, Script function}
12831*da0073e9SAndroid Build Coastguard Worker
12832*da0073e9SAndroid Build Coastguard Worker    #  String frontend , Python 3-style type annotations , Script function
12833*da0073e9SAndroid Build Coastguard Worker    def test_annot_string_py3_fn(self):
12834*da0073e9SAndroid Build Coastguard Worker        code = '''
12835*da0073e9SAndroid Build Coastguard Worker            def foo(x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
12836*da0073e9SAndroid Build Coastguard Worker                return x, x
12837*da0073e9SAndroid Build Coastguard Worker        '''
12838*da0073e9SAndroid Build Coastguard Worker        test_str = []
12839*da0073e9SAndroid Build Coastguard Worker        for pair in self.type_input_return_pairs():
12840*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit(self.format_code(code, pair))
12841*da0073e9SAndroid Build Coastguard Worker            test_str.append(str(cu.foo.schema))
12842*da0073e9SAndroid Build Coastguard Worker        self.assertExpected("\n".join(test_str) + "\n")
12843*da0073e9SAndroid Build Coastguard Worker
12844*da0073e9SAndroid Build Coastguard Worker    #  String frontend , Python 3-style type annotations , Script method
12845*da0073e9SAndroid Build Coastguard Worker    def test_annot_string_py3_method(self):
12846*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.jit.ScriptModule):
12847*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
12848*da0073e9SAndroid Build Coastguard Worker                super().__init__()
12849*da0073e9SAndroid Build Coastguard Worker
12850*da0073e9SAndroid Build Coastguard Worker        code = '''
12851*da0073e9SAndroid Build Coastguard Worker            def foo(self, x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
12852*da0073e9SAndroid Build Coastguard Worker                return x, x
12853*da0073e9SAndroid Build Coastguard Worker        '''
12854*da0073e9SAndroid Build Coastguard Worker        test_str = []
12855*da0073e9SAndroid Build Coastguard Worker        for pair in self.type_input_return_pairs():
12856*da0073e9SAndroid Build Coastguard Worker            # clear the class registry as we will be defining foo multiple times
12857*da0073e9SAndroid Build Coastguard Worker            jit_utils.clear_class_registry()
12858*da0073e9SAndroid Build Coastguard Worker            tm = TestModule()
12859*da0073e9SAndroid Build Coastguard Worker            tm.define(self.format_code(code, pair))
12860*da0073e9SAndroid Build Coastguard Worker            test_str.append(str(tm.foo.schema))
12861*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedStripMangled("\n".join(test_str) + "\n")
12862*da0073e9SAndroid Build Coastguard Worker
12863*da0073e9SAndroid Build Coastguard Worker    #  String frontend , MyPy-style type comments , Script function
12864*da0073e9SAndroid Build Coastguard Worker    def test_annot_string_mypy_fn(self):
12865*da0073e9SAndroid Build Coastguard Worker        code = '''
12866*da0073e9SAndroid Build Coastguard Worker            def foo(x, y):
12867*da0073e9SAndroid Build Coastguard Worker                # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
12868*da0073e9SAndroid Build Coastguard Worker                return x, x
12869*da0073e9SAndroid Build Coastguard Worker        '''
12870*da0073e9SAndroid Build Coastguard Worker        test_str = []
12871*da0073e9SAndroid Build Coastguard Worker        for pair in self.type_input_return_pairs():
12872*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit(self.format_code(code, pair))
12873*da0073e9SAndroid Build Coastguard Worker            test_str.append(str(cu.foo.schema))
12874*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedStripMangled("\n".join(test_str) + "\n")
12875*da0073e9SAndroid Build Coastguard Worker
12876*da0073e9SAndroid Build Coastguard Worker    #  String frontend , MyPy-style type comments , Script method
12877*da0073e9SAndroid Build Coastguard Worker    def test_annot_string_mypy_method(self):
12878*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.jit.ScriptModule):
12879*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
12880*da0073e9SAndroid Build Coastguard Worker                super().__init__()
12881*da0073e9SAndroid Build Coastguard Worker
12882*da0073e9SAndroid Build Coastguard Worker        code = '''
12883*da0073e9SAndroid Build Coastguard Worker        def foo(self, x, y):
12884*da0073e9SAndroid Build Coastguard Worker            # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
12885*da0073e9SAndroid Build Coastguard Worker            return x, x
12886*da0073e9SAndroid Build Coastguard Worker        '''
12887*da0073e9SAndroid Build Coastguard Worker
12888*da0073e9SAndroid Build Coastguard Worker        test_str = []
12889*da0073e9SAndroid Build Coastguard Worker        for pair in self.type_input_return_pairs():
12890*da0073e9SAndroid Build Coastguard Worker            # clear the class registry as we will be defining foo multiple times
12891*da0073e9SAndroid Build Coastguard Worker            jit_utils.clear_class_registry()
12892*da0073e9SAndroid Build Coastguard Worker            tm = TestModule()
12893*da0073e9SAndroid Build Coastguard Worker            tm.define(self.format_code(code, pair))
12894*da0073e9SAndroid Build Coastguard Worker            test_str.append(str(tm.foo.schema))
12895*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedStripMangled("\n".join(test_str) + "\n")
12896*da0073e9SAndroid Build Coastguard Worker
12897*da0073e9SAndroid Build Coastguard Worker    #  Python AST Frontend , Python 3-style type annotations , Script function
12898*da0073e9SAndroid Build Coastguard Worker    def test_annot_ast_py3_fn(self):
12899*da0073e9SAndroid Build Coastguard Worker        code = dedent('''
12900*da0073e9SAndroid Build Coastguard Worker            from typing import Tuple, List, Optional
12901*da0073e9SAndroid Build Coastguard Worker            from torch import Tensor
12902*da0073e9SAndroid Build Coastguard Worker            from torch.jit.annotations import BroadcastingList2, BroadcastingList3
12903*da0073e9SAndroid Build Coastguard Worker            import torch
12904*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
12905*da0073e9SAndroid Build Coastguard Worker            def foo(x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
12906*da0073e9SAndroid Build Coastguard Worker                return x, x
12907*da0073e9SAndroid Build Coastguard Worker        ''')
12908*da0073e9SAndroid Build Coastguard Worker        test_str = []
12909*da0073e9SAndroid Build Coastguard Worker        for pair in self.type_input_return_pairs():
12910*da0073e9SAndroid Build Coastguard Worker            fn = jit_utils._get_py3_code(self.format_code(code, pair), 'foo')
12911*da0073e9SAndroid Build Coastguard Worker            test_str.append(str(fn.schema))
12912*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedStripMangled("\n".join(test_str) + "\n")
12913*da0073e9SAndroid Build Coastguard Worker
12914*da0073e9SAndroid Build Coastguard Worker    def test_multiline_annot_ast_py3_fn(self):
12915*da0073e9SAndroid Build Coastguard Worker        code = dedent('''
12916*da0073e9SAndroid Build Coastguard Worker            from typing import Tuple, List, Optional
12917*da0073e9SAndroid Build Coastguard Worker            from torch import Tensor
12918*da0073e9SAndroid Build Coastguard Worker            from torch.jit.annotations import BroadcastingList2, BroadcastingList3
12919*da0073e9SAndroid Build Coastguard Worker            import torch
12920*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
12921*da0073e9SAndroid Build Coastguard Worker            def foo(x,  # type: {input}
12922*da0073e9SAndroid Build Coastguard Worker                    y   # type: Tuple[Tensor, Tensor]
12923*da0073e9SAndroid Build Coastguard Worker                    ):
12924*da0073e9SAndroid Build Coastguard Worker                # type: (...) -> Tuple[{output}, {output}]
12925*da0073e9SAndroid Build Coastguard Worker                return x, x
12926*da0073e9SAndroid Build Coastguard Worker        ''')
12927*da0073e9SAndroid Build Coastguard Worker        test_str = []
12928*da0073e9SAndroid Build Coastguard Worker
12929*da0073e9SAndroid Build Coastguard Worker        for pair in self.type_input_return_pairs():
12930*da0073e9SAndroid Build Coastguard Worker            fn = jit_utils._get_py3_code(self.format_code(code, pair), 'foo')
12931*da0073e9SAndroid Build Coastguard Worker            args = fn.schema.arguments
12932*da0073e9SAndroid Build Coastguard Worker            returns = fn.schema.returns
12933*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(str(args[0].type), pair[1])
12934*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(str(args[1].type), "Tuple[Tensor, Tensor]")
12935*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(str(returns[0].type), f"Tuple[{pair[1]}, {pair[1]}]")
12936*da0073e9SAndroid Build Coastguard Worker
12937*da0073e9SAndroid Build Coastguard Worker    def test_bad_multiline_annotations(self):
12938*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Return type line"):
12939*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
12940*da0073e9SAndroid Build Coastguard Worker            def bad_type_line(a,  # type: Tensor
12941*da0073e9SAndroid Build Coastguard Worker                              b,  # type: Tensor
12942*da0073e9SAndroid Build Coastguard Worker                              c   # type: Tensor
12943*da0073e9SAndroid Build Coastguard Worker                              ):
12944*da0073e9SAndroid Build Coastguard Worker                # type: (int, int, int) -> Tensor
12945*da0073e9SAndroid Build Coastguard Worker                # type: bad type line  # noqa: F723
12946*da0073e9SAndroid Build Coastguard Worker
12947*da0073e9SAndroid Build Coastguard Worker                return a + b + c
12948*da0073e9SAndroid Build Coastguard Worker
12949*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Return type line"):
12950*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
12951*da0073e9SAndroid Build Coastguard Worker            def bad_return_line(a,  # type: Tensor
12952*da0073e9SAndroid Build Coastguard Worker                                b,
12953*da0073e9SAndroid Build Coastguard Worker                                c   # type: Tensor
12954*da0073e9SAndroid Build Coastguard Worker                                ):
12955*da0073e9SAndroid Build Coastguard Worker                # type: (int, int, int) -> Tensor
12956*da0073e9SAndroid Build Coastguard Worker                return a + b + c
12957*da0073e9SAndroid Build Coastguard Worker
12958*da0073e9SAndroid Build Coastguard Worker        # TODO: this should be supported but is difficult to parse
12959*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Number of type annotations"):
12960*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
12961*da0073e9SAndroid Build Coastguard Worker            def missing_type(a,  # type: Tensor
12962*da0073e9SAndroid Build Coastguard Worker                             b,
12963*da0073e9SAndroid Build Coastguard Worker                             c   # type: Tensor
12964*da0073e9SAndroid Build Coastguard Worker                             ):
12965*da0073e9SAndroid Build Coastguard Worker                # type: (...) -> Tensor
12966*da0073e9SAndroid Build Coastguard Worker                return a + b + c
12967*da0073e9SAndroid Build Coastguard Worker
12968*da0073e9SAndroid Build Coastguard Worker    #  Python AST Frontend , Python 3-style type annotations , Script method
12969*da0073e9SAndroid Build Coastguard Worker    def test_annot_ast_py3_method(self):
12970*da0073e9SAndroid Build Coastguard Worker        code = dedent('''
12971*da0073e9SAndroid Build Coastguard Worker            from typing import Tuple, List, Optional
12972*da0073e9SAndroid Build Coastguard Worker            from torch import Tensor
12973*da0073e9SAndroid Build Coastguard Worker            from torch.jit.annotations import BroadcastingList2, \\
12974*da0073e9SAndroid Build Coastguard Worker                BroadcastingList3
12975*da0073e9SAndroid Build Coastguard Worker            import torch
12976*da0073e9SAndroid Build Coastguard Worker            class FooModule(torch.jit.ScriptModule):
12977*da0073e9SAndroid Build Coastguard Worker                @torch.jit.script_method
12978*da0073e9SAndroid Build Coastguard Worker                def foo(self, x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
12979*da0073e9SAndroid Build Coastguard Worker                    return x, x
12980*da0073e9SAndroid Build Coastguard Worker            instance = FooModule()
12981*da0073e9SAndroid Build Coastguard Worker        ''')
12982*da0073e9SAndroid Build Coastguard Worker
12983*da0073e9SAndroid Build Coastguard Worker        test_str = []
12984*da0073e9SAndroid Build Coastguard Worker        for pair in self.type_input_return_pairs():
12985*da0073e9SAndroid Build Coastguard Worker            fn = jit_utils._get_py3_code(self.format_code(code, pair), 'instance')
12986*da0073e9SAndroid Build Coastguard Worker            test_str.append(str(fn.foo.schema))
12987*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedStripMangled("\n".join(test_str) + "\n")
12988*da0073e9SAndroid Build Coastguard Worker
12989*da0073e9SAndroid Build Coastguard Worker    #  Python AST Frontend , MyPy-style type comments , Script function
12990*da0073e9SAndroid Build Coastguard Worker    def test_annot_ast_mypy_fn(self):
12991*da0073e9SAndroid Build Coastguard Worker        code = dedent('''
12992*da0073e9SAndroid Build Coastguard Worker            import torch
12993*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
12994*da0073e9SAndroid Build Coastguard Worker            def foo(x, y):
12995*da0073e9SAndroid Build Coastguard Worker                # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
12996*da0073e9SAndroid Build Coastguard Worker                return x, x
12997*da0073e9SAndroid Build Coastguard Worker        ''')
12998*da0073e9SAndroid Build Coastguard Worker
12999*da0073e9SAndroid Build Coastguard Worker        test_str = []
13000*da0073e9SAndroid Build Coastguard Worker        for pair in self.type_input_return_pairs():
13001*da0073e9SAndroid Build Coastguard Worker            fn = jit_utils._get_py3_code(self.format_code(code, pair), 'foo')
13002*da0073e9SAndroid Build Coastguard Worker            test_str.append(str(fn.schema))
13003*da0073e9SAndroid Build Coastguard Worker        self.assertExpected("\n".join(test_str) + "\n")
13004*da0073e9SAndroid Build Coastguard Worker
13005*da0073e9SAndroid Build Coastguard Worker    #  Python AST Frontend , MyPy-style type comments , Script method
13006*da0073e9SAndroid Build Coastguard Worker    def test_annot_ast_mypy_method(self):
13007*da0073e9SAndroid Build Coastguard Worker        code = dedent('''
13008*da0073e9SAndroid Build Coastguard Worker            import torch
13009*da0073e9SAndroid Build Coastguard Worker            class FooModule(torch.jit.ScriptModule):
13010*da0073e9SAndroid Build Coastguard Worker                @torch.jit.script_method
13011*da0073e9SAndroid Build Coastguard Worker                def foo(self, x, y):
13012*da0073e9SAndroid Build Coastguard Worker                    # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
13013*da0073e9SAndroid Build Coastguard Worker                    return x, x
13014*da0073e9SAndroid Build Coastguard Worker            instance = FooModule()
13015*da0073e9SAndroid Build Coastguard Worker        ''')
13016*da0073e9SAndroid Build Coastguard Worker
13017*da0073e9SAndroid Build Coastguard Worker        test_str = []
13018*da0073e9SAndroid Build Coastguard Worker        for pair in self.type_input_return_pairs():
13019*da0073e9SAndroid Build Coastguard Worker            fn = jit_utils._get_py3_code(self.format_code(code, pair), 'instance')
13020*da0073e9SAndroid Build Coastguard Worker            test_str.append(str(fn.foo.schema))
13021*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedStripMangled("\n".join(test_str) + "\n")
13022*da0073e9SAndroid Build Coastguard Worker
13023*da0073e9SAndroid Build Coastguard Worker    # Tests that "# type: ignore[*]" is supported in type lines and is
13024*da0073e9SAndroid Build Coastguard Worker    # properly ignored.
13025*da0073e9SAndroid Build Coastguard Worker    def test_mypy_type_ignore(self):
13026*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
13027*da0073e9SAndroid Build Coastguard Worker        def foo(x):  # type: ignore
13028*da0073e9SAndroid Build Coastguard Worker            return x
13029*da0073e9SAndroid Build Coastguard Worker
13030*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
13031*da0073e9SAndroid Build Coastguard Worker        def bar(x):  # type: ignore[no-redef]
13032*da0073e9SAndroid Build Coastguard Worker            return x
13033*da0073e9SAndroid Build Coastguard Worker
13034*da0073e9SAndroid Build Coastguard Worker    def test_method_casts_script(self):
13035*da0073e9SAndroid Build Coastguard Worker        cast_types = [
13036*da0073e9SAndroid Build Coastguard Worker            'byte', 'char', 'double', 'float', 'int', 'long', 'short'
13037*da0073e9SAndroid Build Coastguard Worker        ]
13038*da0073e9SAndroid Build Coastguard Worker
13039*da0073e9SAndroid Build Coastguard Worker        for cast_type in cast_types:
13040*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit(f'''
13041*da0073e9SAndroid Build Coastguard Worker            def cast_to(x):
13042*da0073e9SAndroid Build Coastguard Worker                return x.{cast_type}()
13043*da0073e9SAndroid Build Coastguard Worker            ''')
13044*da0073e9SAndroid Build Coastguard Worker
13045*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(3, 4, 5) * 128
13046*da0073e9SAndroid Build Coastguard Worker            cu_result = cu.cast_to(x)
13047*da0073e9SAndroid Build Coastguard Worker            reference = getattr(x, cast_type)()
13048*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cu_result, reference)
13049*da0073e9SAndroid Build Coastguard Worker
13050*da0073e9SAndroid Build Coastguard Worker    def test_string_frontend_elif(self):
13051*da0073e9SAndroid Build Coastguard Worker        code = '''
13052*da0073e9SAndroid Build Coastguard Worker            def func(niter):
13053*da0073e9SAndroid Build Coastguard Worker                # type: (int)
13054*da0073e9SAndroid Build Coastguard Worker                rv = 0
13055*da0073e9SAndroid Build Coastguard Worker                for i in range(niter):
13056*da0073e9SAndroid Build Coastguard Worker                    if i % 3 == 0 and i % 5 == 0:
13057*da0073e9SAndroid Build Coastguard Worker                        rv += 35
13058*da0073e9SAndroid Build Coastguard Worker                    elif i % 3 == 0:
13059*da0073e9SAndroid Build Coastguard Worker                        rv += 3
13060*da0073e9SAndroid Build Coastguard Worker                    elif i % 5 == 0:
13061*da0073e9SAndroid Build Coastguard Worker                        rv += 5
13062*da0073e9SAndroid Build Coastguard Worker                    else:
13063*da0073e9SAndroid Build Coastguard Worker                        rv += i
13064*da0073e9SAndroid Build Coastguard Worker                return rv
13065*da0073e9SAndroid Build Coastguard Worker        '''
13066*da0073e9SAndroid Build Coastguard Worker
13067*da0073e9SAndroid Build Coastguard Worker        self.checkScript(dedent(code), (101,))
13068*da0073e9SAndroid Build Coastguard Worker
13069*da0073e9SAndroid Build Coastguard Worker    def test_module_parameters_and_buffers(self):
13070*da0073e9SAndroid Build Coastguard Worker        weights = torch.randn(10, 10)
13071*da0073e9SAndroid Build Coastguard Worker        bias = torch.randn(10)
13072*da0073e9SAndroid Build Coastguard Worker        weights2 = torch.randn(10, 10)
13073*da0073e9SAndroid Build Coastguard Worker        bias2 = torch.randn(10)
13074*da0073e9SAndroid Build Coastguard Worker
13075*da0073e9SAndroid Build Coastguard Worker        class TestLinear(torch.nn.Module):
13076*da0073e9SAndroid Build Coastguard Worker            def __init__(self, in_features, out_features):
13077*da0073e9SAndroid Build Coastguard Worker                super().__init__()
13078*da0073e9SAndroid Build Coastguard Worker                self.in_features = in_features
13079*da0073e9SAndroid Build Coastguard Worker                self.out_features = out_features
13080*da0073e9SAndroid Build Coastguard Worker                self.weight = torch.nn.Parameter(torch.empty(out_features, in_features))
13081*da0073e9SAndroid Build Coastguard Worker                self.bias = torch.nn.Parameter(torch.empty(out_features))
13082*da0073e9SAndroid Build Coastguard Worker                self.counter = nn.Buffer(torch.ones(out_features))
13083*da0073e9SAndroid Build Coastguard Worker                self.reset_parameters()
13084*da0073e9SAndroid Build Coastguard Worker
13085*da0073e9SAndroid Build Coastguard Worker            def reset_parameters(self):
13086*da0073e9SAndroid Build Coastguard Worker                torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
13087*da0073e9SAndroid Build Coastguard Worker                if self.bias is not None:
13088*da0073e9SAndroid Build Coastguard Worker                    fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
13089*da0073e9SAndroid Build Coastguard Worker                    bound = 1 / math.sqrt(fan_in)
13090*da0073e9SAndroid Build Coastguard Worker                    torch.nn.init.uniform_(self.bias, -bound, bound)
13091*da0073e9SAndroid Build Coastguard Worker
13092*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
13093*da0073e9SAndroid Build Coastguard Worker                return F.linear(input, self.weight, self.bias) + self.counter
13094*da0073e9SAndroid Build Coastguard Worker
13095*da0073e9SAndroid Build Coastguard Worker        # Initialize a ScriptModule that uses the weak module above multiple times
13096*da0073e9SAndroid Build Coastguard Worker        class Strong(torch.jit.ScriptModule):
13097*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
13098*da0073e9SAndroid Build Coastguard Worker                super().__init__()
13099*da0073e9SAndroid Build Coastguard Worker                self.fc1 = TestLinear(10, 10)
13100*da0073e9SAndroid Build Coastguard Worker                self.fc1.weight = torch.nn.Parameter(weights)
13101*da0073e9SAndroid Build Coastguard Worker                self.fc1.bias = torch.nn.Parameter(bias)
13102*da0073e9SAndroid Build Coastguard Worker                self.fc2 = TestLinear(10, 10)
13103*da0073e9SAndroid Build Coastguard Worker                self.fc2.weight = torch.nn.Parameter(weights2)
13104*da0073e9SAndroid Build Coastguard Worker                self.fc2.bias = torch.nn.Parameter(bias2)
13105*da0073e9SAndroid Build Coastguard Worker
13106*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
13107*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
13108*da0073e9SAndroid Build Coastguard Worker                return x + self.fc1(x) + self.fc1(x) + self.fc2(x)
13109*da0073e9SAndroid Build Coastguard Worker
13110*da0073e9SAndroid Build Coastguard Worker        strong_mod = Strong()
13111*da0073e9SAndroid Build Coastguard Worker
13112*da0073e9SAndroid Build Coastguard Worker        # Run same calculation as module
13113*da0073e9SAndroid Build Coastguard Worker        inp = torch.ones(10)
13114*da0073e9SAndroid Build Coastguard Worker        lin = torch.nn.Linear(10, 10)
13115*da0073e9SAndroid Build Coastguard Worker        lin.weight = torch.nn.Parameter(weights)
13116*da0073e9SAndroid Build Coastguard Worker        lin.bias = torch.nn.Parameter(bias)
13117*da0073e9SAndroid Build Coastguard Worker        lin2 = torch.nn.Linear(10, 10)
13118*da0073e9SAndroid Build Coastguard Worker        lin2.weight = torch.nn.Parameter(weights2)
13119*da0073e9SAndroid Build Coastguard Worker        lin2.bias = torch.nn.Parameter(bias2)
13120*da0073e9SAndroid Build Coastguard Worker        expected_result = inp + (lin(inp) + torch.ones(10)) * 2 + lin2(inp) + torch.ones(10)
13121*da0073e9SAndroid Build Coastguard Worker
13122*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(strong_mod(inp), expected_result)
13123*da0073e9SAndroid Build Coastguard Worker        self.assertExportImportModule(strong_mod, (inp,))
13124*da0073e9SAndroid Build Coastguard Worker
13125*da0073e9SAndroid Build Coastguard Worker    def test_module_copying(self):
13126*da0073e9SAndroid Build Coastguard Worker        class Submodule(torch.nn.Module):
13127*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
13128*da0073e9SAndroid Build Coastguard Worker                return x + 100
13129*da0073e9SAndroid Build Coastguard Worker
13130*da0073e9SAndroid Build Coastguard Worker        class Weak(torch.nn.Module):
13131*da0073e9SAndroid Build Coastguard Worker            def __init__(self, in_features, out_features):
13132*da0073e9SAndroid Build Coastguard Worker                super().__init__()
13133*da0073e9SAndroid Build Coastguard Worker                self.weight = torch.nn.Parameter(torch.ones(out_features, in_features))
13134*da0073e9SAndroid Build Coastguard Worker                self.bias = torch.nn.Parameter(torch.ones(out_features))
13135*da0073e9SAndroid Build Coastguard Worker                self.buffer = nn.Buffer(torch.ones(out_features))
13136*da0073e9SAndroid Build Coastguard Worker                self.submodule = Submodule()
13137*da0073e9SAndroid Build Coastguard Worker
13138*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
13139*da0073e9SAndroid Build Coastguard Worker                return F.linear(x, self.weight, self.bias) \
13140*da0073e9SAndroid Build Coastguard Worker                    + self.buffer + self.submodule(x)
13141*da0073e9SAndroid Build Coastguard Worker
13142*da0073e9SAndroid Build Coastguard Worker        class Strong(torch.jit.ScriptModule):
13143*da0073e9SAndroid Build Coastguard Worker            def __init__(self, weak):
13144*da0073e9SAndroid Build Coastguard Worker                super().__init__()
13145*da0073e9SAndroid Build Coastguard Worker                self.weak = weak
13146*da0073e9SAndroid Build Coastguard Worker
13147*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
13148*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
13149*da0073e9SAndroid Build Coastguard Worker                return self.weak(x)
13150*da0073e9SAndroid Build Coastguard Worker
13151*da0073e9SAndroid Build Coastguard Worker        inp = torch.ones(5, 5) * 5
13152*da0073e9SAndroid Build Coastguard Worker        weak_mod = Weak(5, 5)
13153*da0073e9SAndroid Build Coastguard Worker        strong_mod = Strong(weak_mod)
13154*da0073e9SAndroid Build Coastguard Worker
13155*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(strong_mod.weak, torch.jit.ScriptModule))
13156*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(isinstance(weak_mod, torch.jit.ScriptModule))
13157*da0073e9SAndroid Build Coastguard Worker
13158*da0073e9SAndroid Build Coastguard Worker        self.assertIs(strong_mod.weak.weight, weak_mod.weight)
13159*da0073e9SAndroid Build Coastguard Worker        self.assertIs(strong_mod.weak.buffer, weak_mod.buffer)
13160*da0073e9SAndroid Build Coastguard Worker        # strong_mod.weak.submodule has been recursively scripted
13161*da0073e9SAndroid Build Coastguard Worker        self.assertIsNot(strong_mod.weak.submodule, weak_mod.submodule)
13162*da0073e9SAndroid Build Coastguard Worker
13163*da0073e9SAndroid Build Coastguard Worker        weak_mod.weight.data += torch.ones(5, 5) * 100
13164*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(strong_mod(inp).allclose(weak_mod(inp)))
13165*da0073e9SAndroid Build Coastguard Worker
13166*da0073e9SAndroid Build Coastguard Worker        # Re-assignment is not tracked
13167*da0073e9SAndroid Build Coastguard Worker        weak_mod.weight = torch.nn.Parameter(torch.ones(5, 5) * 100)
13168*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(strong_mod(inp).allclose(weak_mod(inp)))
13169*da0073e9SAndroid Build Coastguard Worker
13170*da0073e9SAndroid Build Coastguard Worker    def test_backend_cudnn_enabled(self):
13171*da0073e9SAndroid Build Coastguard Worker        # Only test that this compiles
13172*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
13173*da0073e9SAndroid Build Coastguard Worker        def fn(x):
13174*da0073e9SAndroid Build Coastguard Worker            if torch.backends.cudnn.enabled:
13175*da0073e9SAndroid Build Coastguard Worker                x = x + 2
13176*da0073e9SAndroid Build Coastguard Worker            else:
13177*da0073e9SAndroid Build Coastguard Worker                x = x + 3
13178*da0073e9SAndroid Build Coastguard Worker            return x
13179*da0073e9SAndroid Build Coastguard Worker
13180*da0073e9SAndroid Build Coastguard Worker    def test_inplace_add(self):
13181*da0073e9SAndroid Build Coastguard Worker
13182*da0073e9SAndroid Build Coastguard Worker        def foo(a, b):
13183*da0073e9SAndroid Build Coastguard Worker            c = a + b
13184*da0073e9SAndroid Build Coastguard Worker            c.add_(b)
13185*da0073e9SAndroid Build Coastguard Worker            return c
13186*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, (torch.rand(3), torch.rand(3)))
13187*da0073e9SAndroid Build Coastguard Worker
13188*da0073e9SAndroid Build Coastguard Worker    def test_add_out(self):
13189*da0073e9SAndroid Build Coastguard Worker        def foo(a, b):
13190*da0073e9SAndroid Build Coastguard Worker            c = a + b
13191*da0073e9SAndroid Build Coastguard Worker            e = 2 * a
13192*da0073e9SAndroid Build Coastguard Worker            torch.add(c, b, out=e)
13193*da0073e9SAndroid Build Coastguard Worker            return e
13194*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, (torch.rand(3), torch.rand(3)))
13195*da0073e9SAndroid Build Coastguard Worker
13196*da0073e9SAndroid Build Coastguard Worker    def test_tuple_error_msg(self):
13197*da0073e9SAndroid Build Coastguard Worker        def fn(t: Any):
13198*da0073e9SAndroid Build Coastguard Worker            if isinstance(t, tuple):
13199*da0073e9SAndroid Build Coastguard Worker                a, b = t
13200*da0073e9SAndroid Build Coastguard Worker            return a + b
13201*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegexWithHighlight(RuntimeError, "Provided tuple is not fully defined/refined", "t"):
13202*da0073e9SAndroid Build Coastguard Worker            s = torch.jit.script(fn)
13203*da0073e9SAndroid Build Coastguard Worker
13204*da0073e9SAndroid Build Coastguard Worker    def test_augmented_assign(self):
13205*da0073e9SAndroid Build Coastguard Worker        def foo(a, b):
13206*da0073e9SAndroid Build Coastguard Worker            a += b
13207*da0073e9SAndroid Build Coastguard Worker            a -= b
13208*da0073e9SAndroid Build Coastguard Worker            a /= b
13209*da0073e9SAndroid Build Coastguard Worker            a *= b
13210*da0073e9SAndroid Build Coastguard Worker            return a, b
13211*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, (torch.rand(3), torch.rand(3)))
13212*da0073e9SAndroid Build Coastguard Worker
13213*da0073e9SAndroid Build Coastguard Worker    def test_ignored_props(self):
13214*da0073e9SAndroid Build Coastguard Worker        class A(nn.Module):
13215*da0073e9SAndroid Build Coastguard Worker            __jit_ignored_attributes__ = ["ignored", "ignored_return_val"]
13216*da0073e9SAndroid Build Coastguard Worker
13217*da0073e9SAndroid Build Coastguard Worker            @property
13218*da0073e9SAndroid Build Coastguard Worker            def ignored(self):
13219*da0073e9SAndroid Build Coastguard Worker                raise ValueError("shouldn't be called")
13220*da0073e9SAndroid Build Coastguard Worker
13221*da0073e9SAndroid Build Coastguard Worker            @property
13222*da0073e9SAndroid Build Coastguard Worker            def ignored_return_val(self):
13223*da0073e9SAndroid Build Coastguard Worker                return 1
13224*da0073e9SAndroid Build Coastguard Worker
13225*da0073e9SAndroid Build Coastguard Worker            @torch.jit.ignore
13226*da0073e9SAndroid Build Coastguard Worker            def call(self):
13227*da0073e9SAndroid Build Coastguard Worker                return self.ignored_return_val
13228*da0073e9SAndroid Build Coastguard Worker
13229*da0073e9SAndroid Build Coastguard Worker        f = torch.jit.script(A())
13230*da0073e9SAndroid Build Coastguard Worker        # jank way to test if there is no error
13231*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(f, torch.jit.ScriptModule))
13232*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(f.call(), property))
13233*da0073e9SAndroid Build Coastguard Worker
13234*da0073e9SAndroid Build Coastguard Worker
13235*da0073e9SAndroid Build Coastguard Worker    def test_pass(self):
13236*da0073e9SAndroid Build Coastguard Worker        def foo(x):
13237*da0073e9SAndroid Build Coastguard Worker            # type: (bool) -> int
13238*da0073e9SAndroid Build Coastguard Worker            for _i in range(3):
13239*da0073e9SAndroid Build Coastguard Worker                pass
13240*da0073e9SAndroid Build Coastguard Worker            if x:
13241*da0073e9SAndroid Build Coastguard Worker                pass
13242*da0073e9SAndroid Build Coastguard Worker            else:
13243*da0073e9SAndroid Build Coastguard Worker                pass
13244*da0073e9SAndroid Build Coastguard Worker            return 3
13245*da0073e9SAndroid Build Coastguard Worker
13246*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, (True,))
13247*da0073e9SAndroid Build Coastguard Worker
13248*da0073e9SAndroid Build Coastguard Worker    def test_lhs_indexing(self):
13249*da0073e9SAndroid Build Coastguard Worker        def foo(a, b):
13250*da0073e9SAndroid Build Coastguard Worker            a = a.clone()
13251*da0073e9SAndroid Build Coastguard Worker            a[0] = b
13252*da0073e9SAndroid Build Coastguard Worker            return a
13253*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
13254*da0073e9SAndroid Build Coastguard Worker
13255*da0073e9SAndroid Build Coastguard Worker    def test_lhs_advanced_indexing_assignment(self):
13256*da0073e9SAndroid Build Coastguard Worker        def foo(x, y):
13257*da0073e9SAndroid Build Coastguard Worker            a = torch.exp(x)
13258*da0073e9SAndroid Build Coastguard Worker            b = x == 1
13259*da0073e9SAndroid Build Coastguard Worker            a[b] = y[b]
13260*da0073e9SAndroid Build Coastguard Worker            return a
13261*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, (torch.ones(4, 3), torch.ones(4, 3)))
13262*da0073e9SAndroid Build Coastguard Worker
13263*da0073e9SAndroid Build Coastguard Worker    def test_lhs_advanced_indexing_augmented_assignment(self):
13264*da0073e9SAndroid Build Coastguard Worker        def foo(x, y):
13265*da0073e9SAndroid Build Coastguard Worker            a = torch.exp(x)
13266*da0073e9SAndroid Build Coastguard Worker            b = x == 1
13267*da0073e9SAndroid Build Coastguard Worker            a[b] += y[b]
13268*da0073e9SAndroid Build Coastguard Worker            return a
13269*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, (torch.ones(4, 3), torch.ones(4, 3)))
13270*da0073e9SAndroid Build Coastguard Worker
13271*da0073e9SAndroid Build Coastguard Worker    def test_lhs_indexing_list(self):
13272*da0073e9SAndroid Build Coastguard Worker        def foo(a, b):
13273*da0073e9SAndroid Build Coastguard Worker            ls = [a]
13274*da0073e9SAndroid Build Coastguard Worker            ls[0] = b
13275*da0073e9SAndroid Build Coastguard Worker            return ls
13276*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
13277*da0073e9SAndroid Build Coastguard Worker
13278*da0073e9SAndroid Build Coastguard Worker    def test_inplace_copy_script(self):
13279*da0073e9SAndroid Build Coastguard Worker        def foo(x):
13280*da0073e9SAndroid Build Coastguard Worker            a = torch.rand(3, 4)
13281*da0073e9SAndroid Build Coastguard Worker            a.copy_(x)
13282*da0073e9SAndroid Build Coastguard Worker            return a
13283*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, (torch.rand(3, 4),))
13284*da0073e9SAndroid Build Coastguard Worker
13285*da0073e9SAndroid Build Coastguard Worker    def test_lhs_indexing_increment(self):
13286*da0073e9SAndroid Build Coastguard Worker        def foo(a, b):
13287*da0073e9SAndroid Build Coastguard Worker            a[0] += b
13288*da0073e9SAndroid Build Coastguard Worker            return a
13289*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
13290*da0073e9SAndroid Build Coastguard Worker
13291*da0073e9SAndroid Build Coastguard Worker    def test_lhs_indexing_increment_list(self):
13292*da0073e9SAndroid Build Coastguard Worker        def foo(a, b):
13293*da0073e9SAndroid Build Coastguard Worker            a = a.clone()
13294*da0073e9SAndroid Build Coastguard Worker            ls = [a, b]
13295*da0073e9SAndroid Build Coastguard Worker            ls[0] += b
13296*da0073e9SAndroid Build Coastguard Worker            return ls
13297*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
13298*da0073e9SAndroid Build Coastguard Worker
13299*da0073e9SAndroid Build Coastguard Worker    def test_lhs_indexing_increment_list_prim(self):
13300*da0073e9SAndroid Build Coastguard Worker        def foo():
13301*da0073e9SAndroid Build Coastguard Worker            ls = [1, 2, 3]
13302*da0073e9SAndroid Build Coastguard Worker            ls[0] += 5
13303*da0073e9SAndroid Build Coastguard Worker            return ls
13304*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, ())
13305*da0073e9SAndroid Build Coastguard Worker
13306*da0073e9SAndroid Build Coastguard Worker    def test_lhs_indexing_multi(self):
13307*da0073e9SAndroid Build Coastguard Worker        def foo(a, b):
13308*da0073e9SAndroid Build Coastguard Worker            a = a.clone()
13309*da0073e9SAndroid Build Coastguard Worker            foo, a[0], bar = (1, b, 3)
13310*da0073e9SAndroid Build Coastguard Worker            return foo, a, bar
13311*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
13312*da0073e9SAndroid Build Coastguard Worker
13313*da0073e9SAndroid Build Coastguard Worker    def test_bool_dispatch(self):
13314*da0073e9SAndroid Build Coastguard Worker        with torch._jit_internal._disable_emit_hooks():  # TODO: Python print broadcasting list
13315*da0073e9SAndroid Build Coastguard Worker            def kwarg_false(x):
13316*da0073e9SAndroid Build Coastguard Worker                # type: (Tensor) -> Tensor
13317*da0073e9SAndroid Build Coastguard Worker                return F.max_pool1d(x, 1, 1, return_indices=False)
13318*da0073e9SAndroid Build Coastguard Worker            self.checkScript(kwarg_false, (torch.randn(3, 3, 3),))
13319*da0073e9SAndroid Build Coastguard Worker
13320*da0073e9SAndroid Build Coastguard Worker            def kwarg_true(x):
13321*da0073e9SAndroid Build Coastguard Worker                # type: (Tensor) -> Tuple[Tensor, Tensor]
13322*da0073e9SAndroid Build Coastguard Worker                return F.max_pool1d(x, 1, 1, return_indices=True)
13323*da0073e9SAndroid Build Coastguard Worker            self.checkScript(kwarg_true, (torch.randn(3, 3, 3),))
13324*da0073e9SAndroid Build Coastguard Worker
13325*da0073e9SAndroid Build Coastguard Worker            def full_kwarg_false(x):
13326*da0073e9SAndroid Build Coastguard Worker                # type: (Tensor) -> Tensor
13327*da0073e9SAndroid Build Coastguard Worker                return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=False)
13328*da0073e9SAndroid Build Coastguard Worker            self.checkScript(full_kwarg_false, (torch.randn(3, 3, 3),))
13329*da0073e9SAndroid Build Coastguard Worker
13330*da0073e9SAndroid Build Coastguard Worker            def full_kwarg_true(x):
13331*da0073e9SAndroid Build Coastguard Worker                # type: (Tensor) -> Tuple[Tensor, Tensor]
13332*da0073e9SAndroid Build Coastguard Worker                return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=True)
13333*da0073e9SAndroid Build Coastguard Worker            self.checkScript(full_kwarg_true, (torch.randn(3, 3, 3),))
13334*da0073e9SAndroid Build Coastguard Worker
13335*da0073e9SAndroid Build Coastguard Worker            def use_default(x):
13336*da0073e9SAndroid Build Coastguard Worker                # type: (Tensor) -> Tensor
13337*da0073e9SAndroid Build Coastguard Worker                return F.max_pool1d(x, 1, 1)
13338*da0073e9SAndroid Build Coastguard Worker            self.checkScript(use_default, (torch.randn(3, 3, 3),))
13339*da0073e9SAndroid Build Coastguard Worker
13340*da0073e9SAndroid Build Coastguard Worker            def arg_false(x):
13341*da0073e9SAndroid Build Coastguard Worker                # type: (Tensor) -> Tensor
13342*da0073e9SAndroid Build Coastguard Worker                return F.max_pool1d(x, 1, 1, 0, 1, False, False)
13343*da0073e9SAndroid Build Coastguard Worker            self.checkScript(arg_false, (torch.randn(3, 3, 3),))
13344*da0073e9SAndroid Build Coastguard Worker
13345*da0073e9SAndroid Build Coastguard Worker            def arg_true(x):
13346*da0073e9SAndroid Build Coastguard Worker                # type: (Tensor) -> Tuple[Tensor, Tensor]
13347*da0073e9SAndroid Build Coastguard Worker                return F.max_pool1d(x, 1, 1, 0, 1, False, True)
13348*da0073e9SAndroid Build Coastguard Worker            self.checkScript(arg_true, (torch.randn(3, 3, 3),))
13349*da0073e9SAndroid Build Coastguard Worker
13350*da0073e9SAndroid Build Coastguard Worker    def test_infer_size(self):
13351*da0073e9SAndroid Build Coastguard Worker        from torch._C import _infer_size
13352*da0073e9SAndroid Build Coastguard Worker
13353*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
13354*da0073e9SAndroid Build Coastguard Worker            # type: (Tensor, Tensor) -> List[int]
13355*da0073e9SAndroid Build Coastguard Worker            return _infer_size(x.size(), y.size())
13356*da0073e9SAndroid Build Coastguard Worker
13357*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (torch.ones(2, 4, 2), torch.ones(2, 4, 2)))
13358*da0073e9SAndroid Build Coastguard Worker
13359*da0073e9SAndroid Build Coastguard Worker    def test_hash(self):
13360*da0073e9SAndroid Build Coastguard Worker        def tester(fn, inputs):
13361*da0073e9SAndroid Build Coastguard Worker            for x in inputs:
13362*da0073e9SAndroid Build Coastguard Worker                for y in inputs:
13363*da0073e9SAndroid Build Coastguard Worker                    if x == y:
13364*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(fn(x), fn(y))
13365*da0073e9SAndroid Build Coastguard Worker                    else:
13366*da0073e9SAndroid Build Coastguard Worker                        self.assertNotEqual(fn(x), fn(y))
13367*da0073e9SAndroid Build Coastguard Worker
13368*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
13369*da0073e9SAndroid Build Coastguard Worker        def int_hash(x):
13370*da0073e9SAndroid Build Coastguard Worker            # type: (int) -> int
13371*da0073e9SAndroid Build Coastguard Worker            return hash(x)
13372*da0073e9SAndroid Build Coastguard Worker
13373*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
13374*da0073e9SAndroid Build Coastguard Worker        def float_hash(x):
13375*da0073e9SAndroid Build Coastguard Worker            # type: (float) -> int
13376*da0073e9SAndroid Build Coastguard Worker            return hash(x)
13377*da0073e9SAndroid Build Coastguard Worker
13378*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
13379*da0073e9SAndroid Build Coastguard Worker        def str_hash(x):
13380*da0073e9SAndroid Build Coastguard Worker            # type: (str) -> int
13381*da0073e9SAndroid Build Coastguard Worker            return hash(x)
13382*da0073e9SAndroid Build Coastguard Worker
13383*da0073e9SAndroid Build Coastguard Worker        tester(int_hash, (20, 21, 22))
13384*da0073e9SAndroid Build Coastguard Worker        tester(float_hash, (20.0, 21.00001, 22.443))
13385*da0073e9SAndroid Build Coastguard Worker        tester(str_hash, ("", "hello", "a"))
13386*da0073e9SAndroid Build Coastguard Worker
13387*da0073e9SAndroid Build Coastguard Worker    def test_id(self):
13388*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Expected a value"):
13389*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
13390*da0073e9SAndroid Build Coastguard Worker            def test_id_scalars():
13391*da0073e9SAndroid Build Coastguard Worker                return id(2) == id(None)
13392*da0073e9SAndroid Build Coastguard Worker
13393*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
13394*da0073e9SAndroid Build Coastguard Worker        class FooTest:
13395*da0073e9SAndroid Build Coastguard Worker            def __init__(self, x):
13396*da0073e9SAndroid Build Coastguard Worker                self.foo = x
13397*da0073e9SAndroid Build Coastguard Worker
13398*da0073e9SAndroid Build Coastguard Worker            def getFooTest(self):
13399*da0073e9SAndroid Build Coastguard Worker                return self.foo
13400*da0073e9SAndroid Build Coastguard Worker
13401*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
13402*da0073e9SAndroid Build Coastguard Worker        def test_id_class_types():
13403*da0073e9SAndroid Build Coastguard Worker            obj1 = FooTest(torch.tensor(3))
13404*da0073e9SAndroid Build Coastguard Worker            obj2 = FooTest(torch.tensor(2))
13405*da0073e9SAndroid Build Coastguard Worker            assert obj1 is not obj2
13406*da0073e9SAndroid Build Coastguard Worker            assert id(obj1) != id(obj2)
13407*da0073e9SAndroid Build Coastguard Worker            assert id(obj1) != id(None)
13408*da0073e9SAndroid Build Coastguard Worker            return True
13409*da0073e9SAndroid Build Coastguard Worker
13410*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(test_id_class_types())
13411*da0073e9SAndroid Build Coastguard Worker
13412*da0073e9SAndroid Build Coastguard Worker    def test_mutable_dce(self):
13413*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
13414*da0073e9SAndroid Build Coastguard Worker        def foo():
13415*da0073e9SAndroid Build Coastguard Worker            a = torch.rand(2, 3)
13416*da0073e9SAndroid Build Coastguard Worker            a += torch.rand(2, 3)
13417*da0073e9SAndroid Build Coastguard Worker            b = torch.rand(2, 3)
13418*da0073e9SAndroid Build Coastguard Worker            b += torch.rand(2, 3)
13419*da0073e9SAndroid Build Coastguard Worker            # b should be cleaned up but not a
13420*da0073e9SAndroid Build Coastguard Worker            return a
13421*da0073e9SAndroid Build Coastguard Worker
13422*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("aten::rand", 2, exactly=True) \
13423*da0073e9SAndroid Build Coastguard Worker            .check_count("aten::add", 1, exactly=True).run(str(foo.graph))
13424*da0073e9SAndroid Build Coastguard Worker
13425*da0073e9SAndroid Build Coastguard Worker    def test_mutable_dce_block(self):
13426*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
13427*da0073e9SAndroid Build Coastguard Worker        def foo():
13428*da0073e9SAndroid Build Coastguard Worker            a = torch.rand(2, 3)
13429*da0073e9SAndroid Build Coastguard Worker            a += torch.rand(2, 3)
13430*da0073e9SAndroid Build Coastguard Worker            b = torch.rand(2, 3)
13431*da0073e9SAndroid Build Coastguard Worker            if bool(a > torch.zeros(2, 3)):
13432*da0073e9SAndroid Build Coastguard Worker                b += torch.rand(2, 3)
13433*da0073e9SAndroid Build Coastguard Worker                a += torch.rand(2, 3)
13434*da0073e9SAndroid Build Coastguard Worker            # a should be cleaned up but not b
13435*da0073e9SAndroid Build Coastguard Worker            return b
13436*da0073e9SAndroid Build Coastguard Worker
13437*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("prim::If").check_count("aten::rand", 1, exactly=True) \
13438*da0073e9SAndroid Build Coastguard Worker            .run(str(foo.graph))
13439*da0073e9SAndroid Build Coastguard Worker
13440*da0073e9SAndroid Build Coastguard Worker    def test_mutable_dce_graph_input(self):
13441*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
13442*da0073e9SAndroid Build Coastguard Worker        def foo(a):
13443*da0073e9SAndroid Build Coastguard Worker            a += torch.rand(2, 3)
13444*da0073e9SAndroid Build Coastguard Worker            # shouldn't clean up `a` even though it's not used in the output
13445*da0073e9SAndroid Build Coastguard Worker
13446*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::rand").check("aten::add").run(str(foo.graph))
13447*da0073e9SAndroid Build Coastguard Worker
13448*da0073e9SAndroid Build Coastguard Worker    def test_mutable_dce_list(self):
13449*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
13450*da0073e9SAndroid Build Coastguard Worker        def foo(a):
13451*da0073e9SAndroid Build Coastguard Worker            l = []
13452*da0073e9SAndroid Build Coastguard Worker            l.append(a)
13453*da0073e9SAndroid Build Coastguard Worker            c = l[0]
13454*da0073e9SAndroid Build Coastguard Worker            b = torch.rand(2, 3)
13455*da0073e9SAndroid Build Coastguard Worker            c += torch.rand(2, 3)
13456*da0073e9SAndroid Build Coastguard Worker            return b
13457*da0073e9SAndroid Build Coastguard Worker
13458*da0073e9SAndroid Build Coastguard Worker        # c does not get cleaned up because there is a wildcard + mutation
13459*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("aten::rand", 2, exactly=True).run(str(foo.graph))
13460*da0073e9SAndroid Build Coastguard Worker
13461*da0073e9SAndroid Build Coastguard Worker    def test_mutable_dce_loop(self):
13462*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
13463*da0073e9SAndroid Build Coastguard Worker        def foo(a):
13464*da0073e9SAndroid Build Coastguard Worker            l = []
13465*da0073e9SAndroid Build Coastguard Worker            l.append(a)
13466*da0073e9SAndroid Build Coastguard Worker            i = 0
13467*da0073e9SAndroid Build Coastguard Worker            b = torch.rand(2, 3)
13468*da0073e9SAndroid Build Coastguard Worker            while i < 1:
13469*da0073e9SAndroid Build Coastguard Worker                dead = torch.rand(2, 3)
13470*da0073e9SAndroid Build Coastguard Worker                c = l[0]
13471*da0073e9SAndroid Build Coastguard Worker                c += torch.rand(2, 3)
13472*da0073e9SAndroid Build Coastguard Worker                i += 1
13473*da0073e9SAndroid Build Coastguard Worker            return b
13474*da0073e9SAndroid Build Coastguard Worker
13475*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("prim::Loop").check_not("aten::rand").check("aten::__getitem__") \
13476*da0073e9SAndroid Build Coastguard Worker            .check_count("aten::rand", 1, exactly=True).run(str(foo.graph))
13477*da0073e9SAndroid Build Coastguard Worker
13478*da0073e9SAndroid Build Coastguard Worker    def test_mutable_dce_indirect_wildcards(self):
13479*da0073e9SAndroid Build Coastguard Worker        def fn():
13480*da0073e9SAndroid Build Coastguard Worker            x = torch.ones(2, 3)
13481*da0073e9SAndroid Build Coastguard Worker            x_1 = x.view(-1)
13482*da0073e9SAndroid Build Coastguard Worker            l = []
13483*da0073e9SAndroid Build Coastguard Worker            l.append(x_1)
13484*da0073e9SAndroid Build Coastguard Worker            x_view = l[0]
13485*da0073e9SAndroid Build Coastguard Worker            x.add_(torch.ones(2, 3))
13486*da0073e9SAndroid Build Coastguard Worker            return x_view
13487*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ())
13488*da0073e9SAndroid Build Coastguard Worker
13489*da0073e9SAndroid Build Coastguard Worker    def test_mutable_dce_indirect_wildcard_write(self):
13490*da0073e9SAndroid Build Coastguard Worker        def fn():
13491*da0073e9SAndroid Build Coastguard Worker            indexes = torch.jit.annotate(List[Tensor], [])
13492*da0073e9SAndroid Build Coastguard Worker            word_ids = torch.zeros(10, dtype=torch.int32)
13493*da0073e9SAndroid Build Coastguard Worker            word_ids[1] = 1
13494*da0073e9SAndroid Build Coastguard Worker            indexes.append(word_ids)
13495*da0073e9SAndroid Build Coastguard Worker
13496*da0073e9SAndroid Build Coastguard Worker            return word_ids
13497*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ())
13498*da0073e9SAndroid Build Coastguard Worker
13499*da0073e9SAndroid Build Coastguard Worker    def test_mutable_dce_wildcards(self):
13500*da0073e9SAndroid Build Coastguard Worker        def fn():
13501*da0073e9SAndroid Build Coastguard Worker            x = torch.ones(2, 3)
13502*da0073e9SAndroid Build Coastguard Worker            l = []
13503*da0073e9SAndroid Build Coastguard Worker            l.append(x)
13504*da0073e9SAndroid Build Coastguard Worker            x_view = l[0]
13505*da0073e9SAndroid Build Coastguard Worker            x.add_(torch.ones(2, 3))
13506*da0073e9SAndroid Build Coastguard Worker            return x_view
13507*da0073e9SAndroid Build Coastguard Worker
13508*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (), profiling=ProfilingMode.SIMPLE)
13509*da0073e9SAndroid Build Coastguard Worker
13510*da0073e9SAndroid Build Coastguard Worker    def test_cpp_function_tensor_str(self):
13511*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2)
13512*da0073e9SAndroid Build Coastguard Worker        scale = torch.randn(2, 2, requires_grad=True)
13513*da0073e9SAndroid Build Coastguard Worker        shift = torch.randn(2, 2, requires_grad=True)
13514*da0073e9SAndroid Build Coastguard Worker
13515*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
13516*da0073e9SAndroid Build Coastguard Worker        def fn(x, scale, shift):
13517*da0073e9SAndroid Build Coastguard Worker            return scale * x + shift
13518*da0073e9SAndroid Build Coastguard Worker
13519*da0073e9SAndroid Build Coastguard Worker        with self.capture_stdout() as captured:
13520*da0073e9SAndroid Build Coastguard Worker            print(fn(x, scale, shift))
13521*da0073e9SAndroid Build Coastguard Worker
13522*da0073e9SAndroid Build Coastguard Worker    def test_string_index(self):
13523*da0073e9SAndroid Build Coastguard Worker        def fn(x):
13524*da0073e9SAndroid Build Coastguard Worker            # type: (str)
13525*da0073e9SAndroid Build Coastguard Worker            return x[2], x[-1]
13526*da0073e9SAndroid Build Coastguard Worker
13527*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ("abcde",))
13528*da0073e9SAndroid Build Coastguard Worker
13529*da0073e9SAndroid Build Coastguard Worker    def test_ord(self):
13530*da0073e9SAndroid Build Coastguard Worker        def fn(x):
13531*da0073e9SAndroid Build Coastguard Worker            # type: (str) -> int
13532*da0073e9SAndroid Build Coastguard Worker            return ord(x)
13533*da0073e9SAndroid Build Coastguard Worker
13534*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ("h"))
13535*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ("y"))
13536*da0073e9SAndroid Build Coastguard Worker
13537*da0073e9SAndroid Build Coastguard Worker        def index_str_to_tensor(s):
13538*da0073e9SAndroid Build Coastguard Worker            # type: (str) -> Tensor
13539*da0073e9SAndroid Build Coastguard Worker            return torch.tensor(ord(s))  # noqa: T484
13540*da0073e9SAndroid Build Coastguard Worker
13541*da0073e9SAndroid Build Coastguard Worker        s = '\u00a3'.encode()[:1]
13542*da0073e9SAndroid Build Coastguard Worker        self.checkScript(index_str_to_tensor, (s,))
13543*da0073e9SAndroid Build Coastguard Worker
13544*da0073e9SAndroid Build Coastguard Worker    def test_chr(self):
13545*da0073e9SAndroid Build Coastguard Worker        def fn(x):
13546*da0073e9SAndroid Build Coastguard Worker            # type: (int) -> str
13547*da0073e9SAndroid Build Coastguard Worker            return chr(x)
13548*da0073e9SAndroid Build Coastguard Worker
13549*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (1,))
13550*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (97,))
13551*da0073e9SAndroid Build Coastguard Worker
13552*da0073e9SAndroid Build Coastguard Worker    def test_round(self):
13553*da0073e9SAndroid Build Coastguard Worker        def round_float(x):
13554*da0073e9SAndroid Build Coastguard Worker            # type: (float) -> float
13555*da0073e9SAndroid Build Coastguard Worker            return round(x)
13556*da0073e9SAndroid Build Coastguard Worker
13557*da0073e9SAndroid Build Coastguard Worker        def round_int(x):
13558*da0073e9SAndroid Build Coastguard Worker            # type: (int) -> float
13559*da0073e9SAndroid Build Coastguard Worker            return round(x)
13560*da0073e9SAndroid Build Coastguard Worker
13561*da0073e9SAndroid Build Coastguard Worker        self.checkScript(round_float, (1.5,))
13562*da0073e9SAndroid Build Coastguard Worker        self.checkScript(round_int, (2,))
13563*da0073e9SAndroid Build Coastguard Worker
13564*da0073e9SAndroid Build Coastguard Worker    def test_convert_base(self):
13565*da0073e9SAndroid Build Coastguard Worker        def test_hex(x):
13566*da0073e9SAndroid Build Coastguard Worker            # type: (int) -> str
13567*da0073e9SAndroid Build Coastguard Worker            return hex(x)
13568*da0073e9SAndroid Build Coastguard Worker
13569*da0073e9SAndroid Build Coastguard Worker        def test_oct(x):
13570*da0073e9SAndroid Build Coastguard Worker            # type: (int) -> str
13571*da0073e9SAndroid Build Coastguard Worker            return oct(x)
13572*da0073e9SAndroid Build Coastguard Worker
13573*da0073e9SAndroid Build Coastguard Worker        def test_bin(x):
13574*da0073e9SAndroid Build Coastguard Worker            # type: (int) -> str
13575*da0073e9SAndroid Build Coastguard Worker            return bin(x)
13576*da0073e9SAndroid Build Coastguard Worker
13577*da0073e9SAndroid Build Coastguard Worker        numbers = [-1000, -10, 0, 1, 10, 2343]
13578*da0073e9SAndroid Build Coastguard Worker        for n in numbers:
13579*da0073e9SAndroid Build Coastguard Worker            self.checkScript(test_bin, (n,))
13580*da0073e9SAndroid Build Coastguard Worker            self.checkScript(test_oct, (n,))
13581*da0073e9SAndroid Build Coastguard Worker            self.checkScript(test_hex, (n,))
13582*da0073e9SAndroid Build Coastguard Worker
13583*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle")
13584*da0073e9SAndroid Build Coastguard Worker    def test_get_set_state(self):
13585*da0073e9SAndroid Build Coastguard Worker        class Root(torch.jit.ScriptModule):
13586*da0073e9SAndroid Build Coastguard Worker            __constants__ = ['number']
13587*da0073e9SAndroid Build Coastguard Worker
13588*da0073e9SAndroid Build Coastguard Worker            def __init__(self, number):
13589*da0073e9SAndroid Build Coastguard Worker                super().__init__()
13590*da0073e9SAndroid Build Coastguard Worker                self.buffer1 = nn.Buffer(torch.ones(2, 2))
13591*da0073e9SAndroid Build Coastguard Worker                self.buffer2 = nn.Buffer(torch.ones(2, 2))
13592*da0073e9SAndroid Build Coastguard Worker                self.number = number
13593*da0073e9SAndroid Build Coastguard Worker
13594*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
13595*da0073e9SAndroid Build Coastguard Worker            def __getstate__(self):
13596*da0073e9SAndroid Build Coastguard Worker                return (self.buffer1, self.buffer2, 74, self.training)
13597*da0073e9SAndroid Build Coastguard Worker
13598*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
13599*da0073e9SAndroid Build Coastguard Worker            def __setstate__(self, state):
13600*da0073e9SAndroid Build Coastguard Worker                self.buffer1 = state[0] + 10
13601*da0073e9SAndroid Build Coastguard Worker                self.buffer2 = state[1] + 10
13602*da0073e9SAndroid Build Coastguard Worker                self.training = state[3]
13603*da0073e9SAndroid Build Coastguard Worker
13604*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
13605*da0073e9SAndroid Build Coastguard Worker            __constants__ = ['number']
13606*da0073e9SAndroid Build Coastguard Worker
13607*da0073e9SAndroid Build Coastguard Worker            def __init__(self, number, submodule):
13608*da0073e9SAndroid Build Coastguard Worker                super().__init__()
13609*da0073e9SAndroid Build Coastguard Worker                self.buffer1 = nn.Buffer(torch.ones(2, 2))
13610*da0073e9SAndroid Build Coastguard Worker                self.buffer2 = nn.Buffer(torch.ones(2, 2))
13611*da0073e9SAndroid Build Coastguard Worker                self.number = number
13612*da0073e9SAndroid Build Coastguard Worker                self.submodule = submodule
13613*da0073e9SAndroid Build Coastguard Worker
13614*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
13615*da0073e9SAndroid Build Coastguard Worker            def __getstate__(self):
13616*da0073e9SAndroid Build Coastguard Worker                return (self.buffer1, self.buffer2, 74, self.submodule, self.training)
13617*da0073e9SAndroid Build Coastguard Worker
13618*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
13619*da0073e9SAndroid Build Coastguard Worker            def __setstate__(self, state):
13620*da0073e9SAndroid Build Coastguard Worker                self.buffer1 = state[0] + 10
13621*da0073e9SAndroid Build Coastguard Worker                self.buffer2 = state[1] + 10
13622*da0073e9SAndroid Build Coastguard Worker                self.submodule = state[3]
13623*da0073e9SAndroid Build Coastguard Worker                self.training = state[4]
13624*da0073e9SAndroid Build Coastguard Worker
13625*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName() as fname:
13626*da0073e9SAndroid Build Coastguard Worker            m = M(23, submodule=Root(99))
13627*da0073e9SAndroid Build Coastguard Worker            m.save(fname)
13628*da0073e9SAndroid Build Coastguard Worker            loaded = torch.jit.load(fname)
13629*da0073e9SAndroid Build Coastguard Worker
13630*da0073e9SAndroid Build Coastguard Worker        # Check original module
13631*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.buffer1, torch.ones(2, 2))
13632*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.buffer2, torch.ones(2, 2))
13633*da0073e9SAndroid Build Coastguard Worker
13634*da0073e9SAndroid Build Coastguard Worker        # Check top level module
13635*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(loaded.buffer1, torch.ones(2, 2) + 10)
13636*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(loaded.buffer2, torch.ones(2, 2) + 10)
13637*da0073e9SAndroid Build Coastguard Worker
13638*da0073e9SAndroid Build Coastguard Worker        # Check submodule
13639*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(loaded.submodule.buffer1, torch.ones(2, 2) + 10)
13640*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(loaded.submodule.buffer2, torch.ones(2, 2) + 10)
13641*da0073e9SAndroid Build Coastguard Worker
13642*da0073e9SAndroid Build Coastguard Worker        # Check simpler module
13643*da0073e9SAndroid Build Coastguard Worker        class NoArgState(torch.nn.Module):
13644*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
13645*da0073e9SAndroid Build Coastguard Worker                super().__init__()
13646*da0073e9SAndroid Build Coastguard Worker                self.buffer1 = nn.Buffer(torch.ones(2, 2))
13647*da0073e9SAndroid Build Coastguard Worker                self.buffer2 = nn.Buffer(torch.ones(2, 2))
13648*da0073e9SAndroid Build Coastguard Worker
13649*da0073e9SAndroid Build Coastguard Worker            def forward(self):
13650*da0073e9SAndroid Build Coastguard Worker                pass
13651*da0073e9SAndroid Build Coastguard Worker
13652*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
13653*da0073e9SAndroid Build Coastguard Worker            def __getstate__(self):
13654*da0073e9SAndroid Build Coastguard Worker                return 5, self.training
13655*da0073e9SAndroid Build Coastguard Worker
13656*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
13657*da0073e9SAndroid Build Coastguard Worker            def __setstate__(self, state):
13658*da0073e9SAndroid Build Coastguard Worker                self.buffer1 = torch.ones(2, 2) + state[0]
13659*da0073e9SAndroid Build Coastguard Worker                self.buffer2 = torch.ones(2, 2) + 10
13660*da0073e9SAndroid Build Coastguard Worker                self.training = state[1]
13661*da0073e9SAndroid Build Coastguard Worker
13662*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName() as fname:
13663*da0073e9SAndroid Build Coastguard Worker            m = torch.jit.script(NoArgState())
13664*da0073e9SAndroid Build Coastguard Worker            m.save(fname)
13665*da0073e9SAndroid Build Coastguard Worker            loaded = torch.jit.load(fname)
13666*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(loaded.buffer1, torch.ones(2, 2) + 5)
13667*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(loaded.buffer2, torch.ones(2, 2) + 10)
13668*da0073e9SAndroid Build Coastguard Worker
13669*da0073e9SAndroid Build Coastguard Worker
13670*da0073e9SAndroid Build Coastguard Worker
13671*da0073e9SAndroid Build Coastguard Worker    def test_string_slicing(self):
13672*da0073e9SAndroid Build Coastguard Worker        def fn1(x):
13673*da0073e9SAndroid Build Coastguard Worker            # type: (str) -> str
13674*da0073e9SAndroid Build Coastguard Worker            return x[1:3]
13675*da0073e9SAndroid Build Coastguard Worker
13676*da0073e9SAndroid Build Coastguard Worker        def fn2(x):
13677*da0073e9SAndroid Build Coastguard Worker            # type: (str) -> str
13678*da0073e9SAndroid Build Coastguard Worker            return x[-1:3]
13679*da0073e9SAndroid Build Coastguard Worker
13680*da0073e9SAndroid Build Coastguard Worker        def fn3(x):
13681*da0073e9SAndroid Build Coastguard Worker            # type: (str) -> str
13682*da0073e9SAndroid Build Coastguard Worker            return x[3:1]
13683*da0073e9SAndroid Build Coastguard Worker
13684*da0073e9SAndroid Build Coastguard Worker        def fn4(x):
13685*da0073e9SAndroid Build Coastguard Worker            # type: (str) -> str
13686*da0073e9SAndroid Build Coastguard Worker            return x[3:100]
13687*da0073e9SAndroid Build Coastguard Worker
13688*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn1, ("abcdefghi",))
13689*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn2, ("abcdefghi",))
13690*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn3, ("abcdefghi",))
13691*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn4, ("abcdefghi",))
13692*da0073e9SAndroid Build Coastguard Worker
13693*da0073e9SAndroid Build Coastguard Worker    def test_early_return_closure(self):
13694*da0073e9SAndroid Build Coastguard Worker        code = dedent('''
13695*da0073e9SAndroid Build Coastguard Worker            def tanh(self):
13696*da0073e9SAndroid Build Coastguard Worker                output = torch.tanh(self)
13697*da0073e9SAndroid Build Coastguard Worker                def backward(grad_output):
13698*da0073e9SAndroid Build Coastguard Worker                    pass
13699*da0073e9SAndroid Build Coastguard Worker                return output, backward
13700*da0073e9SAndroid Build Coastguard Worker        ''')
13701*da0073e9SAndroid Build Coastguard Worker        cu = torch.jit.CompilationUnit(code)
13702*da0073e9SAndroid Build Coastguard Worker        g = cu.tanh.graph
13703*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("prim::Closure_0", 2).check("NoneType = prim::Constant") \
13704*da0073e9SAndroid Build Coastguard Worker                   .check_next("return").run(g)
13705*da0073e9SAndroid Build Coastguard Worker
13706*da0073e9SAndroid Build Coastguard Worker        code = dedent('''
13707*da0073e9SAndroid Build Coastguard Worker            def tanh(self):
13708*da0073e9SAndroid Build Coastguard Worker                output = torch.tanh(self)
13709*da0073e9SAndroid Build Coastguard Worker                def backward(grad_output):
13710*da0073e9SAndroid Build Coastguard Worker                    a = 1
13711*da0073e9SAndroid Build Coastguard Worker                    if output:
13712*da0073e9SAndroid Build Coastguard Worker                        return 1
13713*da0073e9SAndroid Build Coastguard Worker                    else:
13714*da0073e9SAndroid Build Coastguard Worker                        a = 2
13715*da0073e9SAndroid Build Coastguard Worker                    return a
13716*da0073e9SAndroid Build Coastguard Worker                return output, backward
13717*da0073e9SAndroid Build Coastguard Worker        ''')
13718*da0073e9SAndroid Build Coastguard Worker        cu = torch.jit.CompilationUnit(code)
13719*da0073e9SAndroid Build Coastguard Worker        g = cu.tanh.graph
13720*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("prim::Closure_0", 2).check("int = prim::If") \
13721*da0073e9SAndroid Build Coastguard Worker                   .run(g)
13722*da0073e9SAndroid Build Coastguard Worker
13723*da0073e9SAndroid Build Coastguard Worker        code = dedent('''
13724*da0073e9SAndroid Build Coastguard Worker            def loop_in_closure(self):
13725*da0073e9SAndroid Build Coastguard Worker                output = torch.tanh(self)
13726*da0073e9SAndroid Build Coastguard Worker                def backward(grad_output):
13727*da0073e9SAndroid Build Coastguard Worker                    for i in range(3):
13728*da0073e9SAndroid Build Coastguard Worker                        return 1
13729*da0073e9SAndroid Build Coastguard Worker                    return 4
13730*da0073e9SAndroid Build Coastguard Worker                return output, backward
13731*da0073e9SAndroid Build Coastguard Worker        ''')
13732*da0073e9SAndroid Build Coastguard Worker        cu = torch.jit.CompilationUnit(code)
13733*da0073e9SAndroid Build Coastguard Worker        fc = FileCheck()
13734*da0073e9SAndroid Build Coastguard Worker        fc.check("prim::Closure").check("(Tensor, NoneType) = prim::TupleConstruct")
13735*da0073e9SAndroid Build Coastguard Worker        # Loop then two if's added in exit transform
13736*da0073e9SAndroid Build Coastguard Worker        fc.check("prim::Closure").check("prim::Loop").check_count("prim::If", 2)
13737*da0073e9SAndroid Build Coastguard Worker        fc.run(cu.loop_in_closure.graph)
13738*da0073e9SAndroid Build Coastguard Worker
13739*da0073e9SAndroid Build Coastguard Worker        code = dedent('''
13740*da0073e9SAndroid Build Coastguard Worker            def tanh(self):
13741*da0073e9SAndroid Build Coastguard Worker                output = torch.tanh(self)
13742*da0073e9SAndroid Build Coastguard Worker                def backward(grad_output):
13743*da0073e9SAndroid Build Coastguard Worker                    if 1 == 1:
13744*da0073e9SAndroid Build Coastguard Worker                        return 1
13745*da0073e9SAndroid Build Coastguard Worker                    else:
13746*da0073e9SAndroid Build Coastguard Worker                        return 1.
13747*da0073e9SAndroid Build Coastguard Worker                return output, backward
13748*da0073e9SAndroid Build Coastguard Worker        ''')
13749*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "returned a value of type int but"):
13750*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit(code)
13751*da0073e9SAndroid Build Coastguard Worker
13752*da0073e9SAndroid Build Coastguard Worker    @_inline_everything
13753*da0073e9SAndroid Build Coastguard Worker    def test_early_return_fork_join(self):
13754*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
13755*da0073e9SAndroid Build Coastguard Worker        def foo(x):
13756*da0073e9SAndroid Build Coastguard Worker            if x.dim() == 2:
13757*da0073e9SAndroid Build Coastguard Worker                return torch.neg(x), x
13758*da0073e9SAndroid Build Coastguard Worker            else:
13759*da0073e9SAndroid Build Coastguard Worker                return torch.neg(x), x + 1
13760*da0073e9SAndroid Build Coastguard Worker
13761*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
13762*da0073e9SAndroid Build Coastguard Worker
13763*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
13764*da0073e9SAndroid Build Coastguard Worker        def wait_script(x):
13765*da0073e9SAndroid Build Coastguard Worker            fut = torch.jit._fork(foo, x)
13766*da0073e9SAndroid Build Coastguard Worker            y_hat = foo(x)
13767*da0073e9SAndroid Build Coastguard Worker            y = torch.jit._wait(fut)
13768*da0073e9SAndroid Build Coastguard Worker            return y, y_hat
13769*da0073e9SAndroid Build Coastguard Worker
13770*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("with prim::fork").check("prim::If").check("return")\
13771*da0073e9SAndroid Build Coastguard Worker                   .run(wait_script.graph)
13772*da0073e9SAndroid Build Coastguard Worker
13773*da0073e9SAndroid Build Coastguard Worker    def test_early_return_type_refinement(self):
13774*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
13775*da0073e9SAndroid Build Coastguard Worker        def test(x):
13776*da0073e9SAndroid Build Coastguard Worker            # type: (Optional[int]) -> int
13777*da0073e9SAndroid Build Coastguard Worker            if x is None:
13778*da0073e9SAndroid Build Coastguard Worker                return 1
13779*da0073e9SAndroid Build Coastguard Worker            else:
13780*da0073e9SAndroid Build Coastguard Worker                return x
13781*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(test(None), 1)
13782*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(test(2), 2)
13783*da0073e9SAndroid Build Coastguard Worker
13784*da0073e9SAndroid Build Coastguard Worker    def test_exceptions_with_control_flow(self):
13785*da0073e9SAndroid Build Coastguard Worker        def test_num_ifs(func, num_ifs):
13786*da0073e9SAndroid Build Coastguard Worker            g = torch.jit.script(func).graph
13787*da0073e9SAndroid Build Coastguard Worker            FileCheck().check_count("prim::If", num_ifs, exactly=True).run(g)
13788*da0073e9SAndroid Build Coastguard Worker
13789*da0073e9SAndroid Build Coastguard Worker        def no_guard_ifs_added(x):
13790*da0073e9SAndroid Build Coastguard Worker            # type: (int) -> int
13791*da0073e9SAndroid Build Coastguard Worker            if x == 1:
13792*da0073e9SAndroid Build Coastguard Worker                return 1
13793*da0073e9SAndroid Build Coastguard Worker            else:
13794*da0073e9SAndroid Build Coastguard Worker                if x == 2:
13795*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError("hi")
13796*da0073e9SAndroid Build Coastguard Worker                else:
13797*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError("hi")
13798*da0073e9SAndroid Build Coastguard Worker
13799*da0073e9SAndroid Build Coastguard Worker        self.checkScript(no_guard_ifs_added, (1,))
13800*da0073e9SAndroid Build Coastguard Worker        self.checkScriptRaisesRegex(no_guard_ifs_added, (2,), Exception, "")
13801*da0073e9SAndroid Build Coastguard Worker        test_num_ifs(no_guard_ifs_added, 2)
13802*da0073e9SAndroid Build Coastguard Worker
13803*da0073e9SAndroid Build Coastguard Worker        # FUNCTION LOOKS LIKE:
13804*da0073e9SAndroid Build Coastguard Worker        # graph(%x.1 : int):
13805*da0073e9SAndroid Build Coastguard Worker        #   %7 : str = prim::Constant[value="Exception"]()
13806*da0073e9SAndroid Build Coastguard Worker        #   %2 : int = prim::Constant[value=1]()
13807*da0073e9SAndroid Build Coastguard Worker        #   %5 : int = prim::Constant[value=2]()
13808*da0073e9SAndroid Build Coastguard Worker        #   %19 : int = prim::Uninitialized()
13809*da0073e9SAndroid Build Coastguard Worker        #   %3 : bool = aten::eq(%x.1, %2)
13810*da0073e9SAndroid Build Coastguard Worker        #   %20 : int = prim::If(%3)
13811*da0073e9SAndroid Build Coastguard Worker        #     block0():
13812*da0073e9SAndroid Build Coastguard Worker        #       -> (%2)
13813*da0073e9SAndroid Build Coastguard Worker        #     block1():
13814*da0073e9SAndroid Build Coastguard Worker        #       %6 : bool = aten::eq(%x.1, %5)
13815*da0073e9SAndroid Build Coastguard Worker        #        = prim::If(%6)
13816*da0073e9SAndroid Build Coastguard Worker        #         block0():
13817*da0073e9SAndroid Build Coastguard Worker        #            = prim::RaiseException(%7)
13818*da0073e9SAndroid Build Coastguard Worker        #           -> ()
13819*da0073e9SAndroid Build Coastguard Worker        #         block1():
13820*da0073e9SAndroid Build Coastguard Worker        #            = prim::RaiseException(%7)
13821*da0073e9SAndroid Build Coastguard Worker        #           -> ()
13822*da0073e9SAndroid Build Coastguard Worker        #       -> (%19)
13823*da0073e9SAndroid Build Coastguard Worker        #   return (%20)
13824*da0073e9SAndroid Build Coastguard Worker
13825*da0073e9SAndroid Build Coastguard Worker        def no_ifs_added(x):
13826*da0073e9SAndroid Build Coastguard Worker            # type: (int) -> int
13827*da0073e9SAndroid Build Coastguard Worker            if x < 0:
13828*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError("hi")
13829*da0073e9SAndroid Build Coastguard Worker            return x
13830*da0073e9SAndroid Build Coastguard Worker
13831*da0073e9SAndroid Build Coastguard Worker        self.checkScript(no_ifs_added, (1,))
13832*da0073e9SAndroid Build Coastguard Worker        self.checkScriptRaisesRegex(no_ifs_added, (-2,), Exception, "")
13833*da0073e9SAndroid Build Coastguard Worker        test_num_ifs(no_ifs_added, 1)
13834*da0073e9SAndroid Build Coastguard Worker
13835*da0073e9SAndroid Build Coastguard Worker        def test_if_might(x):
13836*da0073e9SAndroid Build Coastguard Worker            # type: (int)
13837*da0073e9SAndroid Build Coastguard Worker            if x > 0:
13838*da0073e9SAndroid Build Coastguard Worker                if x == 1:
13839*da0073e9SAndroid Build Coastguard Worker                    return 1
13840*da0073e9SAndroid Build Coastguard Worker                else:
13841*da0073e9SAndroid Build Coastguard Worker                    a = 2
13842*da0073e9SAndroid Build Coastguard Worker            else:
13843*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError("hi")
13844*da0073e9SAndroid Build Coastguard Worker            return a + 2
13845*da0073e9SAndroid Build Coastguard Worker
13846*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_if_might, (1,))
13847*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_if_might, (3,))
13848*da0073e9SAndroid Build Coastguard Worker        self.checkScriptRaisesRegex(no_ifs_added, (-2,), Exception, "")
13849*da0073e9SAndroid Build Coastguard Worker        test_num_ifs(test_if_might, 3)  # one if added to guard a + 2
13850*da0073e9SAndroid Build Coastguard Worker
13851*da0073e9SAndroid Build Coastguard Worker        def test_loop_no_escape(x):
13852*da0073e9SAndroid Build Coastguard Worker            # type: (int)
13853*da0073e9SAndroid Build Coastguard Worker            if x >= 0:
13854*da0073e9SAndroid Build Coastguard Worker                for i in range(x):
13855*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError("hi")
13856*da0073e9SAndroid Build Coastguard Worker            else:
13857*da0073e9SAndroid Build Coastguard Worker                return 5
13858*da0073e9SAndroid Build Coastguard Worker            return x + 3
13859*da0073e9SAndroid Build Coastguard Worker
13860*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_loop_no_escape, (0,))
13861*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_loop_no_escape, (-1,))
13862*da0073e9SAndroid Build Coastguard Worker        self.checkScriptRaisesRegex(test_loop_no_escape, (1,), Exception, "")
13863*da0073e9SAndroid Build Coastguard Worker
13864*da0073e9SAndroid Build Coastguard Worker        # if guard gets optimized away
13865*da0073e9SAndroid Build Coastguard Worker        test_num_ifs(test_loop_no_escape, 1)
13866*da0073e9SAndroid Build Coastguard Worker
13867*da0073e9SAndroid Build Coastguard Worker        def test_loop_exception_with_continue(x):
13868*da0073e9SAndroid Build Coastguard Worker            # type: (int)
13869*da0073e9SAndroid Build Coastguard Worker            i = 0
13870*da0073e9SAndroid Build Coastguard Worker            for i in range(5):
13871*da0073e9SAndroid Build Coastguard Worker                if i == x:
13872*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError("hi")
13873*da0073e9SAndroid Build Coastguard Worker                else:
13874*da0073e9SAndroid Build Coastguard Worker                    continue
13875*da0073e9SAndroid Build Coastguard Worker                print(i)
13876*da0073e9SAndroid Build Coastguard Worker            return i + 5
13877*da0073e9SAndroid Build Coastguard Worker
13878*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_loop_exception_with_continue, (-1,))
13879*da0073e9SAndroid Build Coastguard Worker        self.checkScriptRaisesRegex(test_loop_exception_with_continue, (1,), Exception, "")
13880*da0073e9SAndroid Build Coastguard Worker        test_num_ifs(test_loop_exception_with_continue, 1)  # no ifs added to guard print
13881*da0073e9SAndroid Build Coastguard Worker
13882*da0073e9SAndroid Build Coastguard Worker
13883*da0073e9SAndroid Build Coastguard Worker    def test_exception_exits_closure(self):
13884*da0073e9SAndroid Build Coastguard Worker        code = dedent('''
13885*da0073e9SAndroid Build Coastguard Worker            def no_return_func(self):
13886*da0073e9SAndroid Build Coastguard Worker                # type: (Tensor) -> Tensor
13887*da0073e9SAndroid Build Coastguard Worker                output = torch.tanh(self)
13888*da0073e9SAndroid Build Coastguard Worker                def backward(grad_output):
13889*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError("Hi")
13890*da0073e9SAndroid Build Coastguard Worker        ''')
13891*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "does not return along all"):
13892*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit(code)
13893*da0073e9SAndroid Build Coastguard Worker
13894*da0073e9SAndroid Build Coastguard Worker        code = dedent('''
13895*da0073e9SAndroid Build Coastguard Worker            def test_exit_pair_reset(x):
13896*da0073e9SAndroid Build Coastguard Worker                # type: (int) -> int
13897*da0073e9SAndroid Build Coastguard Worker                if x > 0:
13898*da0073e9SAndroid Build Coastguard Worker                    a = 0
13899*da0073e9SAndroid Build Coastguard Worker                    def backward(grad_output):
13900*da0073e9SAndroid Build Coastguard Worker                        raise RuntimeError("Hi")
13901*da0073e9SAndroid Build Coastguard Worker                    a = a + 1
13902*da0073e9SAndroid Build Coastguard Worker                else:
13903*da0073e9SAndroid Build Coastguard Worker                    return x
13904*da0073e9SAndroid Build Coastguard Worker                return a + 1
13905*da0073e9SAndroid Build Coastguard Worker        ''')
13906*da0073e9SAndroid Build Coastguard Worker        func = torch.jit.CompilationUnit(code).test_exit_pair_reset
13907*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(func(1,), 2)
13908*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(func(-1,), -1)
13909*da0073e9SAndroid Build Coastguard Worker        # final a + 1 gets inlined into the first branch and optimized away
13910*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("prim::If", 1, exactly=True).run(func.graph)
13911*da0073e9SAndroid Build Coastguard Worker
13912*da0073e9SAndroid Build Coastguard Worker    def test_non_final_return(self):
13913*da0073e9SAndroid Build Coastguard Worker        def simple(x):
13914*da0073e9SAndroid Build Coastguard Worker            if bool(x > 3):
13915*da0073e9SAndroid Build Coastguard Worker                return x + 1
13916*da0073e9SAndroid Build Coastguard Worker            else:
13917*da0073e9SAndroid Build Coastguard Worker                return x + 2
13918*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("nope")
13919*da0073e9SAndroid Build Coastguard Worker
13920*da0073e9SAndroid Build Coastguard Worker        def nest(x):
13921*da0073e9SAndroid Build Coastguard Worker            x = x + 1
13922*da0073e9SAndroid Build Coastguard Worker            if bool(x > 3):
13923*da0073e9SAndroid Build Coastguard Worker                if bool(x > 4):
13924*da0073e9SAndroid Build Coastguard Worker                    x += 1
13925*da0073e9SAndroid Build Coastguard Worker                return x + 1
13926*da0073e9SAndroid Build Coastguard Worker            else:
13927*da0073e9SAndroid Build Coastguard Worker                return x + 2
13928*da0073e9SAndroid Build Coastguard Worker
13929*da0073e9SAndroid Build Coastguard Worker        def early_ret(x):
13930*da0073e9SAndroid Build Coastguard Worker            x = x + 1
13931*da0073e9SAndroid Build Coastguard Worker            if bool(x > 3):
13932*da0073e9SAndroid Build Coastguard Worker                return x + 1
13933*da0073e9SAndroid Build Coastguard Worker            x = x + 1
13934*da0073e9SAndroid Build Coastguard Worker            return x + 2
13935*da0073e9SAndroid Build Coastguard Worker
13936*da0073e9SAndroid Build Coastguard Worker        def nest_early_ret(x):
13937*da0073e9SAndroid Build Coastguard Worker            x = x + 1
13938*da0073e9SAndroid Build Coastguard Worker            if bool(x > 3):
13939*da0073e9SAndroid Build Coastguard Worker                if bool(x > 4):
13940*da0073e9SAndroid Build Coastguard Worker                    return x + 2
13941*da0073e9SAndroid Build Coastguard Worker                return x + 1
13942*da0073e9SAndroid Build Coastguard Worker            x = x + 1
13943*da0073e9SAndroid Build Coastguard Worker            return x + 2
13944*da0073e9SAndroid Build Coastguard Worker
13945*da0073e9SAndroid Build Coastguard Worker        def not_early_ret(x):
13946*da0073e9SAndroid Build Coastguard Worker            s = ""
13947*da0073e9SAndroid Build Coastguard Worker            if bool(x > 3):
13948*da0073e9SAndroid Build Coastguard Worker                if bool(x > 4):
13949*da0073e9SAndroid Build Coastguard Worker                    return 1, s
13950*da0073e9SAndroid Build Coastguard Worker                s += "foo"
13951*da0073e9SAndroid Build Coastguard Worker            else:
13952*da0073e9SAndroid Build Coastguard Worker                s += "5"
13953*da0073e9SAndroid Build Coastguard Worker            s += "hi"
13954*da0073e9SAndroid Build Coastguard Worker            return 7, s
13955*da0073e9SAndroid Build Coastguard Worker
13956*da0073e9SAndroid Build Coastguard Worker        def not_total_ret(x):
13957*da0073e9SAndroid Build Coastguard Worker            s = ""
13958*da0073e9SAndroid Build Coastguard Worker            if bool(x > 3):
13959*da0073e9SAndroid Build Coastguard Worker                if bool(x > 4):
13960*da0073e9SAndroid Build Coastguard Worker                    return 1, s
13961*da0073e9SAndroid Build Coastguard Worker                else:
13962*da0073e9SAndroid Build Coastguard Worker                    return 2, s
13963*da0073e9SAndroid Build Coastguard Worker            else:
13964*da0073e9SAndroid Build Coastguard Worker                s += "5"
13965*da0073e9SAndroid Build Coastguard Worker            return 7, s
13966*da0073e9SAndroid Build Coastguard Worker
13967*da0073e9SAndroid Build Coastguard Worker        for i in range(3):
13968*da0073e9SAndroid Build Coastguard Worker            for func in [simple, nest, early_ret, nest_early_ret, not_early_ret,
13969*da0073e9SAndroid Build Coastguard Worker                         not_total_ret]:
13970*da0073e9SAndroid Build Coastguard Worker                self.checkScript(func, (torch.tensor(2.5 + i),))
13971*da0073e9SAndroid Build Coastguard Worker
13972*da0073e9SAndroid Build Coastguard Worker        def vars_used_after_ret(x):
13973*da0073e9SAndroid Build Coastguard Worker            # type: (int) -> int
13974*da0073e9SAndroid Build Coastguard Worker            if x == 0:
13975*da0073e9SAndroid Build Coastguard Worker                return x
13976*da0073e9SAndroid Build Coastguard Worker            else:
13977*da0073e9SAndroid Build Coastguard Worker                y = 2
13978*da0073e9SAndroid Build Coastguard Worker                z = 3
13979*da0073e9SAndroid Build Coastguard Worker            return x + y * z
13980*da0073e9SAndroid Build Coastguard Worker
13981*da0073e9SAndroid Build Coastguard Worker        self.checkScript(vars_used_after_ret, (1,))
13982*da0073e9SAndroid Build Coastguard Worker        self.checkScript(vars_used_after_ret, (0,))
13983*da0073e9SAndroid Build Coastguard Worker
13984*da0073e9SAndroid Build Coastguard Worker        def complicated(x):
13985*da0073e9SAndroid Build Coastguard Worker            # type: (int) -> int
13986*da0073e9SAndroid Build Coastguard Worker            if x:
13987*da0073e9SAndroid Build Coastguard Worker                if x == 2:
13988*da0073e9SAndroid Build Coastguard Worker                    return 1
13989*da0073e9SAndroid Build Coastguard Worker                    assert 1 == 2
13990*da0073e9SAndroid Build Coastguard Worker                else:
13991*da0073e9SAndroid Build Coastguard Worker                    if x == 3:
13992*da0073e9SAndroid Build Coastguard Worker                        return 2
13993*da0073e9SAndroid Build Coastguard Worker                        assert 1 == 2
13994*da0073e9SAndroid Build Coastguard Worker                    else:
13995*da0073e9SAndroid Build Coastguard Worker                        a = 2
13996*da0073e9SAndroid Build Coastguard Worker                        b = 3
13997*da0073e9SAndroid Build Coastguard Worker            else:
13998*da0073e9SAndroid Build Coastguard Worker                a = 4
13999*da0073e9SAndroid Build Coastguard Worker                b = 1
14000*da0073e9SAndroid Build Coastguard Worker            return a + b
14001*da0073e9SAndroid Build Coastguard Worker            assert 1 == 2
14002*da0073e9SAndroid Build Coastguard Worker
14003*da0073e9SAndroid Build Coastguard Worker        for i in range(4):
14004*da0073e9SAndroid Build Coastguard Worker            self.checkScript(complicated, (i,))
14005*da0073e9SAndroid Build Coastguard Worker
14006*da0073e9SAndroid Build Coastguard Worker    def test_partial_returns(self):
14007*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "does not return along all"):
14008*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
14009*da0073e9SAndroid Build Coastguard Worker            def no_ret():
14010*da0073e9SAndroid Build Coastguard Worker                # type: () -> int
14011*da0073e9SAndroid Build Coastguard Worker                pass
14012*da0073e9SAndroid Build Coastguard Worker
14013*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "does not return along all"):
14014*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
14015*da0073e9SAndroid Build Coastguard Worker            def partial(x):
14016*da0073e9SAndroid Build Coastguard Worker                # type: (Tensor) -> int
14017*da0073e9SAndroid Build Coastguard Worker                if x:
14018*da0073e9SAndroid Build Coastguard Worker                    return 1
14019*da0073e9SAndroid Build Coastguard Worker
14020*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "does not return along all"):
14021*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
14022*da0073e9SAndroid Build Coastguard Worker            def typed_none():
14023*da0073e9SAndroid Build Coastguard Worker                # type: () -> Optional[int]
14024*da0073e9SAndroid Build Coastguard Worker                pass
14025*da0073e9SAndroid Build Coastguard Worker
14026*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
14027*da0073e9SAndroid Build Coastguard Worker        def none_ret():
14028*da0073e9SAndroid Build Coastguard Worker            pass
14029*da0073e9SAndroid Build Coastguard Worker
14030*da0073e9SAndroid Build Coastguard Worker        self.assertIs(none_ret(), None)
14031*da0073e9SAndroid Build Coastguard Worker        FileCheck().check(": None").run(none_ret.graph)
14032*da0073e9SAndroid Build Coastguard Worker
14033*da0073e9SAndroid Build Coastguard Worker    def test_early_returns_loops(self):
14034*da0073e9SAndroid Build Coastguard Worker        def nest_while_ret(x):
14035*da0073e9SAndroid Build Coastguard Worker            # type: (int) -> int
14036*da0073e9SAndroid Build Coastguard Worker            y = 4
14037*da0073e9SAndroid Build Coastguard Worker            while x < 4:
14038*da0073e9SAndroid Build Coastguard Worker                if x < 3:
14039*da0073e9SAndroid Build Coastguard Worker                    return y
14040*da0073e9SAndroid Build Coastguard Worker                else:
14041*da0073e9SAndroid Build Coastguard Worker                    y = y + 1
14042*da0073e9SAndroid Build Coastguard Worker                    break
14043*da0073e9SAndroid Build Coastguard Worker                y = y + 2
14044*da0073e9SAndroid Build Coastguard Worker            y = y + 1
14045*da0073e9SAndroid Build Coastguard Worker            return y
14046*da0073e9SAndroid Build Coastguard Worker
14047*da0073e9SAndroid Build Coastguard Worker        self.checkScript(nest_while_ret, (2,))
14048*da0073e9SAndroid Build Coastguard Worker        self.checkScript(nest_while_ret, (3,))
14049*da0073e9SAndroid Build Coastguard Worker        self.checkScript(nest_while_ret, (4,))
14050*da0073e9SAndroid Build Coastguard Worker
14051*da0073e9SAndroid Build Coastguard Worker        def loop_ret(x, y):
14052*da0073e9SAndroid Build Coastguard Worker            # type: (int, int) -> (int)
14053*da0073e9SAndroid Build Coastguard Worker            i = 0
14054*da0073e9SAndroid Build Coastguard Worker            for i in range(x):
14055*da0073e9SAndroid Build Coastguard Worker                if x == y:
14056*da0073e9SAndroid Build Coastguard Worker                    return x + y
14057*da0073e9SAndroid Build Coastguard Worker                i = i + y
14058*da0073e9SAndroid Build Coastguard Worker            i = i - 1
14059*da0073e9SAndroid Build Coastguard Worker            return i
14060*da0073e9SAndroid Build Coastguard Worker
14061*da0073e9SAndroid Build Coastguard Worker        self.checkScript(loop_ret, (3, 3))
14062*da0073e9SAndroid Build Coastguard Worker        self.checkScript(loop_ret, (2, 3))
14063*da0073e9SAndroid Build Coastguard Worker        self.checkScript(loop_ret, (3, 1))
14064*da0073e9SAndroid Build Coastguard Worker
14065*da0073e9SAndroid Build Coastguard Worker        def test_will_ret(y):
14066*da0073e9SAndroid Build Coastguard Worker            # type: (int) -> int
14067*da0073e9SAndroid Build Coastguard Worker            for i in range(y):
14068*da0073e9SAndroid Build Coastguard Worker                return 2
14069*da0073e9SAndroid Build Coastguard Worker            return 1
14070*da0073e9SAndroid Build Coastguard Worker
14071*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_will_ret, (0,))
14072*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_will_ret, (1,))
14073*da0073e9SAndroid Build Coastguard Worker
14074*da0073e9SAndroid Build Coastguard Worker        def test_loop_nest_ret(y):
14075*da0073e9SAndroid Build Coastguard Worker            # type: (int) -> int
14076*da0073e9SAndroid Build Coastguard Worker            for i in range(y):
14077*da0073e9SAndroid Build Coastguard Worker                for i in range(y - 2):
14078*da0073e9SAndroid Build Coastguard Worker                    return 10
14079*da0073e9SAndroid Build Coastguard Worker                return 5
14080*da0073e9SAndroid Build Coastguard Worker            return 0
14081*da0073e9SAndroid Build Coastguard Worker
14082*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_loop_nest_ret, (0,))
14083*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_loop_nest_ret, (1,))
14084*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_loop_nest_ret, (2,))
14085*da0073e9SAndroid Build Coastguard Worker
14086*da0073e9SAndroid Build Coastguard Worker    def test_nn_init(self):
14087*da0073e9SAndroid Build Coastguard Worker        tests = (
14088*da0073e9SAndroid Build Coastguard Worker            ('constant_', (lambda: (torch.ones(2, 2), 2.5)), "Tensor, float"),
14089*da0073e9SAndroid Build Coastguard Worker            ('ones_', (lambda: (torch.ones(2, 2),)), "Tensor"),
14090*da0073e9SAndroid Build Coastguard Worker            ('zeros_', (lambda: (torch.ones(2, 2),)), "Tensor"),
14091*da0073e9SAndroid Build Coastguard Worker            ('uniform_', (lambda: (torch.ones(2, 2),)), "Tensor"),
14092*da0073e9SAndroid Build Coastguard Worker            ('normal_', (lambda: (torch.ones(2, 2),)), "Tensor"),
14093*da0073e9SAndroid Build Coastguard Worker            ('xavier_normal_', (lambda: (torch.ones(2, 2),)), "Tensor"),
14094*da0073e9SAndroid Build Coastguard Worker            ('xavier_uniform_', (lambda: (torch.ones(2, 2),)), "Tensor"),
14095*da0073e9SAndroid Build Coastguard Worker        )
14096*da0073e9SAndroid Build Coastguard Worker
14097*da0073e9SAndroid Build Coastguard Worker        for name, args_fn, type_str in tests:
14098*da0073e9SAndroid Build Coastguard Worker            # Build test code
14099*da0073e9SAndroid Build Coastguard Worker            arg_str = ', '.join([chr(i + ord('a')) for i in range(len(args_fn()))])
14100*da0073e9SAndroid Build Coastguard Worker
14101*da0073e9SAndroid Build Coastguard Worker            code = dedent('''
14102*da0073e9SAndroid Build Coastguard Worker                def test({arg_str}):
14103*da0073e9SAndroid Build Coastguard Worker                    # type: ({type_str})
14104*da0073e9SAndroid Build Coastguard Worker                    return torch.nn.init.{name}({arg_str})
14105*da0073e9SAndroid Build Coastguard Worker            ''').format(arg_str=arg_str, type_str=type_str, name=name)
14106*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit(code)
14107*da0073e9SAndroid Build Coastguard Worker
14108*da0073e9SAndroid Build Coastguard Worker            # Compare functions
14109*da0073e9SAndroid Build Coastguard Worker            init_fn = getattr(torch.nn.init, name)
14110*da0073e9SAndroid Build Coastguard Worker            script_out = self.runAndSaveRNG(cu.test, args_fn())
14111*da0073e9SAndroid Build Coastguard Worker            eager_out = self.runAndSaveRNG(init_fn, args_fn())
14112*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(script_out, eager_out)
14113*da0073e9SAndroid Build Coastguard Worker
14114*da0073e9SAndroid Build Coastguard Worker            FileCheck().check_not("prim::PythonOp").run(cu.test.graph)
14115*da0073e9SAndroid Build Coastguard Worker
14116*da0073e9SAndroid Build Coastguard Worker    def test_nn_init_generator(self):
14117*da0073e9SAndroid Build Coastguard Worker        init_fns = (
14118*da0073e9SAndroid Build Coastguard Worker            'uniform_', 'normal_', 'xavier_normal_', 'xavier_uniform_',
14119*da0073e9SAndroid Build Coastguard Worker        )
14120*da0073e9SAndroid Build Coastguard Worker
14121*da0073e9SAndroid Build Coastguard Worker        for name in init_fns:
14122*da0073e9SAndroid Build Coastguard Worker            # Build test code
14123*da0073e9SAndroid Build Coastguard Worker            code = dedent('''
14124*da0073e9SAndroid Build Coastguard Worker                def test(tensor, generator):
14125*da0073e9SAndroid Build Coastguard Worker                    # type: (Tensor, Generator)
14126*da0073e9SAndroid Build Coastguard Worker                    return torch.nn.init.{name}(tensor, generator=generator)
14127*da0073e9SAndroid Build Coastguard Worker            ''').format(name=name)
14128*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit(code)
14129*da0073e9SAndroid Build Coastguard Worker
14130*da0073e9SAndroid Build Coastguard Worker            # Compare functions
14131*da0073e9SAndroid Build Coastguard Worker            init_fn = getattr(torch.nn.init, name)
14132*da0073e9SAndroid Build Coastguard Worker
14133*da0073e9SAndroid Build Coastguard Worker            torch.manual_seed(1)
14134*da0073e9SAndroid Build Coastguard Worker
14135*da0073e9SAndroid Build Coastguard Worker            g = torch.Generator()
14136*da0073e9SAndroid Build Coastguard Worker            g.manual_seed(2023)
14137*da0073e9SAndroid Build Coastguard Worker            script_out = cu.test(torch.ones(2, 2), g)
14138*da0073e9SAndroid Build Coastguard Worker
14139*da0073e9SAndroid Build Coastguard Worker            # Change the seed of the default generator to make
14140*da0073e9SAndroid Build Coastguard Worker            # sure that we're using the provided generator
14141*da0073e9SAndroid Build Coastguard Worker            torch.manual_seed(2)
14142*da0073e9SAndroid Build Coastguard Worker
14143*da0073e9SAndroid Build Coastguard Worker            g = torch.Generator()
14144*da0073e9SAndroid Build Coastguard Worker            g.manual_seed(2023)
14145*da0073e9SAndroid Build Coastguard Worker            eager_out = init_fn(torch.ones(2, 2), generator=g)
14146*da0073e9SAndroid Build Coastguard Worker
14147*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(script_out, eager_out)
14148*da0073e9SAndroid Build Coastguard Worker
14149*da0073e9SAndroid Build Coastguard Worker            FileCheck().check_not("prim::PythonOp").run(cu.test.graph)
14150*da0073e9SAndroid Build Coastguard Worker
14151*da0073e9SAndroid Build Coastguard Worker    def test_early_return_rewrite(self):
14152*da0073e9SAndroid Build Coastguard Worker        def test_foo(x: bool):
14153*da0073e9SAndroid Build Coastguard Worker            if x:
14154*da0073e9SAndroid Build Coastguard Worker                return 1
14155*da0073e9SAndroid Build Coastguard Worker            return 2
14156*da0073e9SAndroid Build Coastguard Worker
14157*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_foo, (True,))
14158*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_foo, (False,))
14159*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("prim::If", 1, exactly=True).run(torch.jit.script(test_foo).graph)
14160*da0073e9SAndroid Build Coastguard Worker
14161*da0073e9SAndroid Build Coastguard Worker        def test_multiple(x: int):
14162*da0073e9SAndroid Build Coastguard Worker            if x == 5:
14163*da0073e9SAndroid Build Coastguard Worker                return x * x
14164*da0073e9SAndroid Build Coastguard Worker            else:
14165*da0073e9SAndroid Build Coastguard Worker                y = 2 * x
14166*da0073e9SAndroid Build Coastguard Worker
14167*da0073e9SAndroid Build Coastguard Worker            z = y * 2
14168*da0073e9SAndroid Build Coastguard Worker            if z == 8:
14169*da0073e9SAndroid Build Coastguard Worker                return 1
14170*da0073e9SAndroid Build Coastguard Worker
14171*da0073e9SAndroid Build Coastguard Worker            if z != 16:
14172*da0073e9SAndroid Build Coastguard Worker                z = z - 2
14173*da0073e9SAndroid Build Coastguard Worker                abc = 4
14174*da0073e9SAndroid Build Coastguard Worker            else:
14175*da0073e9SAndroid Build Coastguard Worker                return 3
14176*da0073e9SAndroid Build Coastguard Worker
14177*da0073e9SAndroid Build Coastguard Worker            z = z * abc
14178*da0073e9SAndroid Build Coastguard Worker            return z * z * z
14179*da0073e9SAndroid Build Coastguard Worker
14180*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_multiple, (5,))
14181*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_multiple, (2,))
14182*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_multiple, (4,))
14183*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_multiple, (3,))
14184*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_multiple, (10,))
14185*da0073e9SAndroid Build Coastguard Worker
14186*da0073e9SAndroid Build Coastguard Worker        graph = torch.jit.script(test_multiple).graph
14187*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("prim::If", 3, exactly=True).run(graph)
14188*da0073e9SAndroid Build Coastguard Worker
14189*da0073e9SAndroid Build Coastguard Worker    def test_is_scripting_metacompile(self):
14190*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
14191*da0073e9SAndroid Build Coastguard Worker        def foo():
14192*da0073e9SAndroid Build Coastguard Worker            if torch.jit.is_scripting():
14193*da0073e9SAndroid Build Coastguard Worker                return 1
14194*da0073e9SAndroid Build Coastguard Worker            else:
14195*da0073e9SAndroid Build Coastguard Worker                print("hello") + 2  # will not be compiled
14196*da0073e9SAndroid Build Coastguard Worker
14197*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(), 1)
14198*da0073e9SAndroid Build Coastguard Worker
14199*da0073e9SAndroid Build Coastguard Worker    def test_boolean_literal_constant_metacompile(self):
14200*da0073e9SAndroid Build Coastguard Worker        class Mod(torch.nn.Module):
14201*da0073e9SAndroid Build Coastguard Worker            __constants__ = ['val']
14202*da0073e9SAndroid Build Coastguard Worker
14203*da0073e9SAndroid Build Coastguard Worker            def __init__(self, val):
14204*da0073e9SAndroid Build Coastguard Worker                super().__init__()
14205*da0073e9SAndroid Build Coastguard Worker                self.val = val
14206*da0073e9SAndroid Build Coastguard Worker
14207*da0073e9SAndroid Build Coastguard Worker            def forward(self):
14208*da0073e9SAndroid Build Coastguard Worker                if self.val:
14209*da0073e9SAndroid Build Coastguard Worker                    return 1
14210*da0073e9SAndroid Build Coastguard Worker                else:
14211*da0073e9SAndroid Build Coastguard Worker                    return "2"
14212*da0073e9SAndroid Build Coastguard Worker
14213*da0073e9SAndroid Build Coastguard Worker        self.checkModule(Mod(True), ())
14214*da0073e9SAndroid Build Coastguard Worker        self.checkModule(Mod(False), ())
14215*da0073e9SAndroid Build Coastguard Worker
14216*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
14217*da0073e9SAndroid Build Coastguard Worker        def foo():
14218*da0073e9SAndroid Build Coastguard Worker            if True:
14219*da0073e9SAndroid Build Coastguard Worker                return 1
14220*da0073e9SAndroid Build Coastguard Worker            else:
14221*da0073e9SAndroid Build Coastguard Worker                return "2"
14222*da0073e9SAndroid Build Coastguard Worker
14223*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(), 1)
14224*da0073e9SAndroid Build Coastguard Worker
14225*da0073e9SAndroid Build Coastguard Worker    def test_assert_is_scripting_metacompile(self):
14226*da0073e9SAndroid Build Coastguard Worker        def foo():
14227*da0073e9SAndroid Build Coastguard Worker            assert not torch.jit.is_scripting(), "TestErrorMsg"
14228*da0073e9SAndroid Build Coastguard Worker            print("hello") + 2  # will not be compiled
14229*da0073e9SAndroid Build Coastguard Worker
14230*da0073e9SAndroid Build Coastguard Worker        f = torch.jit.script(foo)
14231*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(torch.jit.Error, "TestErrorMsg"):
14232*da0073e9SAndroid Build Coastguard Worker            f()
14233*da0073e9SAndroid Build Coastguard Worker
14234*da0073e9SAndroid Build Coastguard Worker    def test_isinstance_metacompile(self):
14235*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
14236*da0073e9SAndroid Build Coastguard Worker        def test_primitive_type(x):
14237*da0073e9SAndroid Build Coastguard Worker            # type: (int) -> int
14238*da0073e9SAndroid Build Coastguard Worker            if isinstance(x, int):
14239*da0073e9SAndroid Build Coastguard Worker                return x + 1
14240*da0073e9SAndroid Build Coastguard Worker            else:
14241*da0073e9SAndroid Build Coastguard Worker                return x - 1
14242*da0073e9SAndroid Build Coastguard Worker
14243*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(test_primitive_type(1), 2)
14244*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(Exception, "Expected a value of type"):
14245*da0073e9SAndroid Build Coastguard Worker            test_primitive_type(1.5)
14246*da0073e9SAndroid Build Coastguard Worker
14247*da0073e9SAndroid Build Coastguard Worker        _MyNamedTuple = namedtuple('_MyNamedTuple', ['value'])
14248*da0073e9SAndroid Build Coastguard Worker
14249*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
14250*da0073e9SAndroid Build Coastguard Worker        def test_non_primitive_types(x):
14251*da0073e9SAndroid Build Coastguard Worker            # type: (_MyNamedTuple) -> Tensor
14252*da0073e9SAndroid Build Coastguard Worker            if isinstance(1, _MyNamedTuple):
14253*da0073e9SAndroid Build Coastguard Worker                return 10
14254*da0073e9SAndroid Build Coastguard Worker
14255*da0073e9SAndroid Build Coastguard Worker            if isinstance(x, _MyNamedTuple):
14256*da0073e9SAndroid Build Coastguard Worker                return x.value + 1
14257*da0073e9SAndroid Build Coastguard Worker            else:
14258*da0073e9SAndroid Build Coastguard Worker                return 1
14259*da0073e9SAndroid Build Coastguard Worker
14260*da0073e9SAndroid Build Coastguard Worker        out = test_non_primitive_types(_MyNamedTuple(value=torch.tensor(5.0)))
14261*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, torch.tensor(6.0))
14262*da0073e9SAndroid Build Coastguard Worker
14263*da0073e9SAndroid Build Coastguard Worker    def test_namedtuple_type_inference(self):
14264*da0073e9SAndroid Build Coastguard Worker        _AnnotatedNamedTuple = NamedTuple('_NamedTupleAnnotated', [('value', int)])  # noqa: UP014
14265*da0073e9SAndroid Build Coastguard Worker        _UnannotatedNamedTuple = namedtuple('_NamedTupleUnAnnotated', ['value'])
14266*da0073e9SAndroid Build Coastguard Worker
14267*da0073e9SAndroid Build Coastguard Worker        def test_check_named_tuple_value():
14268*da0073e9SAndroid Build Coastguard Worker            named_tuple = _AnnotatedNamedTuple(1)
14269*da0073e9SAndroid Build Coastguard Worker            return named_tuple.value
14270*da0073e9SAndroid Build Coastguard Worker
14271*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_check_named_tuple_value, ())
14272*da0073e9SAndroid Build Coastguard Worker
14273*da0073e9SAndroid Build Coastguard Worker        def test_error():
14274*da0073e9SAndroid Build Coastguard Worker            return _UnannotatedNamedTuple(1)
14275*da0073e9SAndroid Build Coastguard Worker
14276*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"Expected a value of type \'Tensor \(inferred\)\' "
14277*da0073e9SAndroid Build Coastguard Worker                                                  r"for argument \'value\' but instead found type \'int\'."):
14278*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(test_error)
14279*da0073e9SAndroid Build Coastguard Worker
14280*da0073e9SAndroid Build Coastguard Worker    def test_namedtuple_default_values_simple_type(self):
14281*da0073e9SAndroid Build Coastguard Worker
14282*da0073e9SAndroid Build Coastguard Worker        class Point(NamedTuple):
14283*da0073e9SAndroid Build Coastguard Worker            x: Optional[int] = None
14284*da0073e9SAndroid Build Coastguard Worker            y: int = 2
14285*da0073e9SAndroid Build Coastguard Worker
14286*da0073e9SAndroid Build Coastguard Worker        make_global(Point)
14287*da0073e9SAndroid Build Coastguard Worker
14288*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
14289*da0073e9SAndroid Build Coastguard Worker            def forward(self, point: Point):
14290*da0073e9SAndroid Build Coastguard Worker                return point
14291*da0073e9SAndroid Build Coastguard Worker
14292*da0073e9SAndroid Build Coastguard Worker        p = Point(x=3, y=2)
14293*da0073e9SAndroid Build Coastguard Worker
14294*da0073e9SAndroid Build Coastguard Worker        self.checkModule(M(), (p,))
14295*da0073e9SAndroid Build Coastguard Worker        self.checkModule(M(), (Point(),))
14296*da0073e9SAndroid Build Coastguard Worker
14297*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(M())
14298*da0073e9SAndroid Build Coastguard Worker
14299*da0073e9SAndroid Build Coastguard Worker        FileCheck().check(r"NamedTuple(x : int? = None, y : int = 2))")   \
14300*da0073e9SAndroid Build Coastguard Worker                   .run(m.graph)
14301*da0073e9SAndroid Build Coastguard Worker
14302*da0073e9SAndroid Build Coastguard Worker    def test_namedtuple_default_values_missing(self):
14303*da0073e9SAndroid Build Coastguard Worker
14304*da0073e9SAndroid Build Coastguard Worker        class Point(NamedTuple):
14305*da0073e9SAndroid Build Coastguard Worker            x: Optional[int]
14306*da0073e9SAndroid Build Coastguard Worker            y: int
14307*da0073e9SAndroid Build Coastguard Worker            z: int = 3
14308*da0073e9SAndroid Build Coastguard Worker
14309*da0073e9SAndroid Build Coastguard Worker        make_global(Point)
14310*da0073e9SAndroid Build Coastguard Worker
14311*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
14312*da0073e9SAndroid Build Coastguard Worker            def forward(self, point: Point):
14313*da0073e9SAndroid Build Coastguard Worker                return point
14314*da0073e9SAndroid Build Coastguard Worker
14315*da0073e9SAndroid Build Coastguard Worker        p1 = Point(x=3, y=2)
14316*da0073e9SAndroid Build Coastguard Worker        p2 = Point(x=3, y=2, z=1)
14317*da0073e9SAndroid Build Coastguard Worker
14318*da0073e9SAndroid Build Coastguard Worker        self.checkModule(M(), (p1,))
14319*da0073e9SAndroid Build Coastguard Worker        self.checkModule(M(), (p2,))
14320*da0073e9SAndroid Build Coastguard Worker
14321*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(M())
14322*da0073e9SAndroid Build Coastguard Worker
14323*da0073e9SAndroid Build Coastguard Worker        FileCheck().check(r"NamedTuple(x : int?, y : int, z : int = 3))")   \
14324*da0073e9SAndroid Build Coastguard Worker                   .run(m.graph)
14325*da0073e9SAndroid Build Coastguard Worker
14326*da0073e9SAndroid Build Coastguard Worker    def test_namedtuple_default_values_container_type(self):
14327*da0073e9SAndroid Build Coastguard Worker
14328*da0073e9SAndroid Build Coastguard Worker        class Point(NamedTuple):
14329*da0073e9SAndroid Build Coastguard Worker            x: Optional[List[int]] = None
14330*da0073e9SAndroid Build Coastguard Worker            y: List[int] = [1, 2, 3]
14331*da0073e9SAndroid Build Coastguard Worker            z: Optional[Dict[str, int]] = {"a": 1}
14332*da0073e9SAndroid Build Coastguard Worker
14333*da0073e9SAndroid Build Coastguard Worker        make_global(Point)
14334*da0073e9SAndroid Build Coastguard Worker
14335*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
14336*da0073e9SAndroid Build Coastguard Worker            def forward(self, point: Point):
14337*da0073e9SAndroid Build Coastguard Worker                return point
14338*da0073e9SAndroid Build Coastguard Worker
14339*da0073e9SAndroid Build Coastguard Worker        p = Point(x=[4, 5, 6], y=[3, 2, 1], z={"b": 2})
14340*da0073e9SAndroid Build Coastguard Worker
14341*da0073e9SAndroid Build Coastguard Worker        self.checkModule(M(), (p,))
14342*da0073e9SAndroid Build Coastguard Worker        self.checkModule(M(), (Point(),))
14343*da0073e9SAndroid Build Coastguard Worker
14344*da0073e9SAndroid Build Coastguard Worker        m = torch.jit.script(M())
14345*da0073e9SAndroid Build Coastguard Worker
14346*da0073e9SAndroid Build Coastguard Worker        first_line = r"NamedTuple(x : int[]? = None, y : int[] = "    \
14347*da0073e9SAndroid Build Coastguard Worker                     r"[1, 2, 3], z : Dict(str, int)? = {a: 1}))"
14348*da0073e9SAndroid Build Coastguard Worker
14349*da0073e9SAndroid Build Coastguard Worker        FileCheck().check(first_line)   \
14350*da0073e9SAndroid Build Coastguard Worker                   .run(m.graph)
14351*da0073e9SAndroid Build Coastguard Worker
14352*da0073e9SAndroid Build Coastguard Worker    def test_namedtuple_default_values_Tensor_type(self):
14353*da0073e9SAndroid Build Coastguard Worker
14354*da0073e9SAndroid Build Coastguard Worker        class Point(NamedTuple):
14355*da0073e9SAndroid Build Coastguard Worker            x: torch.Tensor = torch.rand(2, 3)
14356*da0073e9SAndroid Build Coastguard Worker
14357*da0073e9SAndroid Build Coastguard Worker        make_global(Point)
14358*da0073e9SAndroid Build Coastguard Worker
14359*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
14360*da0073e9SAndroid Build Coastguard Worker            def forward(self, point: Point):
14361*da0073e9SAndroid Build Coastguard Worker                return point
14362*da0073e9SAndroid Build Coastguard Worker
14363*da0073e9SAndroid Build Coastguard Worker        p = Point(x=torch.rand(2, 3))
14364*da0073e9SAndroid Build Coastguard Worker
14365*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Tensors are not "
14366*da0073e9SAndroid Build Coastguard Worker                                    "supported as default NamedTuple "
14367*da0073e9SAndroid Build Coastguard Worker                                    "fields"):
14368*da0073e9SAndroid Build Coastguard Worker            m = torch.jit.script(M())
14369*da0073e9SAndroid Build Coastguard Worker            m(p)
14370*da0073e9SAndroid Build Coastguard Worker
14371*da0073e9SAndroid Build Coastguard Worker    def test_namedtuple_default_values_using_factory_constructor(self):
14372*da0073e9SAndroid Build Coastguard Worker        Pair = namedtuple("Pair", ["x", "y"], defaults=(1, 2))
14373*da0073e9SAndroid Build Coastguard Worker
14374*da0073e9SAndroid Build Coastguard Worker        make_global(Pair)
14375*da0073e9SAndroid Build Coastguard Worker
14376*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
14377*da0073e9SAndroid Build Coastguard Worker        def fn(x: Pair) -> Pair:
14378*da0073e9SAndroid Build Coastguard Worker            return x
14379*da0073e9SAndroid Build Coastguard Worker
14380*da0073e9SAndroid Build Coastguard Worker        # TODO: We can't use `checkScript` with the NamedTuple factory
14381*da0073e9SAndroid Build Coastguard Worker        # constructor. Using the factory constructor with TorchScript
14382*da0073e9SAndroid Build Coastguard Worker        # TorchScript creates an anonymous `NamedTuple` class instead of
14383*da0073e9SAndroid Build Coastguard Worker        # preserving the actual name. For example, the actual generated
14384*da0073e9SAndroid Build Coastguard Worker        # signature in this case is:
14385*da0073e9SAndroid Build Coastguard Worker        #   graph(%x.1 : NamedTuple(x : Tensor, y : Tensor))
14386*da0073e9SAndroid Build Coastguard Worker        # It looks like similar test cases have had this issue as well
14387*da0073e9SAndroid Build Coastguard Worker        # (see: `test_namedtuple_python`).
14388*da0073e9SAndroid Build Coastguard Worker        FileCheck().check(r"NamedTuple(x : Tensor = 1, y : Tensor = 2))")   \
14389*da0073e9SAndroid Build Coastguard Worker                   .check_next(r"return (%x.1)")    \
14390*da0073e9SAndroid Build Coastguard Worker                   .run(fn.graph)
14391*da0073e9SAndroid Build Coastguard Worker
14392*da0073e9SAndroid Build Coastguard Worker    def test_isinstance_dynamic(self):
14393*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
14394*da0073e9SAndroid Build Coastguard Worker        def foo(a):
14395*da0073e9SAndroid Build Coastguard Worker            # type: (Optional[List[int]]) -> int
14396*da0073e9SAndroid Build Coastguard Worker            b = 0
14397*da0073e9SAndroid Build Coastguard Worker            if isinstance(a, (int, (float,), list, str)):
14398*da0073e9SAndroid Build Coastguard Worker                b += 1
14399*da0073e9SAndroid Build Coastguard Worker            if isinstance(a, (int, str)):
14400*da0073e9SAndroid Build Coastguard Worker                b += 1
14401*da0073e9SAndroid Build Coastguard Worker            if isinstance(a, List[int]):
14402*da0073e9SAndroid Build Coastguard Worker                b += 1
14403*da0073e9SAndroid Build Coastguard Worker            return b
14404*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo([3, 4]), 2)
14405*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(None), 0)
14406*da0073e9SAndroid Build Coastguard Worker
14407*da0073e9SAndroid Build Coastguard Worker    def test_function_overloads(self):
14408*da0073e9SAndroid Build Coastguard Worker        # TODO: pyflakes currently does not compose @overload annotation with other
14409*da0073e9SAndroid Build Coastguard Worker        # decorators. This is fixed on master but not on version 2.1.1.
14410*da0073e9SAndroid Build Coastguard Worker        # Next version update remove noqa and add @typing.overload annotation
14411*da0073e9SAndroid Build Coastguard Worker
14412*da0073e9SAndroid Build Coastguard Worker        @torch.jit._overload  # noqa: F811
14413*da0073e9SAndroid Build Coastguard Worker        def test_simple(x1):  # noqa: F811
14414*da0073e9SAndroid Build Coastguard Worker            # type: (int) -> int
14415*da0073e9SAndroid Build Coastguard Worker            pass
14416*da0073e9SAndroid Build Coastguard Worker
14417*da0073e9SAndroid Build Coastguard Worker        @torch.jit._overload  # noqa: F811
14418*da0073e9SAndroid Build Coastguard Worker        def test_simple(x1):  # noqa: F811
14419*da0073e9SAndroid Build Coastguard Worker            # type: (float) -> float
14420*da0073e9SAndroid Build Coastguard Worker            pass
14421*da0073e9SAndroid Build Coastguard Worker
14422*da0073e9SAndroid Build Coastguard Worker        def test_simple(x1):  # noqa: F811
14423*da0073e9SAndroid Build Coastguard Worker            return x1
14424*da0073e9SAndroid Build Coastguard Worker
14425*da0073e9SAndroid Build Coastguard Worker        def invoke_function():
14426*da0073e9SAndroid Build Coastguard Worker            return test_simple(1.0), test_simple(.5)
14427*da0073e9SAndroid Build Coastguard Worker
14428*da0073e9SAndroid Build Coastguard Worker        self.checkScript(invoke_function, ())
14429*da0073e9SAndroid Build Coastguard Worker
14430*da0073e9SAndroid Build Coastguard Worker        # testing that the functions are cached
14431*da0073e9SAndroid Build Coastguard Worker        compiled_fns_1 = torch.jit._script._get_overloads(test_simple)
14432*da0073e9SAndroid Build Coastguard Worker        compiled_fns_2 = torch.jit._script._get_overloads(test_simple)
14433*da0073e9SAndroid Build Coastguard Worker        for a, b in zip(compiled_fns_1, compiled_fns_2):
14434*da0073e9SAndroid Build Coastguard Worker            self.assertIs(a.graph, b.graph)
14435*da0073e9SAndroid Build Coastguard Worker
14436*da0073e9SAndroid Build Coastguard Worker        old_func = test_simple
14437*da0073e9SAndroid Build Coastguard Worker
14438*da0073e9SAndroid Build Coastguard Worker        # testing that new functions added work with caching
14439*da0073e9SAndroid Build Coastguard Worker        @torch.jit._overload  # noqa: F811
14440*da0073e9SAndroid Build Coastguard Worker        def test_simple(x1):  # noqa: F811
14441*da0073e9SAndroid Build Coastguard Worker            # type: (str) -> str
14442*da0073e9SAndroid Build Coastguard Worker            pass
14443*da0073e9SAndroid Build Coastguard Worker
14444*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
14445*da0073e9SAndroid Build Coastguard Worker        def my_func():
14446*da0073e9SAndroid Build Coastguard Worker            return old_func("hi")
14447*da0073e9SAndroid Build Coastguard Worker
14448*da0073e9SAndroid Build Coastguard Worker        # testing new function same qualified name
14449*da0073e9SAndroid Build Coastguard Worker        @torch.jit._overload  # noqa: F811
14450*da0073e9SAndroid Build Coastguard Worker        def test_simple(a, b):  # noqa: F811
14451*da0073e9SAndroid Build Coastguard Worker            # type: (int, int) -> int
14452*da0073e9SAndroid Build Coastguard Worker            pass
14453*da0073e9SAndroid Build Coastguard Worker
14454*da0073e9SAndroid Build Coastguard Worker        def test_simple(a, b):
14455*da0073e9SAndroid Build Coastguard Worker            return a + b
14456*da0073e9SAndroid Build Coastguard Worker
14457*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
14458*da0073e9SAndroid Build Coastguard Worker        def fn():
14459*da0073e9SAndroid Build Coastguard Worker            return test_simple(3, 4)
14460*da0073e9SAndroid Build Coastguard Worker
14461*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(), 7)
14462*da0073e9SAndroid Build Coastguard Worker
14463*da0073e9SAndroid Build Coastguard Worker        # currently we take the default values have to be specified in the
14464*da0073e9SAndroid Build Coastguard Worker        # overload as well - TODO take them from implementation and apply
14465*da0073e9SAndroid Build Coastguard Worker        # where the type is valid.
14466*da0073e9SAndroid Build Coastguard Worker        @torch.jit._overload  # noqa: F811
14467*da0073e9SAndroid Build Coastguard Worker        def identity(x1):  # noqa: F811
14468*da0073e9SAndroid Build Coastguard Worker            # type: (str) -> str
14469*da0073e9SAndroid Build Coastguard Worker            pass
14470*da0073e9SAndroid Build Coastguard Worker
14471*da0073e9SAndroid Build Coastguard Worker        @torch.jit._overload  # noqa: F811
14472*da0073e9SAndroid Build Coastguard Worker        def identity(x1):  # noqa: F811
14473*da0073e9SAndroid Build Coastguard Worker            # type: (float) -> float
14474*da0073e9SAndroid Build Coastguard Worker            pass
14475*da0073e9SAndroid Build Coastguard Worker
14476*da0073e9SAndroid Build Coastguard Worker        def identity(x1=1.0):  # noqa: F811
14477*da0073e9SAndroid Build Coastguard Worker            return x1
14478*da0073e9SAndroid Build Coastguard Worker
14479*da0073e9SAndroid Build Coastguard Worker        def invoke():
14480*da0073e9SAndroid Build Coastguard Worker            return identity(), identity(.5), identity("hi")
14481*da0073e9SAndroid Build Coastguard Worker
14482*da0073e9SAndroid Build Coastguard Worker        self.checkScript(invoke, ())
14483*da0073e9SAndroid Build Coastguard Worker
14484*da0073e9SAndroid Build Coastguard Worker        def schema_match_failure():
14485*da0073e9SAndroid Build Coastguard Worker            return identity((1, 2))
14486*da0073e9SAndroid Build Coastguard Worker
14487*da0073e9SAndroid Build Coastguard Worker        thrown = False
14488*da0073e9SAndroid Build Coastguard Worker        try:
14489*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(schema_match_failure)
14490*da0073e9SAndroid Build Coastguard Worker        except Exception as e:
14491*da0073e9SAndroid Build Coastguard Worker            thrown = True
14492*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(r"of type 'str'" in str(e) and r"of type 'float" in str(e))
14493*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(thrown)
14494*da0073e9SAndroid Build Coastguard Worker
14495*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(Exception, "cannot be directly compiled"):
14496*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(identity)
14497*da0073e9SAndroid Build Coastguard Worker
14498*da0073e9SAndroid Build Coastguard Worker        @torch.jit._overload  # noqa: F811
14499*da0073e9SAndroid Build Coastguard Worker        def impl_compile_failure(x, y):  # noqa: F811
14500*da0073e9SAndroid Build Coastguard Worker            # type: (str, str) -> (str)
14501*da0073e9SAndroid Build Coastguard Worker            pass
14502*da0073e9SAndroid Build Coastguard Worker
14503*da0073e9SAndroid Build Coastguard Worker        @torch.jit._overload  # noqa: F811
14504*da0073e9SAndroid Build Coastguard Worker        def impl_compile_failure(x, y):  # noqa: F811
14505*da0073e9SAndroid Build Coastguard Worker            # type: (int, int) -> (int)
14506*da0073e9SAndroid Build Coastguard Worker            pass
14507*da0073e9SAndroid Build Coastguard Worker
14508*da0073e9SAndroid Build Coastguard Worker        def impl_compile_failure(x, y):  # noqa: F811
14509*da0073e9SAndroid Build Coastguard Worker            return x - y
14510*da0073e9SAndroid Build Coastguard Worker
14511*da0073e9SAndroid Build Coastguard Worker        def test():
14512*da0073e9SAndroid Build Coastguard Worker            impl_compile_failure("one", "two")
14513*da0073e9SAndroid Build Coastguard Worker
14514*da0073e9SAndroid Build Coastguard Worker
14515*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(Exception, "Arguments for call are not valid"):
14516*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(test)
14517*da0073e9SAndroid Build Coastguard Worker
14518*da0073e9SAndroid Build Coastguard Worker        @torch.jit._overload  # noqa: F811
14519*da0073e9SAndroid Build Coastguard Worker        def good_overload(x=1):  # noqa: F811
14520*da0073e9SAndroid Build Coastguard Worker            # type: (int) -> (int)
14521*da0073e9SAndroid Build Coastguard Worker            pass
14522*da0073e9SAndroid Build Coastguard Worker
14523*da0073e9SAndroid Build Coastguard Worker        def good_overload(x=1):  # noqa: F811
14524*da0073e9SAndroid Build Coastguard Worker            return x
14525*da0073e9SAndroid Build Coastguard Worker
14526*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
14527*da0073e9SAndroid Build Coastguard Worker        def foo():
14528*da0073e9SAndroid Build Coastguard Worker            return good_overload()
14529*da0073e9SAndroid Build Coastguard Worker
14530*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(), 1)
14531*da0073e9SAndroid Build Coastguard Worker
14532*da0073e9SAndroid Build Coastguard Worker
14533*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(Exception, "must equal to the default parameter"):
14534*da0073e9SAndroid Build Coastguard Worker            @torch.jit._overload  # noqa: F811
14535*da0073e9SAndroid Build Coastguard Worker            def bad_default_on_overload(x, y=2):  # noqa: F811
14536*da0073e9SAndroid Build Coastguard Worker                # type: (int, int) -> (int)
14537*da0073e9SAndroid Build Coastguard Worker                pass
14538*da0073e9SAndroid Build Coastguard Worker
14539*da0073e9SAndroid Build Coastguard Worker            def bad_default_on_overload(x, y=1):  # noqa: F811
14540*da0073e9SAndroid Build Coastguard Worker                # type: (int, int) -> (int)
14541*da0073e9SAndroid Build Coastguard Worker                pass
14542*da0073e9SAndroid Build Coastguard Worker
14543*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
14544*da0073e9SAndroid Build Coastguard Worker            def test():
14545*da0073e9SAndroid Build Coastguard Worker                return bad_default_on_overload(1, 2)
14546*da0073e9SAndroid Build Coastguard Worker
14547*da0073e9SAndroid Build Coastguard Worker        @torch.jit._overload  # noqa: F811
14548*da0073e9SAndroid Build Coastguard Worker        def diff_default(x):  # noqa: F811
14549*da0073e9SAndroid Build Coastguard Worker            # type: (int) -> int
14550*da0073e9SAndroid Build Coastguard Worker            pass
14551*da0073e9SAndroid Build Coastguard Worker
14552*da0073e9SAndroid Build Coastguard Worker        @torch.jit._overload  # noqa: F811
14553*da0073e9SAndroid Build Coastguard Worker        def diff_default(x):  # noqa: F811
14554*da0073e9SAndroid Build Coastguard Worker            # type: (str) -> str
14555*da0073e9SAndroid Build Coastguard Worker            pass
14556*da0073e9SAndroid Build Coastguard Worker
14557*da0073e9SAndroid Build Coastguard Worker        def diff_default(x="hi"):  # noqa: F811
14558*da0073e9SAndroid Build Coastguard Worker            return x
14559*da0073e9SAndroid Build Coastguard Worker
14560*da0073e9SAndroid Build Coastguard Worker        def test():
14561*da0073e9SAndroid Build Coastguard Worker            return diff_default(), diff_default(2), diff_default("abc")
14562*da0073e9SAndroid Build Coastguard Worker
14563*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(test(), torch.jit.script(test)())
14564*da0073e9SAndroid Build Coastguard Worker
14565*da0073e9SAndroid Build Coastguard Worker        @torch.jit._overload  # noqa: F811
14566*da0073e9SAndroid Build Coastguard Worker        def diff_num_params(x):  # noqa: F811
14567*da0073e9SAndroid Build Coastguard Worker            # type: (float) -> float
14568*da0073e9SAndroid Build Coastguard Worker            pass
14569*da0073e9SAndroid Build Coastguard Worker
14570*da0073e9SAndroid Build Coastguard Worker        @torch.jit._overload  # noqa: F811
14571*da0073e9SAndroid Build Coastguard Worker        def diff_num_params(x, y):  # noqa: F811
14572*da0073e9SAndroid Build Coastguard Worker            # type: (int, int) -> int
14573*da0073e9SAndroid Build Coastguard Worker            pass
14574*da0073e9SAndroid Build Coastguard Worker
14575*da0073e9SAndroid Build Coastguard Worker        def diff_num_params(x, y=2, z=3):  # noqa: F811
14576*da0073e9SAndroid Build Coastguard Worker            # type: (Union[float, int], int, int)
14577*da0073e9SAndroid Build Coastguard Worker            return x + y + z
14578*da0073e9SAndroid Build Coastguard Worker
14579*da0073e9SAndroid Build Coastguard Worker        def test():
14580*da0073e9SAndroid Build Coastguard Worker            return diff_num_params(1.0), diff_num_params(1, 2), diff_num_params(1), diff_num_params(1, 2, 3)
14581*da0073e9SAndroid Build Coastguard Worker
14582*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(test(), torch.jit.script(test)())
14583*da0073e9SAndroid Build Coastguard Worker
14584*da0073e9SAndroid Build Coastguard Worker        @torch.jit._overload  # noqa: F811
14585*da0073e9SAndroid Build Coastguard Worker        def diff_num_params_no_annot():
14586*da0073e9SAndroid Build Coastguard Worker            # type: () -> int
14587*da0073e9SAndroid Build Coastguard Worker            pass
14588*da0073e9SAndroid Build Coastguard Worker
14589*da0073e9SAndroid Build Coastguard Worker        def diff_num_params_no_annot(x=1):    # noqa: F811
14590*da0073e9SAndroid Build Coastguard Worker            return x
14591*da0073e9SAndroid Build Coastguard Worker
14592*da0073e9SAndroid Build Coastguard Worker        def test():
14593*da0073e9SAndroid Build Coastguard Worker            return diff_num_params_no_annot(1.0)
14594*da0073e9SAndroid Build Coastguard Worker
14595*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(Exception, "Parameters not specified"):
14596*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(test)
14597*da0073e9SAndroid Build Coastguard Worker
14598*da0073e9SAndroid Build Coastguard Worker    def test_function_overload_misuse(self):
14599*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Only `pass` statement or `...` can be the body"):
14600*da0073e9SAndroid Build Coastguard Worker            @torch.jit._overload
14601*da0073e9SAndroid Build Coastguard Worker            def wrong_decl_body(x: str) -> str:
14602*da0073e9SAndroid Build Coastguard Worker                return x + "0"
14603*da0073e9SAndroid Build Coastguard Worker
14604*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Only `pass` statement or `...` can be the body"):
14605*da0073e9SAndroid Build Coastguard Worker            class MyClass:
14606*da0073e9SAndroid Build Coastguard Worker                @torch.jit._overload_method
14607*da0073e9SAndroid Build Coastguard Worker                def method(self):
14608*da0073e9SAndroid Build Coastguard Worker                    return 0
14609*da0073e9SAndroid Build Coastguard Worker
14610*da0073e9SAndroid Build Coastguard Worker        @torch.jit._overload
14611*da0073e9SAndroid Build Coastguard Worker        def null_overload(x: int) -> int: ...  # noqa: E704
14612*da0073e9SAndroid Build Coastguard Worker
14613*da0073e9SAndroid Build Coastguard Worker        @torch.jit._overload  # noqa: F811
14614*da0073e9SAndroid Build Coastguard Worker        def null_overload(x: str) -> str:  # noqa: F811
14615*da0073e9SAndroid Build Coastguard Worker            pass
14616*da0073e9SAndroid Build Coastguard Worker
14617*da0073e9SAndroid Build Coastguard Worker        def null_overload_driver():
14618*da0073e9SAndroid Build Coastguard Worker            return null_overload(0)
14619*da0073e9SAndroid Build Coastguard Worker
14620*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'Implementation for the function ".+" is missing.'):
14621*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(null_overload_driver)
14622*da0073e9SAndroid Build Coastguard Worker
14623*da0073e9SAndroid Build Coastguard Worker        class OverloadMisuse(torch.nn.Module):
14624*da0073e9SAndroid Build Coastguard Worker            @torch.jit._overload_method
14625*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: int):
14626*da0073e9SAndroid Build Coastguard Worker                pass
14627*da0073e9SAndroid Build Coastguard Worker
14628*da0073e9SAndroid Build Coastguard Worker            @torch.jit._overload_method  # noqa: F811
14629*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: Tensor):  # noqa: F811
14630*da0073e9SAndroid Build Coastguard Worker                pass
14631*da0073e9SAndroid Build Coastguard Worker
14632*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'Implementation for the method ".+" is missing.'):
14633*da0073e9SAndroid Build Coastguard Worker            m = torch.jit.script(OverloadMisuse())
14634*da0073e9SAndroid Build Coastguard Worker
14635*da0073e9SAndroid Build Coastguard Worker
14636*da0073e9SAndroid Build Coastguard Worker    def test_script_method_torch_function_overload(self):
14637*da0073e9SAndroid Build Coastguard Worker        class MyCustomTensor(torch.Tensor):
14638*da0073e9SAndroid Build Coastguard Worker            pass
14639*da0073e9SAndroid Build Coastguard Worker
14640*da0073e9SAndroid Build Coastguard Worker        class MyCustomModule(torch.nn.Module):
14641*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
14642*da0073e9SAndroid Build Coastguard Worker                return torch.relu(x)
14643*da0073e9SAndroid Build Coastguard Worker
14644*da0073e9SAndroid Build Coastguard Worker        scripted_mod = torch.jit.script(MyCustomModule())
14645*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([3.0])
14646*da0073e9SAndroid Build Coastguard Worker        ref_out = scripted_mod(t)
14647*da0073e9SAndroid Build Coastguard Worker
14648*da0073e9SAndroid Build Coastguard Worker        t_custom = MyCustomTensor([3.0])
14649*da0073e9SAndroid Build Coastguard Worker        out1 = scripted_mod(t_custom)
14650*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out1, ref_out)
14651*da0073e9SAndroid Build Coastguard Worker
14652*da0073e9SAndroid Build Coastguard Worker        out2 = scripted_mod.forward(t_custom)
14653*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out2, ref_out)
14654*da0073e9SAndroid Build Coastguard Worker
14655*da0073e9SAndroid Build Coastguard Worker    def test_function_overloading_isinstance(self):
14656*da0073e9SAndroid Build Coastguard Worker        @torch.jit._overload  # noqa: F811
14657*da0073e9SAndroid Build Coastguard Worker        def my_conv(x, y):  # noqa: F811
14658*da0073e9SAndroid Build Coastguard Worker            # type: (float, str) -> (float)
14659*da0073e9SAndroid Build Coastguard Worker            pass
14660*da0073e9SAndroid Build Coastguard Worker
14661*da0073e9SAndroid Build Coastguard Worker        @torch.jit._overload  # noqa: F811
14662*da0073e9SAndroid Build Coastguard Worker        def my_conv(x, y):  # noqa: F811
14663*da0073e9SAndroid Build Coastguard Worker            # type: (float, float) -> (float)
14664*da0073e9SAndroid Build Coastguard Worker            pass
14665*da0073e9SAndroid Build Coastguard Worker
14666*da0073e9SAndroid Build Coastguard Worker        def my_conv(x, y=2.0):  # noqa: F811
14667*da0073e9SAndroid Build Coastguard Worker            if isinstance(y, str):
14668*da0073e9SAndroid Build Coastguard Worker                if y == "hi":
14669*da0073e9SAndroid Build Coastguard Worker                    return 4.0 - x
14670*da0073e9SAndroid Build Coastguard Worker                else:
14671*da0073e9SAndroid Build Coastguard Worker                    return 5.0 - x
14672*da0073e9SAndroid Build Coastguard Worker            else:
14673*da0073e9SAndroid Build Coastguard Worker                return 2.0 + x
14674*da0073e9SAndroid Build Coastguard Worker
14675*da0073e9SAndroid Build Coastguard Worker        def test_uses():
14676*da0073e9SAndroid Build Coastguard Worker            return my_conv(1.5), my_conv(1.5, "hi"), my_conv(1.5, 5.0)
14677*da0073e9SAndroid Build Coastguard Worker
14678*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_uses, ())
14679*da0073e9SAndroid Build Coastguard Worker
14680*da0073e9SAndroid Build Coastguard Worker    def test_method_overloading(self):
14681*da0073e9SAndroid Build Coastguard Worker        class Over(torch.nn.Module):
14682*da0073e9SAndroid Build Coastguard Worker            @torch.jit._overload_method  # noqa: F811
14683*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):  # noqa: F811
14684*da0073e9SAndroid Build Coastguard Worker                # type: (Tuple[Tensor, Tensor]) -> Tensor
14685*da0073e9SAndroid Build Coastguard Worker                pass
14686*da0073e9SAndroid Build Coastguard Worker
14687*da0073e9SAndroid Build Coastguard Worker            @torch.jit._overload_method  # noqa: F811
14688*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):  # noqa: F811
14689*da0073e9SAndroid Build Coastguard Worker                # type: (Tensor) -> Tensor
14690*da0073e9SAndroid Build Coastguard Worker                pass
14691*da0073e9SAndroid Build Coastguard Worker
14692*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):  # noqa: F811
14693*da0073e9SAndroid Build Coastguard Worker                if isinstance(x, Tensor):
14694*da0073e9SAndroid Build Coastguard Worker                    return x + 20
14695*da0073e9SAndroid Build Coastguard Worker                else:
14696*da0073e9SAndroid Build Coastguard Worker                    return x[0] + 5
14697*da0073e9SAndroid Build Coastguard Worker
14698*da0073e9SAndroid Build Coastguard Worker        class S(torch.jit.ScriptModule):
14699*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
14700*da0073e9SAndroid Build Coastguard Worker                super().__init__()
14701*da0073e9SAndroid Build Coastguard Worker                self.weak = Over()
14702*da0073e9SAndroid Build Coastguard Worker
14703*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
14704*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
14705*da0073e9SAndroid Build Coastguard Worker                return self.weak(x) + self.weak((x, x))
14706*da0073e9SAndroid Build Coastguard Worker
14707*da0073e9SAndroid Build Coastguard Worker        s_mod = S()
14708*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(1)
14709*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s_mod(x), x + 20 + 5 + x)
14710*da0073e9SAndroid Build Coastguard Worker
14711*da0073e9SAndroid Build Coastguard Worker        over = Over()
14712*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(over((x, x)), x + 5)
14713*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(over(x), x + 20)
14714*da0073e9SAndroid Build Coastguard Worker
14715*da0073e9SAndroid Build Coastguard Worker        class Unannotated(torch.nn.Module):
14716*da0073e9SAndroid Build Coastguard Worker            @torch.jit._overload_method  # noqa: F811
14717*da0073e9SAndroid Build Coastguard Worker            def hello(self, x):  # noqa: F811
14718*da0073e9SAndroid Build Coastguard Worker                pass
14719*da0073e9SAndroid Build Coastguard Worker
14720*da0073e9SAndroid Build Coastguard Worker            @torch.jit._overload_method  # noqa: F811
14721*da0073e9SAndroid Build Coastguard Worker            def hello(self, x):  # noqa: F811
14722*da0073e9SAndroid Build Coastguard Worker                # type: (int) -> (int)
14723*da0073e9SAndroid Build Coastguard Worker                pass
14724*da0073e9SAndroid Build Coastguard Worker
14725*da0073e9SAndroid Build Coastguard Worker            def hello(self, x):  # noqa: F811
14726*da0073e9SAndroid Build Coastguard Worker                return x + 3
14727*da0073e9SAndroid Build Coastguard Worker
14728*da0073e9SAndroid Build Coastguard Worker            def forward(self):
14729*da0073e9SAndroid Build Coastguard Worker                return self.hello(1), self.hello(.5)
14730*da0073e9SAndroid Build Coastguard Worker
14731*da0073e9SAndroid Build Coastguard Worker        w = Unannotated()
14732*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(Exception, "explicitly add type annotations to overloaded functions"):
14733*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(w)
14734*da0073e9SAndroid Build Coastguard Worker
14735*da0073e9SAndroid Build Coastguard Worker        class CompileOverloadError(torch.nn.Module):
14736*da0073e9SAndroid Build Coastguard Worker            @torch.jit._overload_method  # noqa: F811
14737*da0073e9SAndroid Build Coastguard Worker            def hello(self, x):  # noqa: F811
14738*da0073e9SAndroid Build Coastguard Worker                # type: (str) -> (int)
14739*da0073e9SAndroid Build Coastguard Worker                pass
14740*da0073e9SAndroid Build Coastguard Worker
14741*da0073e9SAndroid Build Coastguard Worker            @torch.jit._overload_method  # noqa: F811
14742*da0073e9SAndroid Build Coastguard Worker            def hello(self, x):  # noqa: F811
14743*da0073e9SAndroid Build Coastguard Worker                # type: (int) -> (int)
14744*da0073e9SAndroid Build Coastguard Worker                pass
14745*da0073e9SAndroid Build Coastguard Worker
14746*da0073e9SAndroid Build Coastguard Worker            def hello(self, x):  # noqa: F811
14747*da0073e9SAndroid Build Coastguard Worker                return x + 1
14748*da0073e9SAndroid Build Coastguard Worker
14749*da0073e9SAndroid Build Coastguard Worker            def forward(self):
14750*da0073e9SAndroid Build Coastguard Worker                return self.hello("hi"), self.hello(.5)
14751*da0073e9SAndroid Build Coastguard Worker
14752*da0073e9SAndroid Build Coastguard Worker        w = CompileOverloadError()
14753*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(Exception, "but instead found type 'str'"):
14754*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(w)
14755*da0073e9SAndroid Build Coastguard Worker
14756*da0073e9SAndroid Build Coastguard Worker        # testing overload declared first, then non-overload
14757*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(Exception, "Overloads are not useable when a module"):
14758*da0073e9SAndroid Build Coastguard Worker            class W3(torch.nn.Module):
14759*da0073e9SAndroid Build Coastguard Worker                @torch.jit._overload_method  # noqa: F811
14760*da0073e9SAndroid Build Coastguard Worker                def forward(self, x):  # noqa: F811
14761*da0073e9SAndroid Build Coastguard Worker                    # type: (int) -> int
14762*da0073e9SAndroid Build Coastguard Worker                    pass
14763*da0073e9SAndroid Build Coastguard Worker
14764*da0073e9SAndroid Build Coastguard Worker                @torch.jit._overload_method  # noqa: F811
14765*da0073e9SAndroid Build Coastguard Worker                def forward(self, x):  # noqa: F811
14766*da0073e9SAndroid Build Coastguard Worker                    # type: (Tensor) -> Tensor
14767*da0073e9SAndroid Build Coastguard Worker                    pass
14768*da0073e9SAndroid Build Coastguard Worker
14769*da0073e9SAndroid Build Coastguard Worker                def forward(self, x):  # noqa: F811
14770*da0073e9SAndroid Build Coastguard Worker                    return x + 5
14771*da0073e9SAndroid Build Coastguard Worker
14772*da0073e9SAndroid Build Coastguard Worker            a = W3()
14773*da0073e9SAndroid Build Coastguard Worker            b = torch.jit.script(a)
14774*da0073e9SAndroid Build Coastguard Worker
14775*da0073e9SAndroid Build Coastguard Worker            class W3(torch.nn.Module):
14776*da0073e9SAndroid Build Coastguard Worker                def forward(self, x):  # noqa: F811
14777*da0073e9SAndroid Build Coastguard Worker                    return x + 5 + 10
14778*da0073e9SAndroid Build Coastguard Worker
14779*da0073e9SAndroid Build Coastguard Worker            a = W3()
14780*da0073e9SAndroid Build Coastguard Worker            b = torch.jit.script(a)
14781*da0073e9SAndroid Build Coastguard Worker
14782*da0073e9SAndroid Build Coastguard Worker        # testing non-overload declared first, then overload
14783*da0073e9SAndroid Build Coastguard Worker        class W2(torch.nn.Module):
14784*da0073e9SAndroid Build Coastguard Worker            def hello(self, x1, x2):
14785*da0073e9SAndroid Build Coastguard Worker                return x1 + x2
14786*da0073e9SAndroid Build Coastguard Worker
14787*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
14788*da0073e9SAndroid Build Coastguard Worker                return self.hello(x, x)
14789*da0073e9SAndroid Build Coastguard Worker
14790*da0073e9SAndroid Build Coastguard Worker        a = torch.jit.script(W2())
14791*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a(torch.tensor(1)), torch.tensor(2))
14792*da0073e9SAndroid Build Coastguard Worker
14793*da0073e9SAndroid Build Coastguard Worker        class W2(torch.nn.Module):
14794*da0073e9SAndroid Build Coastguard Worker            @torch.jit._overload_method  # noqa: F811
14795*da0073e9SAndroid Build Coastguard Worker            def hello(self, x):  # noqa: F811
14796*da0073e9SAndroid Build Coastguard Worker                pass
14797*da0073e9SAndroid Build Coastguard Worker
14798*da0073e9SAndroid Build Coastguard Worker            @torch.jit._overload_method  # noqa: F811
14799*da0073e9SAndroid Build Coastguard Worker            def hello(self, x):  # noqa: F811
14800*da0073e9SAndroid Build Coastguard Worker                # type: (int) -> (int)
14801*da0073e9SAndroid Build Coastguard Worker                pass
14802*da0073e9SAndroid Build Coastguard Worker
14803*da0073e9SAndroid Build Coastguard Worker            def hello(self, x):  # noqa: F811
14804*da0073e9SAndroid Build Coastguard Worker                return x + 5 + 10
14805*da0073e9SAndroid Build Coastguard Worker
14806*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
14807*da0073e9SAndroid Build Coastguard Worker                return self.hello(1), self.hello(x)
14808*da0073e9SAndroid Build Coastguard Worker
14809*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(Exception, "Overloads are not useable when a module"):
14810*da0073e9SAndroid Build Coastguard Worker            a = torch.jit.script(W2())
14811*da0073e9SAndroid Build Coastguard Worker
14812*da0073e9SAndroid Build Coastguard Worker    def test_narrow_copy(self):
14813*da0073e9SAndroid Build Coastguard Worker        def foo(a):
14814*da0073e9SAndroid Build Coastguard Worker            return a.narrow_copy(0, 0, 5)
14815*da0073e9SAndroid Build Coastguard Worker
14816*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, [torch.rand(10)])
14817*da0073e9SAndroid Build Coastguard Worker
14818*da0073e9SAndroid Build Coastguard Worker    def test_select_after_chunk(self):
14819*da0073e9SAndroid Build Coastguard Worker        def foo(x):
14820*da0073e9SAndroid Build Coastguard Worker            chunked = torch.chunk(x, 1)
14821*da0073e9SAndroid Build Coastguard Worker            foo = chunked[0]
14822*da0073e9SAndroid Build Coastguard Worker            foo.add_(5)
14823*da0073e9SAndroid Build Coastguard Worker            return x
14824*da0073e9SAndroid Build Coastguard Worker
14825*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, [torch.rand(2, 3)])
14826*da0073e9SAndroid Build Coastguard Worker
14827*da0073e9SAndroid Build Coastguard Worker    def test_nn_LSTM_with_layers(self):
14828*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
14829*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
14830*da0073e9SAndroid Build Coastguard Worker                super().__init__()
14831*da0073e9SAndroid Build Coastguard Worker                self.rnn = nn.LSTM(2, 3, 2, dropout=0)
14832*da0073e9SAndroid Build Coastguard Worker
14833*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
14834*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, lengths, h0, c0):
14835*da0073e9SAndroid Build Coastguard Worker                return self.rnn(x, (h0, c0))[0]
14836*da0073e9SAndroid Build Coastguard Worker
14837*da0073e9SAndroid Build Coastguard Worker        class Eager(torch.nn.Module):
14838*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
14839*da0073e9SAndroid Build Coastguard Worker                super().__init__()
14840*da0073e9SAndroid Build Coastguard Worker                self.rnn = nn.LSTM(2, 3, 2, dropout=0)
14841*da0073e9SAndroid Build Coastguard Worker
14842*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, lengths, h0, c0):
14843*da0073e9SAndroid Build Coastguard Worker                return self.rnn(x, (h0, c0))[0]
14844*da0073e9SAndroid Build Coastguard Worker
14845*da0073e9SAndroid Build Coastguard Worker        inputs = (torch.randn(1, 1, 2), torch.LongTensor([7]), torch.randn(2, 1, 3), torch.randn(2, 1, 3))
14846*da0073e9SAndroid Build Coastguard Worker        eager_out = self.runAndSaveRNG(lambda: Eager()(*inputs), ())[0]
14847*da0073e9SAndroid Build Coastguard Worker        script_out = self.runAndSaveRNG(lambda: M()(*inputs), ())[0]
14848*da0073e9SAndroid Build Coastguard Worker
14849*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager_out, script_out)
14850*da0073e9SAndroid Build Coastguard Worker
14851*da0073e9SAndroid Build Coastguard Worker    def test_nn_LSTM(self):
14852*da0073e9SAndroid Build Coastguard Worker        input = torch.nn.utils.rnn.pack_sequence([torch.randn(5, 5)])
14853*da0073e9SAndroid Build Coastguard Worker
14854*da0073e9SAndroid Build Coastguard Worker        class S(torch.jit.ScriptModule):
14855*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
14856*da0073e9SAndroid Build Coastguard Worker                super().__init__()
14857*da0073e9SAndroid Build Coastguard Worker                self.x = torch.nn.LSTM(5, 5)
14858*da0073e9SAndroid Build Coastguard Worker
14859*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
14860*da0073e9SAndroid Build Coastguard Worker            def forward(self, input: PackedSequence) -> Tuple[PackedSequence, Tuple[torch.Tensor, torch.Tensor]]:
14861*da0073e9SAndroid Build Coastguard Worker                return self.x(input)
14862*da0073e9SAndroid Build Coastguard Worker
14863*da0073e9SAndroid Build Coastguard Worker        eager_out = self.runAndSaveRNG(lambda x: torch.nn.LSTM(5, 5)(x), (input,))[0]
14864*da0073e9SAndroid Build Coastguard Worker        script_out = self.runAndSaveRNG(lambda x: S()(x), (input,))[0]
14865*da0073e9SAndroid Build Coastguard Worker
14866*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager_out, script_out)
14867*da0073e9SAndroid Build Coastguard Worker
14868*da0073e9SAndroid Build Coastguard Worker    def test_nn_GRU(self):
14869*da0073e9SAndroid Build Coastguard Worker        seq_input = torch.nn.utils.rnn.pack_sequence([torch.randn(5, 5)])
14870*da0073e9SAndroid Build Coastguard Worker        tensor_input = torch.randn(5, 5, 5)
14871*da0073e9SAndroid Build Coastguard Worker
14872*da0073e9SAndroid Build Coastguard Worker        class SeqLengthGRU(torch.jit.ScriptModule):
14873*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
14874*da0073e9SAndroid Build Coastguard Worker                super().__init__()
14875*da0073e9SAndroid Build Coastguard Worker                self.x = torch.nn.GRU(5, 5)
14876*da0073e9SAndroid Build Coastguard Worker
14877*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
14878*da0073e9SAndroid Build Coastguard Worker            def forward(self, input: PackedSequence) -> Tuple[PackedSequence, torch.Tensor]:
14879*da0073e9SAndroid Build Coastguard Worker                return self.x(input)
14880*da0073e9SAndroid Build Coastguard Worker
14881*da0073e9SAndroid Build Coastguard Worker        class TensorGRU(torch.jit.ScriptModule):
14882*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
14883*da0073e9SAndroid Build Coastguard Worker                super().__init__()
14884*da0073e9SAndroid Build Coastguard Worker                self.x = torch.nn.GRU(5, 5)
14885*da0073e9SAndroid Build Coastguard Worker
14886*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
14887*da0073e9SAndroid Build Coastguard Worker            def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
14888*da0073e9SAndroid Build Coastguard Worker                return self.x(input)
14889*da0073e9SAndroid Build Coastguard Worker
14890*da0073e9SAndroid Build Coastguard Worker        seq_eager_out = self.runAndSaveRNG(lambda x: torch.nn.GRU(5, 5)(x), (seq_input,))[0]
14891*da0073e9SAndroid Build Coastguard Worker        seq_script_out = self.runAndSaveRNG(lambda x: SeqLengthGRU()(x), (seq_input,))[0]
14892*da0073e9SAndroid Build Coastguard Worker        tensor_eager_out = self.runAndSaveRNG(lambda x: torch.nn.GRU(5, 5)(x), (tensor_input,))[0]
14893*da0073e9SAndroid Build Coastguard Worker        tensor_script_out = self.runAndSaveRNG(lambda x: TensorGRU()(x), (tensor_input,))[0]
14894*da0073e9SAndroid Build Coastguard Worker
14895*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(seq_eager_out, seq_script_out)
14896*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor_eager_out, tensor_script_out)
14897*da0073e9SAndroid Build Coastguard Worker
14898*da0073e9SAndroid Build Coastguard Worker    def test_torchscript_memoryformat(self):
14899*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
14900*da0073e9SAndroid Build Coastguard Worker        def fn(x):
14901*da0073e9SAndroid Build Coastguard Worker            return x.contiguous(memory_format=torch.channels_last)
14902*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 3, 6, 6)
14903*da0073e9SAndroid Build Coastguard Worker        y = fn(x)
14904*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(y.is_contiguous(memory_format=torch.channels_last))
14905*da0073e9SAndroid Build Coastguard Worker
14906*da0073e9SAndroid Build Coastguard Worker    def test_torchscript_multi_head_attn(self):
14907*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
14908*da0073e9SAndroid Build Coastguard Worker        def jit_multihead_attn_forward(query,                   # type: Tensor
14909*da0073e9SAndroid Build Coastguard Worker                                       key,                     # type: Tensor
14910*da0073e9SAndroid Build Coastguard Worker                                       value,                   # type: Tensor
14911*da0073e9SAndroid Build Coastguard Worker                                       embed_dim_to_check,      # type: int
14912*da0073e9SAndroid Build Coastguard Worker                                       num_heads,               # type: int
14913*da0073e9SAndroid Build Coastguard Worker                                       in_proj_weight,          # type: Tensor
14914*da0073e9SAndroid Build Coastguard Worker                                       in_proj_bias,            # type: Tensor
14915*da0073e9SAndroid Build Coastguard Worker                                       bias_k,                  # type: Optional[Tensor]
14916*da0073e9SAndroid Build Coastguard Worker                                       bias_v,                  # type: Optional[Tensor]
14917*da0073e9SAndroid Build Coastguard Worker                                       add_zero_attn,           # type: bool
14918*da0073e9SAndroid Build Coastguard Worker                                       dropout,                 # type: float
14919*da0073e9SAndroid Build Coastguard Worker                                       out_proj_weight,         # type: Tensor
14920*da0073e9SAndroid Build Coastguard Worker                                       out_proj_bias,           # type: Tensor
14921*da0073e9SAndroid Build Coastguard Worker                                       training=True,           # type: bool
14922*da0073e9SAndroid Build Coastguard Worker                                       key_padding_mask=None,   # type: Optional[Tensor]
14923*da0073e9SAndroid Build Coastguard Worker                                       need_weights=True,       # type: bool
14924*da0073e9SAndroid Build Coastguard Worker                                       attn_mask=None           # type: Optional[Tensor]
14925*da0073e9SAndroid Build Coastguard Worker                                       ):
14926*da0073e9SAndroid Build Coastguard Worker            # type: (...) -> Tuple[Tensor, Optional[Tensor]]
14927*da0073e9SAndroid Build Coastguard Worker            return torch.nn.functional.multi_head_attention_forward(query, key, value,
14928*da0073e9SAndroid Build Coastguard Worker                                                                    embed_dim_to_check, num_heads,
14929*da0073e9SAndroid Build Coastguard Worker                                                                    in_proj_weight, in_proj_bias,
14930*da0073e9SAndroid Build Coastguard Worker                                                                    bias_k, bias_v,
14931*da0073e9SAndroid Build Coastguard Worker                                                                    add_zero_attn, dropout,
14932*da0073e9SAndroid Build Coastguard Worker                                                                    out_proj_weight, out_proj_bias,
14933*da0073e9SAndroid Build Coastguard Worker                                                                    training, key_padding_mask,
14934*da0073e9SAndroid Build Coastguard Worker                                                                    need_weights, attn_mask)
14935*da0073e9SAndroid Build Coastguard Worker
14936*da0073e9SAndroid Build Coastguard Worker        src_l = 3
14937*da0073e9SAndroid Build Coastguard Worker        bsz = 5
14938*da0073e9SAndroid Build Coastguard Worker        embed_size = 8
14939*da0073e9SAndroid Build Coastguard Worker        nhead = 2
14940*da0073e9SAndroid Build Coastguard Worker        multi_head_attn = torch.nn.MultiheadAttention(embed_size, nhead)
14941*da0073e9SAndroid Build Coastguard Worker        query = torch.rand((src_l, bsz, embed_size))
14942*da0073e9SAndroid Build Coastguard Worker        key = torch.rand((src_l, bsz, embed_size))
14943*da0073e9SAndroid Build Coastguard Worker        value = torch.rand((src_l, bsz, embed_size))
14944*da0073e9SAndroid Build Coastguard Worker
14945*da0073e9SAndroid Build Coastguard Worker        mask = (torch.triu(torch.ones(src_l, src_l)) == 1).transpose(0, 1)
14946*da0073e9SAndroid Build Coastguard Worker        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, 0.0).to(torch.get_default_dtype())
14947*da0073e9SAndroid Build Coastguard Worker
14948*da0073e9SAndroid Build Coastguard Worker        jit_out = jit_multihead_attn_forward(query, key, value,
14949*da0073e9SAndroid Build Coastguard Worker                                             embed_size, nhead,
14950*da0073e9SAndroid Build Coastguard Worker                                             multi_head_attn.in_proj_weight,
14951*da0073e9SAndroid Build Coastguard Worker                                             multi_head_attn.in_proj_bias,
14952*da0073e9SAndroid Build Coastguard Worker                                             multi_head_attn.bias_k, multi_head_attn.bias_v,
14953*da0073e9SAndroid Build Coastguard Worker                                             multi_head_attn.add_zero_attn, multi_head_attn.dropout,
14954*da0073e9SAndroid Build Coastguard Worker                                             multi_head_attn.out_proj.weight,
14955*da0073e9SAndroid Build Coastguard Worker                                             multi_head_attn.out_proj.bias, attn_mask=mask)[0]
14956*da0073e9SAndroid Build Coastguard Worker
14957*da0073e9SAndroid Build Coastguard Worker        py_out = torch.nn.functional.multi_head_attention_forward(query, key, value,
14958*da0073e9SAndroid Build Coastguard Worker                                                                  embed_size, nhead,
14959*da0073e9SAndroid Build Coastguard Worker                                                                  multi_head_attn.in_proj_weight,
14960*da0073e9SAndroid Build Coastguard Worker                                                                  multi_head_attn.in_proj_bias,
14961*da0073e9SAndroid Build Coastguard Worker                                                                  multi_head_attn.bias_k,
14962*da0073e9SAndroid Build Coastguard Worker                                                                  multi_head_attn.bias_v,
14963*da0073e9SAndroid Build Coastguard Worker                                                                  multi_head_attn.add_zero_attn,
14964*da0073e9SAndroid Build Coastguard Worker                                                                  multi_head_attn.dropout,
14965*da0073e9SAndroid Build Coastguard Worker                                                                  multi_head_attn.out_proj.weight,
14966*da0073e9SAndroid Build Coastguard Worker                                                                  multi_head_attn.out_proj.bias,
14967*da0073e9SAndroid Build Coastguard Worker                                                                  attn_mask=mask)[0]
14968*da0073e9SAndroid Build Coastguard Worker        # print("rel. error: ")
14969*da0073e9SAndroid Build Coastguard Worker        # print(jit_out / py_out - 1)
14970*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(jit_out, py_out, atol=5e-4, rtol=1e-4)
14971*da0073e9SAndroid Build Coastguard Worker
14972*da0073e9SAndroid Build Coastguard Worker    def test_torchscript_multi_head_attn_fast_path(self):
14973*da0073e9SAndroid Build Coastguard Worker        src_l = 3
14974*da0073e9SAndroid Build Coastguard Worker        bsz = 5
14975*da0073e9SAndroid Build Coastguard Worker        embed_size = 8
14976*da0073e9SAndroid Build Coastguard Worker        nhead = 2
14977*da0073e9SAndroid Build Coastguard Worker        multi_head_attn = torch.nn.MultiheadAttention(embed_size, nhead, batch_first=True)
14978*da0073e9SAndroid Build Coastguard Worker        multi_head_attn = multi_head_attn.eval()
14979*da0073e9SAndroid Build Coastguard Worker
14980*da0073e9SAndroid Build Coastguard Worker        query = key = value = torch.rand((bsz, src_l, embed_size))
14981*da0073e9SAndroid Build Coastguard Worker
14982*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
14983*da0073e9SAndroid Build Coastguard Worker            py_out = multi_head_attn(query, key, value)
14984*da0073e9SAndroid Build Coastguard Worker            mha = torch.jit.script(multi_head_attn)
14985*da0073e9SAndroid Build Coastguard Worker            jit_out = mha(query, key, value)
14986*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(jit_out, py_out)
14987*da0073e9SAndroid Build Coastguard Worker
14988*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "no CUDA")
14989*da0073e9SAndroid Build Coastguard Worker    def test_scriptmodule_multi_head_attn_cuda(self):
14990*da0073e9SAndroid Build Coastguard Worker
14991*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.jit.ScriptModule):
14992*da0073e9SAndroid Build Coastguard Worker            def __init__(self, embed_dim, num_heads):
14993*da0073e9SAndroid Build Coastguard Worker                super().__init__()
14994*da0073e9SAndroid Build Coastguard Worker                sample_q = torch.randn(3, 2, embed_dim)
14995*da0073e9SAndroid Build Coastguard Worker                sample_kv = torch.randn(3, 2, embed_dim)
14996*da0073e9SAndroid Build Coastguard Worker                attention = nn.MultiheadAttention(embed_dim, num_heads)
14997*da0073e9SAndroid Build Coastguard Worker                attention.eval()
14998*da0073e9SAndroid Build Coastguard Worker
14999*da0073e9SAndroid Build Coastguard Worker                self.mod = torch.jit.trace(attention,
15000*da0073e9SAndroid Build Coastguard Worker                                           (sample_q, sample_kv, sample_kv))
15001*da0073e9SAndroid Build Coastguard Worker
15002*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
15003*da0073e9SAndroid Build Coastguard Worker            def forward(self, q, k, v):
15004*da0073e9SAndroid Build Coastguard Worker                return self.mod(q, k, v)
15005*da0073e9SAndroid Build Coastguard Worker
15006*da0073e9SAndroid Build Coastguard Worker        embed_dim = 8
15007*da0073e9SAndroid Build Coastguard Worker        num_heads = 2
15008*da0073e9SAndroid Build Coastguard Worker        sl = 3
15009*da0073e9SAndroid Build Coastguard Worker        bs = 2
15010*da0073e9SAndroid Build Coastguard Worker        model = MyModule(embed_dim, num_heads).cuda()
15011*da0073e9SAndroid Build Coastguard Worker        q = torch.randn(sl, bs, embed_dim, device="cuda")
15012*da0073e9SAndroid Build Coastguard Worker        kv = torch.randn(sl, bs, embed_dim, device="cuda")
15013*da0073e9SAndroid Build Coastguard Worker
15014*da0073e9SAndroid Build Coastguard Worker        jit_out = model(q, kv, kv)[0]
15015*da0073e9SAndroid Build Coastguard Worker        py_out = torch.nn.functional.multi_head_attention_forward(q, kv, kv,
15016*da0073e9SAndroid Build Coastguard Worker                                                                  embed_dim, num_heads,
15017*da0073e9SAndroid Build Coastguard Worker                                                                  model.mod.in_proj_weight,
15018*da0073e9SAndroid Build Coastguard Worker                                                                  model.mod.in_proj_bias,
15019*da0073e9SAndroid Build Coastguard Worker                                                                  None, None, None, 0.0,
15020*da0073e9SAndroid Build Coastguard Worker                                                                  model.mod.out_proj.weight,
15021*da0073e9SAndroid Build Coastguard Worker                                                                  model.mod.out_proj.bias)[0]
15022*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(jit_out, py_out, atol=5e-4, rtol=1e-4)
15023*da0073e9SAndroid Build Coastguard Worker
15024*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "no CUDA")
15025*da0073e9SAndroid Build Coastguard Worker    def test_scriptmodule_transformer_cuda(self):
15026*da0073e9SAndroid Build Coastguard Worker
15027*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.jit.ScriptModule):
15028*da0073e9SAndroid Build Coastguard Worker            def __init__(self, transformer, sample_q, sample_kv):
15029*da0073e9SAndroid Build Coastguard Worker                super().__init__()
15030*da0073e9SAndroid Build Coastguard Worker                transformer.eval()
15031*da0073e9SAndroid Build Coastguard Worker
15032*da0073e9SAndroid Build Coastguard Worker                self.mod = torch.jit.trace(transformer,
15033*da0073e9SAndroid Build Coastguard Worker                                           (sample_q, sample_kv))
15034*da0073e9SAndroid Build Coastguard Worker
15035*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
15036*da0073e9SAndroid Build Coastguard Worker            def forward(self, q, k):
15037*da0073e9SAndroid Build Coastguard Worker                return self.mod(q, k)
15038*da0073e9SAndroid Build Coastguard Worker
15039*da0073e9SAndroid Build Coastguard Worker        d_model = 8
15040*da0073e9SAndroid Build Coastguard Worker        nhead = 2
15041*da0073e9SAndroid Build Coastguard Worker        num_encoder_layers = 2
15042*da0073e9SAndroid Build Coastguard Worker        num_decoder_layers = 2
15043*da0073e9SAndroid Build Coastguard Worker        dim_feedforward = 16
15044*da0073e9SAndroid Build Coastguard Worker        bsz = 2
15045*da0073e9SAndroid Build Coastguard Worker        seq_length = 5
15046*da0073e9SAndroid Build Coastguard Worker        tgt_length = 3
15047*da0073e9SAndroid Build Coastguard Worker
15048*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
15049*da0073e9SAndroid Build Coastguard Worker            src = torch.randn(seq_length, bsz, d_model)
15050*da0073e9SAndroid Build Coastguard Worker            tgt = torch.randn(tgt_length, bsz, d_model)
15051*da0073e9SAndroid Build Coastguard Worker            transformer = nn.Transformer(d_model, nhead, num_encoder_layers,
15052*da0073e9SAndroid Build Coastguard Worker                                         num_decoder_layers, dim_feedforward, dropout=0.0)
15053*da0073e9SAndroid Build Coastguard Worker            model = MyModule(transformer, tgt, src)
15054*da0073e9SAndroid Build Coastguard Worker
15055*da0073e9SAndroid Build Coastguard Worker            src = torch.randn(seq_length, bsz, d_model)
15056*da0073e9SAndroid Build Coastguard Worker            tgt = torch.randn(tgt_length, bsz, d_model)
15057*da0073e9SAndroid Build Coastguard Worker            jit_out = model(tgt, src)
15058*da0073e9SAndroid Build Coastguard Worker            py_out = transformer(tgt, src)
15059*da0073e9SAndroid Build Coastguard Worker
15060*da0073e9SAndroid Build Coastguard Worker            # print(jit_out/py_out-1)
15061*da0073e9SAndroid Build Coastguard Worker            # print(torch.allclose(jit_out, py_out, atol=5e-4, rtol=1e-4))
15062*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(jit_out, py_out, atol=5e-4, rtol=1e-4)
15063*da0073e9SAndroid Build Coastguard Worker
15064*da0073e9SAndroid Build Coastguard Worker    def test_list_python_op(self):
15065*da0073e9SAndroid Build Coastguard Worker        def python_list_op(lst):
15066*da0073e9SAndroid Build Coastguard Worker            # type: (List[Tensor]) -> Tensor
15067*da0073e9SAndroid Build Coastguard Worker            return lst[0]
15068*da0073e9SAndroid Build Coastguard Worker
15069*da0073e9SAndroid Build Coastguard Worker        def fn(lst):
15070*da0073e9SAndroid Build Coastguard Worker            # type: (List[Tensor]) -> Tensor
15071*da0073e9SAndroid Build Coastguard Worker            return python_list_op(lst)
15072*da0073e9SAndroid Build Coastguard Worker
15073*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ([torch.ones(2) + 2, torch.ones(2)],))
15074*da0073e9SAndroid Build Coastguard Worker
15075*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "no CUDA")
15076*da0073e9SAndroid Build Coastguard Worker    def test_weak_cuda(self):
15077*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
15078*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
15079*da0073e9SAndroid Build Coastguard Worker                super().__init__()
15080*da0073e9SAndroid Build Coastguard Worker                self.lstm = torch.nn.LSTM(5, 5)
15081*da0073e9SAndroid Build Coastguard Worker                self.lstm.cuda()
15082*da0073e9SAndroid Build Coastguard Worker
15083*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
15084*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
15085*da0073e9SAndroid Build Coastguard Worker                return self.lstm(x)
15086*da0073e9SAndroid Build Coastguard Worker
15087*da0073e9SAndroid Build Coastguard Worker        m = M()
15088*da0073e9SAndroid Build Coastguard Worker        m.cuda()
15089*da0073e9SAndroid Build Coastguard Worker        out = m(torch.ones(5, 5, 5).cuda())
15090*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(out[0].is_cuda)
15091*da0073e9SAndroid Build Coastguard Worker
15092*da0073e9SAndroid Build Coastguard Worker    def test_ignore_decorator(self):
15093*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as warns:
15094*da0073e9SAndroid Build Coastguard Worker            class M(torch.jit.ScriptModule):
15095*da0073e9SAndroid Build Coastguard Worker                def __init__(self) -> None:
15096*da0073e9SAndroid Build Coastguard Worker                    super().__init__()
15097*da0073e9SAndroid Build Coastguard Worker                    tensor = torch.zeros(1, requires_grad=False)
15098*da0073e9SAndroid Build Coastguard Worker                    self.some_state = nn.Buffer(torch.nn.Parameter(tensor))
15099*da0073e9SAndroid Build Coastguard Worker
15100*da0073e9SAndroid Build Coastguard Worker                @torch.jit.script_method
15101*da0073e9SAndroid Build Coastguard Worker                def forward(self, x):
15102*da0073e9SAndroid Build Coastguard Worker                    self.ignored_code(x)
15103*da0073e9SAndroid Build Coastguard Worker                    return x
15104*da0073e9SAndroid Build Coastguard Worker
15105*da0073e9SAndroid Build Coastguard Worker                @torch.jit.ignore(drop_on_export=True)
15106*da0073e9SAndroid Build Coastguard Worker                def ignored_code(self, x):
15107*da0073e9SAndroid Build Coastguard Worker                    self.some_state = torch.tensor((100,))
15108*da0073e9SAndroid Build Coastguard Worker
15109*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("TorchScript will now drop the function").run(str(warns[0]))
15110*da0073e9SAndroid Build Coastguard Worker
15111*da0073e9SAndroid Build Coastguard Worker        # Assert ignored code is run
15112*da0073e9SAndroid Build Coastguard Worker        m = M()
15113*da0073e9SAndroid Build Coastguard Worker
15114*da0073e9SAndroid Build Coastguard Worker        m2 = self.getExportImportCopy(m)
15115*da0073e9SAndroid Build Coastguard Worker        pp = str(m2.forward.code)
15116*da0073e9SAndroid Build Coastguard Worker        self.assertNotIn('ignored_code', pp)
15117*da0073e9SAndroid Build Coastguard Worker
15118*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(torch.jit.Error, "annotated to be ignored and cannot be run"):
15119*da0073e9SAndroid Build Coastguard Worker            m2.forward(torch.ones(1))
15120*da0073e9SAndroid Build Coastguard Worker
15121*da0073e9SAndroid Build Coastguard Worker    def test_ignored_as_value(self):
15122*da0073e9SAndroid Build Coastguard Worker        class Model(nn.Module):
15123*da0073e9SAndroid Build Coastguard Worker            @torch.jit.unused
15124*da0073e9SAndroid Build Coastguard Worker            def tuple_ignored(self, x):
15125*da0073e9SAndroid Build Coastguard Worker                # type: (Tensor) -> Tuple[Tensor, Tensor]
15126*da0073e9SAndroid Build Coastguard Worker                return x, x
15127*da0073e9SAndroid Build Coastguard Worker
15128*da0073e9SAndroid Build Coastguard Worker            @torch.jit.unused
15129*da0073e9SAndroid Build Coastguard Worker            def single_val_ignored(self, x, y):
15130*da0073e9SAndroid Build Coastguard Worker                # type: (Tensor, Tensor) -> Tensor
15131*da0073e9SAndroid Build Coastguard Worker                return x
15132*da0073e9SAndroid Build Coastguard Worker
15133*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, use_ignore_path):
15134*da0073e9SAndroid Build Coastguard Worker                # type: (Tensor, bool) -> Tuple[Tensor, Tensor]
15135*da0073e9SAndroid Build Coastguard Worker                if 1 == 2:
15136*da0073e9SAndroid Build Coastguard Worker                    return self.tuple_ignored(x)
15137*da0073e9SAndroid Build Coastguard Worker                if use_ignore_path:
15138*da0073e9SAndroid Build Coastguard Worker                    return self.single_val_ignored(x, x), self.single_val_ignored(x, x)
15139*da0073e9SAndroid Build Coastguard Worker                return x, x
15140*da0073e9SAndroid Build Coastguard Worker
15141*da0073e9SAndroid Build Coastguard Worker        original = Model()
15142*da0073e9SAndroid Build Coastguard Worker        scripted = torch.jit.script(original)
15143*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted(torch.tensor(.5), False), (torch.tensor(.5), torch.tensor(.5)))
15144*da0073e9SAndroid Build Coastguard Worker
15145*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO()
15146*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(scripted, buffer)
15147*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
15148*da0073e9SAndroid Build Coastguard Worker        loaded = torch.jit.load(buffer)
15149*da0073e9SAndroid Build Coastguard Worker
15150*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(torch.jit.Error, "annotated to be ignored and cannot be run"):
15151*da0073e9SAndroid Build Coastguard Worker            loaded(torch.tensor(.5), True)
15152*da0073e9SAndroid Build Coastguard Worker
15153*da0073e9SAndroid Build Coastguard Worker    def test_module_error(self):
15154*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
15155*da0073e9SAndroid Build Coastguard Worker            def forward(self, foo):
15156*da0073e9SAndroid Build Coastguard Worker                return foo
15157*da0073e9SAndroid Build Coastguard Worker
15158*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "cannot be compiled since it inherits from nn.Module"):
15159*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(MyModule)
15160*da0073e9SAndroid Build Coastguard Worker
15161*da0073e9SAndroid Build Coastguard Worker    def test_view_write(self):
15162*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
15163*da0073e9SAndroid Build Coastguard Worker            l = []
15164*da0073e9SAndroid Build Coastguard Worker            l.append(x)
15165*da0073e9SAndroid Build Coastguard Worker            x_view = l[0]
15166*da0073e9SAndroid Build Coastguard Worker            a = x + x
15167*da0073e9SAndroid Build Coastguard Worker            x_view.add_(y)
15168*da0073e9SAndroid Build Coastguard Worker            b = x + x
15169*da0073e9SAndroid Build Coastguard Worker            return a == b
15170*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (torch.rand(2, 3), torch.rand(2, 3)))
15171*da0073e9SAndroid Build Coastguard Worker
15172*da0073e9SAndroid Build Coastguard Worker    def test_module_attrs(self):
15173*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
15174*da0073e9SAndroid Build Coastguard Worker            def __init__(self, table):
15175*da0073e9SAndroid Build Coastguard Worker                super().__init__()
15176*da0073e9SAndroid Build Coastguard Worker                self.table = torch.jit.Attribute(table, Dict[str, torch.Tensor])
15177*da0073e9SAndroid Build Coastguard Worker                self.x = torch.nn.Parameter(torch.tensor([100.0]))
15178*da0073e9SAndroid Build Coastguard Worker
15179*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
15180*da0073e9SAndroid Build Coastguard Worker            def forward(self, key):
15181*da0073e9SAndroid Build Coastguard Worker                # type: (str) -> Tensor
15182*da0073e9SAndroid Build Coastguard Worker                return self.table[key] + self.x
15183*da0073e9SAndroid Build Coastguard Worker
15184*da0073e9SAndroid Build Coastguard Worker        with torch._jit_internal._disable_emit_hooks():
15185*da0073e9SAndroid Build Coastguard Worker            # TODO: re-enable module hook when Python printing of attributes is
15186*da0073e9SAndroid Build Coastguard Worker            # supported
15187*da0073e9SAndroid Build Coastguard Worker            m = M({char : torch.ones(1) + ord(char) - ord("a") for char in "abcdefg"})
15188*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(m("c"), torch.tensor([103.]))
15189*da0073e9SAndroid Build Coastguard Worker
15190*da0073e9SAndroid Build Coastguard Worker    def test_module_none_attrs(self):
15191*da0073e9SAndroid Build Coastguard Worker        class MyMod(torch.jit.ScriptModule):
15192*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
15193*da0073e9SAndroid Build Coastguard Worker                super().__init__()
15194*da0073e9SAndroid Build Coastguard Worker                self.optional_value = None
15195*da0073e9SAndroid Build Coastguard Worker
15196*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
15197*da0073e9SAndroid Build Coastguard Worker            def forward(self):
15198*da0073e9SAndroid Build Coastguard Worker                return self.optional_value
15199*da0073e9SAndroid Build Coastguard Worker
15200*da0073e9SAndroid Build Coastguard Worker        graph = MyMod().forward.graph
15201*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("prim::GetAttr").run(graph)
15202*da0073e9SAndroid Build Coastguard Worker        self.run_pass('peephole', graph)
15203*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("prim::GetAttr").run(graph)
15204*da0073e9SAndroid Build Coastguard Worker
15205*da0073e9SAndroid Build Coastguard Worker    def test_tensor_import_export(self):
15206*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
15207*da0073e9SAndroid Build Coastguard Worker        def foo(x):
15208*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor(1)
15209*da0073e9SAndroid Build Coastguard Worker            b = torch.tensor([1, 2])
15210*da0073e9SAndroid Build Coastguard Worker            c = [a, b]
15211*da0073e9SAndroid Build Coastguard Worker            return c
15212*da0073e9SAndroid Build Coastguard Worker
15213*da0073e9SAndroid Build Coastguard Worker        self.run_pass('constant_propagation', foo.graph)
15214*da0073e9SAndroid Build Coastguard Worker        m = self.createFunctionFromGraph(foo.graph)
15215*da0073e9SAndroid Build Coastguard Worker        self.getExportImportCopy(m)
15216*da0073e9SAndroid Build Coastguard Worker
15217*da0073e9SAndroid Build Coastguard Worker    def get_pickle_values(self):
15218*da0073e9SAndroid Build Coastguard Worker        return (('dict', {"I": "am", "a test": "test"}, Dict[str, str]),
15219*da0073e9SAndroid Build Coastguard Worker                ('float', 2.3, float),
15220*da0073e9SAndroid Build Coastguard Worker                ('int', 99, int),
15221*da0073e9SAndroid Build Coastguard Worker                ('bool', False, bool),
15222*da0073e9SAndroid Build Coastguard Worker                ('tuple', (1, 2, 3, 4), Tuple[int, int, int, int]),
15223*da0073e9SAndroid Build Coastguard Worker                ('list', [(1, 2), (3, 4)], List[Tuple[int, int]]),
15224*da0073e9SAndroid Build Coastguard Worker                ('tensor', torch.randn(2, 2), torch.Tensor),
15225*da0073e9SAndroid Build Coastguard Worker                ('int_list', [1, 2, 3, 4], List[int]),
15226*da0073e9SAndroid Build Coastguard Worker                ('tensor_list', [torch.ones(2, 2) + i for i in range(4)], List[torch.Tensor]),
15227*da0073e9SAndroid Build Coastguard Worker                ('bool_list', [True, True, False, True], List[bool]),
15228*da0073e9SAndroid Build Coastguard Worker                ('float_list', [1., 2., 3., 4.], List[float]),
15229*da0073e9SAndroid Build Coastguard Worker                ('str_list', ['hello', 'bye'], List[str]),
15230*da0073e9SAndroid Build Coastguard Worker                ('none', None, Optional[int]),
15231*da0073e9SAndroid Build Coastguard Worker                ('a_device', torch.device('cpu'), torch.device),
15232*da0073e9SAndroid Build Coastguard Worker                ('another_device', torch.device('cuda:1'), torch.device))
15233*da0073e9SAndroid Build Coastguard Worker
15234*da0073e9SAndroid Build Coastguard Worker    def test_attribute_serialization(self):
15235*da0073e9SAndroid Build Coastguard Worker        tester = self
15236*da0073e9SAndroid Build Coastguard Worker
15237*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
15238*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
15239*da0073e9SAndroid Build Coastguard Worker                super().__init__()
15240*da0073e9SAndroid Build Coastguard Worker                for name, value, the_type in tester.get_pickle_values():
15241*da0073e9SAndroid Build Coastguard Worker                    setattr(self, name, torch.jit.Attribute(value, the_type))
15242*da0073e9SAndroid Build Coastguard Worker
15243*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
15244*da0073e9SAndroid Build Coastguard Worker            def forward(self):
15245*da0073e9SAndroid Build Coastguard Worker                return (self.dict, self.float, self.int, self.bool, self.tuple,
15246*da0073e9SAndroid Build Coastguard Worker                        self.list, self.int_list, self.tensor_list, self.bool_list,
15247*da0073e9SAndroid Build Coastguard Worker                        self.float_list, self.str_list, self.none)
15248*da0073e9SAndroid Build Coastguard Worker
15249*da0073e9SAndroid Build Coastguard Worker        m = M()
15250*da0073e9SAndroid Build Coastguard Worker        imported_m = self.getExportImportCopy(m)
15251*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m(), imported_m())
15252*da0073e9SAndroid Build Coastguard Worker
15253*da0073e9SAndroid Build Coastguard Worker    def test_string_len(self):
15254*da0073e9SAndroid Build Coastguard Worker        def fn(x):
15255*da0073e9SAndroid Build Coastguard Worker            # type: (str) -> int
15256*da0073e9SAndroid Build Coastguard Worker            return len(x)
15257*da0073e9SAndroid Build Coastguard Worker
15258*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ("",))
15259*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ("h",))
15260*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ("hello",))
15261*da0073e9SAndroid Build Coastguard Worker
15262*da0073e9SAndroid Build Coastguard Worker    def test_multiline_optional_future_refinement(self):
15263*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
15264*da0073e9SAndroid Build Coastguard Worker        def fun() -> int:
15265*da0073e9SAndroid Build Coastguard Worker            future: Optional[
15266*da0073e9SAndroid Build Coastguard Worker                torch.jit.Future[Tuple[torch.Tensor]]
15267*da0073e9SAndroid Build Coastguard Worker            ] = None
15268*da0073e9SAndroid Build Coastguard Worker
15269*da0073e9SAndroid Build Coastguard Worker            return 1
15270*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fun(), 1)
15271*da0073e9SAndroid Build Coastguard Worker
15272*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle")
15273*da0073e9SAndroid Build Coastguard Worker    def test_attribute_unpickling(self):
15274*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randn(2, 2)
15275*da0073e9SAndroid Build Coastguard Worker        tester = self
15276*da0073e9SAndroid Build Coastguard Worker
15277*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
15278*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
15279*da0073e9SAndroid Build Coastguard Worker                super().__init__()
15280*da0073e9SAndroid Build Coastguard Worker                for name, value, the_type in tester.get_pickle_values():
15281*da0073e9SAndroid Build Coastguard Worker                    setattr(self, "_" + name, torch.jit.Attribute(value, the_type))
15282*da0073e9SAndroid Build Coastguard Worker
15283*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
15284*da0073e9SAndroid Build Coastguard Worker            def forward(self):
15285*da0073e9SAndroid Build Coastguard Worker                return (self._dict, self._float, self._int, self._bool, self._tuple,
15286*da0073e9SAndroid Build Coastguard Worker                        self._list, self._int_list, self._tensor_list, self._bool_list,
15287*da0073e9SAndroid Build Coastguard Worker                        self._float_list, self._str_list, self._none)
15288*da0073e9SAndroid Build Coastguard Worker
15289*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName() as fname:
15290*da0073e9SAndroid Build Coastguard Worker            M().save(fname)
15291*da0073e9SAndroid Build Coastguard Worker            loaded = torch.jit.load(fname)
15292*da0073e9SAndroid Build Coastguard Worker
15293*da0073e9SAndroid Build Coastguard Worker            def is_tensor_value(item):
15294*da0073e9SAndroid Build Coastguard Worker                if isinstance(item, torch.Tensor):
15295*da0073e9SAndroid Build Coastguard Worker                    return True
15296*da0073e9SAndroid Build Coastguard Worker                if isinstance(item, list):
15297*da0073e9SAndroid Build Coastguard Worker                    return is_tensor_value(item[0])
15298*da0073e9SAndroid Build Coastguard Worker                return False
15299*da0073e9SAndroid Build Coastguard Worker            for name, value, the_type in self.get_pickle_values():
15300*da0073e9SAndroid Build Coastguard Worker                if is_tensor_value(value):
15301*da0073e9SAndroid Build Coastguard Worker                    continue
15302*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(value, getattr(loaded, "_" + name))
15303*da0073e9SAndroid Build Coastguard Worker
15304*da0073e9SAndroid Build Coastguard Worker
15305*da0073e9SAndroid Build Coastguard Worker    def test_submodule_attribute_serialization(self):
15306*da0073e9SAndroid Build Coastguard Worker        class S(torch.jit.ScriptModule):
15307*da0073e9SAndroid Build Coastguard Worker            def __init__(self, list_data):
15308*da0073e9SAndroid Build Coastguard Worker                super().__init__()
15309*da0073e9SAndroid Build Coastguard Worker                self.table = torch.jit.Attribute({"I": "am", "a test": "test"}, Dict[str, str])
15310*da0073e9SAndroid Build Coastguard Worker                self.list = torch.jit.Attribute(list_data, List[Tuple[int, int]])
15311*da0073e9SAndroid Build Coastguard Worker
15312*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
15313*da0073e9SAndroid Build Coastguard Worker            def forward(self):
15314*da0073e9SAndroid Build Coastguard Worker                return (self.table, self.list)
15315*da0073e9SAndroid Build Coastguard Worker
15316*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
15317*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
15318*da0073e9SAndroid Build Coastguard Worker                super().__init__()
15319*da0073e9SAndroid Build Coastguard Worker                self.table = torch.jit.Attribute({"this": "is", "a different": "dict"}, Dict[str, str])
15320*da0073e9SAndroid Build Coastguard Worker                self.tensor = torch.jit.Attribute(torch.randn(2, 2), torch.Tensor)
15321*da0073e9SAndroid Build Coastguard Worker                self.s1 = S([(1, 2)])
15322*da0073e9SAndroid Build Coastguard Worker                self.s2 = S([(4, 5)])
15323*da0073e9SAndroid Build Coastguard Worker
15324*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
15325*da0073e9SAndroid Build Coastguard Worker            def forward(self):
15326*da0073e9SAndroid Build Coastguard Worker                return (self.table, self.tensor, self.s1.table, self.s2.list, self.s1.list)
15327*da0073e9SAndroid Build Coastguard Worker
15328*da0073e9SAndroid Build Coastguard Worker        m = M()
15329*da0073e9SAndroid Build Coastguard Worker        imported_m = self.getExportImportCopy(m)
15330*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m(), imported_m())
15331*da0073e9SAndroid Build Coastguard Worker
15332*da0073e9SAndroid Build Coastguard Worker    def test_serialization_big_ints(self):
15333*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
15334*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
15335*da0073e9SAndroid Build Coastguard Worker                super().__init__()
15336*da0073e9SAndroid Build Coastguard Worker                self.int32_max = torch.jit.Attribute(2**31 - 1, int)
15337*da0073e9SAndroid Build Coastguard Worker                self.int32_min = torch.jit.Attribute(-2**31, int)
15338*da0073e9SAndroid Build Coastguard Worker                self.uint32_max = torch.jit.Attribute(2**32, int)
15339*da0073e9SAndroid Build Coastguard Worker
15340*da0073e9SAndroid Build Coastguard Worker                self.int64_max = torch.jit.Attribute(2**63 - 1, int)
15341*da0073e9SAndroid Build Coastguard Worker                self.int64_min = torch.jit.Attribute(-2**63, int)
15342*da0073e9SAndroid Build Coastguard Worker
15343*da0073e9SAndroid Build Coastguard Worker                self.tensor = torch.nn.Parameter(torch.ones(2, 2))
15344*da0073e9SAndroid Build Coastguard Worker
15345*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
15346*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
15347*da0073e9SAndroid Build Coastguard Worker                # type: (int) -> (int)
15348*da0073e9SAndroid Build Coastguard Worker                return x + (self.int32_max + self.int32_min) + (self.int64_max + self.int64_min)
15349*da0073e9SAndroid Build Coastguard Worker
15350*da0073e9SAndroid Build Coastguard Worker        m = M()
15351*da0073e9SAndroid Build Coastguard Worker        imported = self.getExportImportCopy(m)
15352*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m(10), imported(10))
15353*da0073e9SAndroid Build Coastguard Worker
15354*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.int32_max, imported.int32_max)
15355*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.int32_min, imported.int32_min)
15356*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.uint32_max, imported.uint32_max)
15357*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.int64_max, imported.int64_max)
15358*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.int64_min, imported.int64_min)
15359*da0073e9SAndroid Build Coastguard Worker
15360*da0073e9SAndroid Build Coastguard Worker    def test_script_scope(self):
15361*da0073e9SAndroid Build Coastguard Worker        scripted = torch.jit.script(torch.nn.functional.triplet_margin_loss)
15362*da0073e9SAndroid Build Coastguard Worker
15363*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "NYI: TemporaryFileName on Windows")
15364*da0073e9SAndroid Build Coastguard Worker    def test_serialization_sharing(self):
15365*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
15366*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
15367*da0073e9SAndroid Build Coastguard Worker                super().__init__()
15368*da0073e9SAndroid Build Coastguard Worker                self.list = torch.jit.Attribute([], List[str])
15369*da0073e9SAndroid Build Coastguard Worker
15370*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
15371*da0073e9SAndroid Build Coastguard Worker            def forward(self, key):
15372*da0073e9SAndroid Build Coastguard Worker                # type: (str) -> List[str]
15373*da0073e9SAndroid Build Coastguard Worker                self.list.append(key)
15374*da0073e9SAndroid Build Coastguard Worker                self.list.append(key)
15375*da0073e9SAndroid Build Coastguard Worker                self.list.append(key)
15376*da0073e9SAndroid Build Coastguard Worker                return self.list
15377*da0073e9SAndroid Build Coastguard Worker
15378*da0073e9SAndroid Build Coastguard Worker        # the text of the string should only appear once in the pickling
15379*da0073e9SAndroid Build Coastguard Worker        m = M()
15380*da0073e9SAndroid Build Coastguard Worker        s1 = "a long string"
15381*da0073e9SAndroid Build Coastguard Worker        s2 = "a different, even longer string"
15382*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m(s1), [s1] * 3)
15383*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m(s2), [s1] * 3 + [s2] * 3)
15384*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName() as fname:
15385*da0073e9SAndroid Build Coastguard Worker            m.save(fname)
15386*da0073e9SAndroid Build Coastguard Worker            archive_name = os.path.basename(os.path.normpath(fname))
15387*da0073e9SAndroid Build Coastguard Worker            archive = zipfile.ZipFile(fname, 'r')
15388*da0073e9SAndroid Build Coastguard Worker            pickled_data = archive.read(os.path.join(archive_name, 'data.pkl'))
15389*da0073e9SAndroid Build Coastguard Worker
15390*da0073e9SAndroid Build Coastguard Worker            out = io.StringIO()
15391*da0073e9SAndroid Build Coastguard Worker            pickletools.dis(pickled_data, out=out)
15392*da0073e9SAndroid Build Coastguard Worker            disassembled = out.getvalue()
15393*da0073e9SAndroid Build Coastguard Worker
15394*da0073e9SAndroid Build Coastguard Worker            FileCheck().check_count(s1, 1, exactly=True) \
15395*da0073e9SAndroid Build Coastguard Worker                .check_count("BINGET", 2, exactly=True) \
15396*da0073e9SAndroid Build Coastguard Worker                .check_count(s2, 1, exactly=True) \
15397*da0073e9SAndroid Build Coastguard Worker                .check_count("BINGET", 2, exactly=True).run(out.getvalue())
15398*da0073e9SAndroid Build Coastguard Worker
15399*da0073e9SAndroid Build Coastguard Worker    def test_sys_stdout_override(self):
15400*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
15401*da0073e9SAndroid Build Coastguard Worker        def foo():
15402*da0073e9SAndroid Build Coastguard Worker            print('foo')
15403*da0073e9SAndroid Build Coastguard Worker
15404*da0073e9SAndroid Build Coastguard Worker        class Redirect:
15405*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
15406*da0073e9SAndroid Build Coastguard Worker                self.s = ''
15407*da0073e9SAndroid Build Coastguard Worker
15408*da0073e9SAndroid Build Coastguard Worker            def write(self, s):
15409*da0073e9SAndroid Build Coastguard Worker                self.s += s
15410*da0073e9SAndroid Build Coastguard Worker
15411*da0073e9SAndroid Build Coastguard Worker        old_stdout = sys.stdout
15412*da0073e9SAndroid Build Coastguard Worker        redirect = Redirect()
15413*da0073e9SAndroid Build Coastguard Worker        try:
15414*da0073e9SAndroid Build Coastguard Worker            sys.stdout = redirect
15415*da0073e9SAndroid Build Coastguard Worker            foo()
15416*da0073e9SAndroid Build Coastguard Worker        finally:
15417*da0073e9SAndroid Build Coastguard Worker            sys.stdout = old_stdout
15418*da0073e9SAndroid Build Coastguard Worker
15419*da0073e9SAndroid Build Coastguard Worker        FileCheck().check('foo').run(redirect.s)
15420*da0073e9SAndroid Build Coastguard Worker
15421*da0073e9SAndroid Build Coastguard Worker    def test_dtype_attr(self):
15422*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
15423*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
15424*da0073e9SAndroid Build Coastguard Worker                super().__init__()
15425*da0073e9SAndroid Build Coastguard Worker                self.dtype = torch.zeros([]).dtype
15426*da0073e9SAndroid Build Coastguard Worker
15427*da0073e9SAndroid Build Coastguard Worker            def forward(self):
15428*da0073e9SAndroid Build Coastguard Worker                return torch.zeros(3, 4, dtype=self.dtype)
15429*da0073e9SAndroid Build Coastguard Worker
15430*da0073e9SAndroid Build Coastguard Worker        f = Foo()
15431*da0073e9SAndroid Build Coastguard Worker        torch.jit.script(f)
15432*da0073e9SAndroid Build Coastguard Worker
15433*da0073e9SAndroid Build Coastguard Worker
15434*da0073e9SAndroid Build Coastguard Worker    def test_named_buffers_are_iterable(self):
15435*da0073e9SAndroid Build Coastguard Worker        class MyMod(torch.nn.Module):
15436*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
15437*da0073e9SAndroid Build Coastguard Worker                super().__init__()
15438*da0073e9SAndroid Build Coastguard Worker                self.mod = (torch.nn.ReLU())
15439*da0073e9SAndroid Build Coastguard Worker                self.mod2 = (torch.nn.ReLU())
15440*da0073e9SAndroid Build Coastguard Worker                self.mod3 = torch.nn.Sequential(torch.nn.Sequential(torch.nn.ReLU()))
15441*da0073e9SAndroid Build Coastguard Worker                self.x = nn.Buffer(torch.zeros(3))
15442*da0073e9SAndroid Build Coastguard Worker                self.y = nn.Buffer(torch.zeros(3))
15443*da0073e9SAndroid Build Coastguard Worker                self.z = torch.zeros(3)
15444*da0073e9SAndroid Build Coastguard Worker
15445*da0073e9SAndroid Build Coastguard Worker            def bleh(self):
15446*da0073e9SAndroid Build Coastguard Worker                return self.z + 4
15447*da0073e9SAndroid Build Coastguard Worker
15448*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
15449*da0073e9SAndroid Build Coastguard Worker            def method(self):
15450*da0073e9SAndroid Build Coastguard Worker                names = [""]
15451*da0073e9SAndroid Build Coastguard Worker                vals = []
15452*da0073e9SAndroid Build Coastguard Worker                for name, buffer in self.named_buffers():
15453*da0073e9SAndroid Build Coastguard Worker                    names.append(name)
15454*da0073e9SAndroid Build Coastguard Worker                    vals.append(buffer + 2)
15455*da0073e9SAndroid Build Coastguard Worker
15456*da0073e9SAndroid Build Coastguard Worker                return names, vals
15457*da0073e9SAndroid Build Coastguard Worker
15458*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
15459*da0073e9SAndroid Build Coastguard Worker                return x
15460*da0073e9SAndroid Build Coastguard Worker
15461*da0073e9SAndroid Build Coastguard Worker        model = MyMod()
15462*da0073e9SAndroid Build Coastguard Worker        x = torch.jit.script(model)
15463*da0073e9SAndroid Build Coastguard Worker        z = self.getExportImportCopy(x)
15464*da0073e9SAndroid Build Coastguard Worker
15465*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z.method(), x.method())
15466*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z.method(), model.method())
15467*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.method(), model.method())
15468*da0073e9SAndroid Build Coastguard Worker        names = x.method()
15469*da0073e9SAndroid Build Coastguard Worker        for name in names:
15470*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual('z', name)
15471*da0073e9SAndroid Build Coastguard Worker
15472*da0073e9SAndroid Build Coastguard Worker
15473*da0073e9SAndroid Build Coastguard Worker    def test_static_if_prop(self):
15474*da0073e9SAndroid Build Coastguard Worker        class MaybeHasAttr(torch.nn.Module):
15475*da0073e9SAndroid Build Coastguard Worker            def __init__(self, add_attr):
15476*da0073e9SAndroid Build Coastguard Worker                super().__init__()
15477*da0073e9SAndroid Build Coastguard Worker                if add_attr:
15478*da0073e9SAndroid Build Coastguard Worker                    self.maybe_attr = 1
15479*da0073e9SAndroid Build Coastguard Worker
15480*da0073e9SAndroid Build Coastguard Worker            def forward(self):
15481*da0073e9SAndroid Build Coastguard Worker                if hasattr(self, "maybe_attr") and True:
15482*da0073e9SAndroid Build Coastguard Worker                    return self.maybe_attr
15483*da0073e9SAndroid Build Coastguard Worker                else:
15484*da0073e9SAndroid Build Coastguard Worker                    return 0
15485*da0073e9SAndroid Build Coastguard Worker
15486*da0073e9SAndroid Build Coastguard Worker        class MaybeHasAttr2(torch.nn.Module):
15487*da0073e9SAndroid Build Coastguard Worker            def __init__(self, add_attr):
15488*da0073e9SAndroid Build Coastguard Worker                super().__init__()
15489*da0073e9SAndroid Build Coastguard Worker                if add_attr:
15490*da0073e9SAndroid Build Coastguard Worker                    self.maybe_attr = 1
15491*da0073e9SAndroid Build Coastguard Worker
15492*da0073e9SAndroid Build Coastguard Worker            def forward(self):
15493*da0073e9SAndroid Build Coastguard Worker                if not hasattr(self, "maybe_attr") or False:
15494*da0073e9SAndroid Build Coastguard Worker                    return 0
15495*da0073e9SAndroid Build Coastguard Worker                else:
15496*da0073e9SAndroid Build Coastguard Worker                    return self.maybe_attr
15497*da0073e9SAndroid Build Coastguard Worker
15498*da0073e9SAndroid Build Coastguard Worker        torch.jit.script(MaybeHasAttr(True))
15499*da0073e9SAndroid Build Coastguard Worker        torch.jit.script(MaybeHasAttr(False))
15500*da0073e9SAndroid Build Coastguard Worker        torch.jit.script(MaybeHasAttr2(True))
15501*da0073e9SAndroid Build Coastguard Worker        torch.jit.script(MaybeHasAttr2(False))
15502*da0073e9SAndroid Build Coastguard Worker
15503*da0073e9SAndroid Build Coastguard Worker        class MyMod(torch.nn.Module):
15504*da0073e9SAndroid Build Coastguard Worker            def forward(self):
15505*da0073e9SAndroid Build Coastguard Worker                if hasattr(self, "foo"):
15506*da0073e9SAndroid Build Coastguard Worker                    return 1
15507*da0073e9SAndroid Build Coastguard Worker                else:
15508*da0073e9SAndroid Build Coastguard Worker                    return 0
15509*da0073e9SAndroid Build Coastguard Worker
15510*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
15511*da0073e9SAndroid Build Coastguard Worker            def fee(self):
15512*da0073e9SAndroid Build Coastguard Worker                return 1
15513*da0073e9SAndroid Build Coastguard Worker
15514*da0073e9SAndroid Build Coastguard Worker        self.checkModule(MyMod(), ())
15515*da0073e9SAndroid Build Coastguard Worker
15516*da0073e9SAndroid Build Coastguard Worker        class HasAttrMod(torch.nn.Module):
15517*da0073e9SAndroid Build Coastguard Worker            __constants__ = ["fee"]
15518*da0073e9SAndroid Build Coastguard Worker
15519*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
15520*da0073e9SAndroid Build Coastguard Worker                super().__init__()
15521*da0073e9SAndroid Build Coastguard Worker                self.fee = 3
15522*da0073e9SAndroid Build Coastguard Worker
15523*da0073e9SAndroid Build Coastguard Worker            def forward(self):
15524*da0073e9SAndroid Build Coastguard Worker                a = hasattr(self, "fee")
15525*da0073e9SAndroid Build Coastguard Worker                b = hasattr(self, "foo")
15526*da0073e9SAndroid Build Coastguard Worker                c = hasattr(self, "hi")
15527*da0073e9SAndroid Build Coastguard Worker                d = hasattr(self, "nonexistant")
15528*da0073e9SAndroid Build Coastguard Worker                return (a, b, c, d)
15529*da0073e9SAndroid Build Coastguard Worker
15530*da0073e9SAndroid Build Coastguard Worker            def foo(self):
15531*da0073e9SAndroid Build Coastguard Worker                return 1
15532*da0073e9SAndroid Build Coastguard Worker
15533*da0073e9SAndroid Build Coastguard Worker            @torch.jit._overload_method
15534*da0073e9SAndroid Build Coastguard Worker            def hi(self, x: Tensor): ...  # noqa: E704
15535*da0073e9SAndroid Build Coastguard Worker
15536*da0073e9SAndroid Build Coastguard Worker            def hi(self, x):  # noqa: F811
15537*da0073e9SAndroid Build Coastguard Worker                return 2
15538*da0073e9SAndroid Build Coastguard Worker
15539*da0073e9SAndroid Build Coastguard Worker        self.checkModule(HasAttrMod(), ())
15540*da0073e9SAndroid Build Coastguard Worker
15541*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
15542*da0073e9SAndroid Build Coastguard Worker        class FooTest:
15543*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
15544*da0073e9SAndroid Build Coastguard Worker                self.x = 1
15545*da0073e9SAndroid Build Coastguard Worker
15546*da0073e9SAndroid Build Coastguard Worker            def foo(self, y):
15547*da0073e9SAndroid Build Coastguard Worker                return self.x + y
15548*da0073e9SAndroid Build Coastguard Worker
15549*da0073e9SAndroid Build Coastguard Worker        def foo():
15550*da0073e9SAndroid Build Coastguard Worker            a = FooTest()
15551*da0073e9SAndroid Build Coastguard Worker            val1 = hasattr(a, "foo"), hasattr(a, "x"), hasattr(a, "bla")
15552*da0073e9SAndroid Build Coastguard Worker            val2 = hasattr(FooTest, "foo"), hasattr(FooTest, "a")
15553*da0073e9SAndroid Build Coastguard Worker            return val1, val2
15554*da0073e9SAndroid Build Coastguard Worker
15555*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(), torch.jit.script(foo)())
15556*da0073e9SAndroid Build Coastguard Worker
15557*da0073e9SAndroid Build Coastguard Worker    def _test_pickle_checkpoint(self, device):
15558*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName() as fname:
15559*da0073e9SAndroid Build Coastguard Worker            class M(torch.jit.ScriptModule):
15560*da0073e9SAndroid Build Coastguard Worker                __constants__ = ['fname']
15561*da0073e9SAndroid Build Coastguard Worker
15562*da0073e9SAndroid Build Coastguard Worker                def __init__(self, tensor):
15563*da0073e9SAndroid Build Coastguard Worker                    super().__init__()
15564*da0073e9SAndroid Build Coastguard Worker                    self.fname = fname
15565*da0073e9SAndroid Build Coastguard Worker                    self.tensor = torch.nn.Parameter(tensor)
15566*da0073e9SAndroid Build Coastguard Worker
15567*da0073e9SAndroid Build Coastguard Worker                @torch.jit.script_method
15568*da0073e9SAndroid Build Coastguard Worker                def forward(self, x):
15569*da0073e9SAndroid Build Coastguard Worker                    y = self.tensor + x
15570*da0073e9SAndroid Build Coastguard Worker                    torch.save(y, self.fname)
15571*da0073e9SAndroid Build Coastguard Worker                    return y
15572*da0073e9SAndroid Build Coastguard Worker
15573*da0073e9SAndroid Build Coastguard Worker            param = torch.randn(2, 2).to(device)
15574*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(2, 2).to(device)
15575*da0073e9SAndroid Build Coastguard Worker            m = M(param)
15576*da0073e9SAndroid Build Coastguard Worker            m(input)
15577*da0073e9SAndroid Build Coastguard Worker            with open(fname, "rb") as handle:
15578*da0073e9SAndroid Build Coastguard Worker                loaded_tensor = torch.load(fname)
15579*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(loaded_tensor, input + param)
15580*da0073e9SAndroid Build Coastguard Worker
15581*da0073e9SAndroid Build Coastguard Worker    def _test_pickle_checkpoint_views(self, device):
15582*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName() as fname:
15583*da0073e9SAndroid Build Coastguard Worker            class M(torch.jit.ScriptModule):
15584*da0073e9SAndroid Build Coastguard Worker                __constants__ = ['fname']
15585*da0073e9SAndroid Build Coastguard Worker
15586*da0073e9SAndroid Build Coastguard Worker                def __init__(self, tensor):
15587*da0073e9SAndroid Build Coastguard Worker                    super().__init__()
15588*da0073e9SAndroid Build Coastguard Worker                    self.fname = fname
15589*da0073e9SAndroid Build Coastguard Worker                    self.tensor = torch.nn.Parameter(tensor)
15590*da0073e9SAndroid Build Coastguard Worker
15591*da0073e9SAndroid Build Coastguard Worker                @torch.jit.script_method
15592*da0073e9SAndroid Build Coastguard Worker                def forward(self, x):
15593*da0073e9SAndroid Build Coastguard Worker                    y = self.tensor + x
15594*da0073e9SAndroid Build Coastguard Worker                    y_view = y.view(4)
15595*da0073e9SAndroid Build Coastguard Worker                    torch.save((y, y_view, y), self.fname)
15596*da0073e9SAndroid Build Coastguard Worker                    return y
15597*da0073e9SAndroid Build Coastguard Worker
15598*da0073e9SAndroid Build Coastguard Worker            param = torch.randn(2, 2).to(device)
15599*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(2, 2).to(device)
15600*da0073e9SAndroid Build Coastguard Worker            m = M(param)
15601*da0073e9SAndroid Build Coastguard Worker            m(input)
15602*da0073e9SAndroid Build Coastguard Worker            with open(fname, "rb") as handle:
15603*da0073e9SAndroid Build Coastguard Worker                loaded_y, loaded_y_view, loaded_y_2 = torch.load(fname)
15604*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(loaded_y, input + param)
15605*da0073e9SAndroid Build Coastguard Worker                with torch.no_grad():
15606*da0073e9SAndroid Build Coastguard Worker                    loaded_y_view[1] += 20
15607*da0073e9SAndroid Build Coastguard Worker                    # assert that loaded_y changed as well
15608*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(loaded_y.view(4), loaded_y_view)
15609*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(loaded_y_2.view(4), loaded_y_view)
15610*da0073e9SAndroid Build Coastguard Worker
15611*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "no CUDA")
15612*da0073e9SAndroid Build Coastguard Worker    def test_pickle_checkpoint_cuda(self):
15613*da0073e9SAndroid Build Coastguard Worker        self._test_pickle_checkpoint('cuda')
15614*da0073e9SAndroid Build Coastguard Worker        self._test_pickle_checkpoint_views('cuda')
15615*da0073e9SAndroid Build Coastguard Worker
15616*da0073e9SAndroid Build Coastguard Worker    def test_pickle_checkpoint(self):
15617*da0073e9SAndroid Build Coastguard Worker        self._test_pickle_checkpoint('cpu')
15618*da0073e9SAndroid Build Coastguard Worker        self._test_pickle_checkpoint_views('cpu')
15619*da0073e9SAndroid Build Coastguard Worker
15620*da0073e9SAndroid Build Coastguard Worker    def test_pickle_checkpoint_tup(self):
15621*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
15622*da0073e9SAndroid Build Coastguard Worker        def foo(fname):
15623*da0073e9SAndroid Build Coastguard Worker            # type: (str) -> None
15624*da0073e9SAndroid Build Coastguard Worker            torch.save((3, 4), fname)
15625*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName() as name:
15626*da0073e9SAndroid Build Coastguard Worker            foo(name)
15627*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.load(name), (3, 4))
15628*da0073e9SAndroid Build Coastguard Worker
15629*da0073e9SAndroid Build Coastguard Worker    def test_string_list(self):
15630*da0073e9SAndroid Build Coastguard Worker        def fn(string):
15631*da0073e9SAndroid Build Coastguard Worker            # type: (str) -> List[str]
15632*da0073e9SAndroid Build Coastguard Worker            return list(string)
15633*da0073e9SAndroid Build Coastguard Worker
15634*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ("abcdefgh",))
15635*da0073e9SAndroid Build Coastguard Worker
15636*da0073e9SAndroid Build Coastguard Worker    def test_unicode_comments(self):
15637*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
15638*da0073e9SAndroid Build Coastguard Worker        def test(self, a):
15639*da0073e9SAndroid Build Coastguard Worker            # ��������
15640*da0073e9SAndroid Build Coastguard Worker            return torch.nn.functional.relu(a)
15641*da0073e9SAndroid Build Coastguard Worker
15642*da0073e9SAndroid Build Coastguard Worker    def test_get_set_state_with_tensors(self):
15643*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
15644*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
15645*da0073e9SAndroid Build Coastguard Worker                super().__init__()
15646*da0073e9SAndroid Build Coastguard Worker                self.tensor = torch.randn(2, 2)
15647*da0073e9SAndroid Build Coastguard Worker
15648*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
15649*da0073e9SAndroid Build Coastguard Worker            def __getstate__(self):
15650*da0073e9SAndroid Build Coastguard Worker                return (self.tensor, self.training)
15651*da0073e9SAndroid Build Coastguard Worker
15652*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
15653*da0073e9SAndroid Build Coastguard Worker            def __setstate__(self, state):
15654*da0073e9SAndroid Build Coastguard Worker                self.tensor = state[0]
15655*da0073e9SAndroid Build Coastguard Worker                self.training = state[1]
15656*da0073e9SAndroid Build Coastguard Worker
15657*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
15658*da0073e9SAndroid Build Coastguard Worker                return x + self.tensor
15659*da0073e9SAndroid Build Coastguard Worker
15660*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName() as fname:
15661*da0073e9SAndroid Build Coastguard Worker            m = torch.jit.script(M())
15662*da0073e9SAndroid Build Coastguard Worker            m.save(fname)
15663*da0073e9SAndroid Build Coastguard Worker            loaded = torch.jit.load(fname)
15664*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(loaded.tensor, m.tensor)
15665*da0073e9SAndroid Build Coastguard Worker
15666*da0073e9SAndroid Build Coastguard Worker    def test_in_for_and_comp_expr(self):
15667*da0073e9SAndroid Build Coastguard Worker        def fn(d):
15668*da0073e9SAndroid Build Coastguard Worker            # type: (Dict[str, int]) -> List[int]
15669*da0073e9SAndroid Build Coastguard Worker            out = [1]
15670*da0073e9SAndroid Build Coastguard Worker            for i in range(d["hi"] if "hi" in d else 6):
15671*da0073e9SAndroid Build Coastguard Worker                out.append(i)  # noqa: PERF402
15672*da0073e9SAndroid Build Coastguard Worker            return out
15673*da0073e9SAndroid Build Coastguard Worker
15674*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ({'hi': 2, 'bye': 3},))
15675*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ({'bye': 3},))
15676*da0073e9SAndroid Build Coastguard Worker
15677*da0073e9SAndroid Build Coastguard Worker    def test_for_else(self):
15678*da0073e9SAndroid Build Coastguard Worker        def fn():
15679*da0073e9SAndroid Build Coastguard Worker            c = 0
15680*da0073e9SAndroid Build Coastguard Worker            for i in range(4):
15681*da0073e9SAndroid Build Coastguard Worker                c += 10
15682*da0073e9SAndroid Build Coastguard Worker            else:
15683*da0073e9SAndroid Build Coastguard Worker                print("In else block of for...else")
15684*da0073e9SAndroid Build Coastguard Worker
15685*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, "else branches of for loops aren't supported"):
15686*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(fn)
15687*da0073e9SAndroid Build Coastguard Worker
15688*da0073e9SAndroid Build Coastguard Worker    def test_split(self):
15689*da0073e9SAndroid Build Coastguard Worker        def split_two(tensor):
15690*da0073e9SAndroid Build Coastguard Worker            a, b, c = torch.split(tensor, 2, dim=1)
15691*da0073e9SAndroid Build Coastguard Worker            return a, b, c
15692*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 6)
15693*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(3, 6)
15694*da0073e9SAndroid Build Coastguard Worker        self.checkScript(split_two, [(x + y)])
15695*da0073e9SAndroid Build Coastguard Worker
15696*da0073e9SAndroid Build Coastguard Worker    def test_conv_error(self):
15697*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
15698*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
15699*da0073e9SAndroid Build Coastguard Worker            return F.conv2d(x, y)
15700*da0073e9SAndroid Build Coastguard Worker
15701*da0073e9SAndroid Build Coastguard Worker        try:
15702*da0073e9SAndroid Build Coastguard Worker            fn(torch.ones(2, 2), torch.ones(4, 4))
15703*da0073e9SAndroid Build Coastguard Worker        except RuntimeError as e:
15704*da0073e9SAndroid Build Coastguard Worker            self.assertFalse('frame' in str(e))
15705*da0073e9SAndroid Build Coastguard Worker
15706*da0073e9SAndroid Build Coastguard Worker    def test_python_op_name(self):
15707*da0073e9SAndroid Build Coastguard Worker        import random
15708*da0073e9SAndroid Build Coastguard Worker
15709*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "randint"):
15710*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
15711*da0073e9SAndroid Build Coastguard Worker            def fn():
15712*da0073e9SAndroid Build Coastguard Worker                return random.randint()
15713*da0073e9SAndroid Build Coastguard Worker
15714*da0073e9SAndroid Build Coastguard Worker    def test_dir(self):
15715*da0073e9SAndroid Build Coastguard Worker        class M(torch.jit.ScriptModule):
15716*da0073e9SAndroid Build Coastguard Worker            def forward(self, t):
15717*da0073e9SAndroid Build Coastguard Worker                return t
15718*da0073e9SAndroid Build Coastguard Worker
15719*da0073e9SAndroid Build Coastguard Worker        self.assertTrue('forward' in dir(M()))
15720*da0073e9SAndroid Build Coastguard Worker
15721*da0073e9SAndroid Build Coastguard Worker    def test_kwarg_expansion_error(self):
15722*da0073e9SAndroid Build Coastguard Worker        @torch.jit.ignore
15723*da0073e9SAndroid Build Coastguard Worker        def something_else(h, i):
15724*da0073e9SAndroid Build Coastguard Worker            pass
15725*da0073e9SAndroid Build Coastguard Worker
15726*da0073e9SAndroid Build Coastguard Worker        def fn(x):
15727*da0073e9SAndroid Build Coastguard Worker            something_else(**x)
15728*da0073e9SAndroid Build Coastguard Worker
15729*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, "keyword-arg expansion is not supported"):
15730*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(fn)
15731*da0073e9SAndroid Build Coastguard Worker
15732*da0073e9SAndroid Build Coastguard Worker    def test_kwargs_error_msg(self):
15733*da0073e9SAndroid Build Coastguard Worker        def other(**kwargs):
15734*da0073e9SAndroid Build Coastguard Worker            print(kwargs)
15735*da0073e9SAndroid Build Coastguard Worker
15736*da0073e9SAndroid Build Coastguard Worker        def fn():
15737*da0073e9SAndroid Build Coastguard Worker            return other()
15738*da0073e9SAndroid Build Coastguard Worker
15739*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, 'variable number'):
15740*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(fn)
15741*da0073e9SAndroid Build Coastguard Worker
15742*da0073e9SAndroid Build Coastguard Worker        def another_other(*args):
15743*da0073e9SAndroid Build Coastguard Worker            print(args)
15744*da0073e9SAndroid Build Coastguard Worker
15745*da0073e9SAndroid Build Coastguard Worker        def another_fn():
15746*da0073e9SAndroid Build Coastguard Worker            return another_other()
15747*da0073e9SAndroid Build Coastguard Worker
15748*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, 'variable number'):
15749*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(another_fn)
15750*da0073e9SAndroid Build Coastguard Worker
15751*da0073e9SAndroid Build Coastguard Worker    def test_inferred_error_msg(self):
15752*da0073e9SAndroid Build Coastguard Worker        """
15753*da0073e9SAndroid Build Coastguard Worker        Test that when we get a type mismatch on a function where we inferred
15754*da0073e9SAndroid Build Coastguard Worker        the type to be tensor, a good error message is given.
15755*da0073e9SAndroid Build Coastguard Worker        """
15756*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
15757*da0073e9SAndroid Build Coastguard Worker        def foo(a):
15758*da0073e9SAndroid Build Coastguard Worker            return a
15759*da0073e9SAndroid Build Coastguard Worker
15760*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, (r"Expected a value of type \'Tensor \(inferred\)\'"
15761*da0073e9SAndroid Build Coastguard Worker                                                   r"[\S\s]*Inferred \'a\' to be of type \'Tensor\'")):
15762*da0073e9SAndroid Build Coastguard Worker            foo("1")
15763*da0073e9SAndroid Build Coastguard Worker
15764*da0073e9SAndroid Build Coastguard Worker    def test_type_comments_in_body(self):
15765*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
15766*da0073e9SAndroid Build Coastguard Worker        def foo(a,  # type: int
15767*da0073e9SAndroid Build Coastguard Worker                b,  # type: int
15768*da0073e9SAndroid Build Coastguard Worker                ):
15769*da0073e9SAndroid Build Coastguard Worker            # type: (...) -> int
15770*da0073e9SAndroid Build Coastguard Worker            # type: int
15771*da0073e9SAndroid Build Coastguard Worker            return a + b
15772*da0073e9SAndroid Build Coastguard Worker
15773*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
15774*da0073e9SAndroid Build Coastguard Worker            def __init__(self,
15775*da0073e9SAndroid Build Coastguard Worker                         a,  # type: int
15776*da0073e9SAndroid Build Coastguard Worker                         b   # type: int
15777*da0073e9SAndroid Build Coastguard Worker                         ):
15778*da0073e9SAndroid Build Coastguard Worker                # type: (...) -> None
15779*da0073e9SAndroid Build Coastguard Worker                super().__init__()
15780*da0073e9SAndroid Build Coastguard Worker                self.a = a  # type: int
15781*da0073e9SAndroid Build Coastguard Worker                self.b = b  # type: int
15782*da0073e9SAndroid Build Coastguard Worker
15783*da0073e9SAndroid Build Coastguard Worker        torch.jit.script(M(2, 3))
15784*da0073e9SAndroid Build Coastguard Worker
15785*da0073e9SAndroid Build Coastguard Worker    def test_input_keyword_in_schema(self):
15786*da0073e9SAndroid Build Coastguard Worker        def f(x):
15787*da0073e9SAndroid Build Coastguard Worker            return torch.ceil(input=x)
15788*da0073e9SAndroid Build Coastguard Worker
15789*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(10)
15790*da0073e9SAndroid Build Coastguard Worker        self.checkScript(f, (inp, ))
15791*da0073e9SAndroid Build Coastguard Worker
15792*da0073e9SAndroid Build Coastguard Worker    def test_module_method_reassignment(self):
15793*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
15794*da0073e9SAndroid Build Coastguard Worker            def _forward(self, x):
15795*da0073e9SAndroid Build Coastguard Worker                return x
15796*da0073e9SAndroid Build Coastguard Worker
15797*da0073e9SAndroid Build Coastguard Worker            forward = _forward
15798*da0073e9SAndroid Build Coastguard Worker
15799*da0073e9SAndroid Build Coastguard Worker        sm = torch.jit.script(Foo())
15800*da0073e9SAndroid Build Coastguard Worker        input = torch.ones(2, 2)
15801*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(input, sm(input))
15802*da0073e9SAndroid Build Coastguard Worker
15803*da0073e9SAndroid Build Coastguard Worker    # Tests the case where a torch.Tensor subclass (like Parameter) is used as
15804*da0073e9SAndroid Build Coastguard Worker    # input.
15805*da0073e9SAndroid Build Coastguard Worker    def test_script_module_tensor_subclass_argument(self):
15806*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
15807*da0073e9SAndroid Build Coastguard Worker        def parameter_script(x: torch.nn.Parameter):
15808*da0073e9SAndroid Build Coastguard Worker            return x
15809*da0073e9SAndroid Build Coastguard Worker
15810*da0073e9SAndroid Build Coastguard Worker        input = torch.ones(2, 2)
15811*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(input, parameter_script(input))
15812*da0073e9SAndroid Build Coastguard Worker
15813*da0073e9SAndroid Build Coastguard Worker    def test_save_load_attr_error(self):
15814*da0073e9SAndroid Build Coastguard Worker        class Inner(nn.Module):
15815*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
15816*da0073e9SAndroid Build Coastguard Worker                return x
15817*da0073e9SAndroid Build Coastguard Worker
15818*da0073e9SAndroid Build Coastguard Worker        class Wrapper(nn.Module):
15819*da0073e9SAndroid Build Coastguard Worker            def __init__(self, inner):
15820*da0073e9SAndroid Build Coastguard Worker                super().__init__()
15821*da0073e9SAndroid Build Coastguard Worker                self.inner = inner
15822*da0073e9SAndroid Build Coastguard Worker
15823*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
15824*da0073e9SAndroid Build Coastguard Worker                # this attribute doesn't exist on `Inner`
15825*da0073e9SAndroid Build Coastguard Worker                return self.inner.b(x)
15826*da0073e9SAndroid Build Coastguard Worker
15827*da0073e9SAndroid Build Coastguard Worker        inner_module = torch.jit.script(Inner())
15828*da0073e9SAndroid Build Coastguard Worker        inner_module = self.getExportImportCopy(inner_module)
15829*da0073e9SAndroid Build Coastguard Worker        wrapped = Wrapper(inner_module)
15830*da0073e9SAndroid Build Coastguard Worker        # This should properly complain that `self.inner` doesn't have the attribute `b`
15831*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'has no attribute'):
15832*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(wrapped)
15833*da0073e9SAndroid Build Coastguard Worker
15834*da0073e9SAndroid Build Coastguard Worker    def test_rescripting_loaded_modules(self):
15835*da0073e9SAndroid Build Coastguard Worker        class InnerSubmod(nn.Module):
15836*da0073e9SAndroid Build Coastguard Worker            __constants__ = ['my_constant']
15837*da0073e9SAndroid Build Coastguard Worker
15838*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
15839*da0073e9SAndroid Build Coastguard Worker                super().__init__()
15840*da0073e9SAndroid Build Coastguard Worker                self.foo = torch.nn.Buffer(torch.ones(1))
15841*da0073e9SAndroid Build Coastguard Worker                self.register_parameter("bar", torch.nn.Parameter(torch.ones(1)))
15842*da0073e9SAndroid Build Coastguard Worker                self.baz = torch.ones(1)
15843*da0073e9SAndroid Build Coastguard Worker                self.my_constant = 1
15844*da0073e9SAndroid Build Coastguard Worker
15845*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
15846*da0073e9SAndroid Build Coastguard Worker                return x + x
15847*da0073e9SAndroid Build Coastguard Worker
15848*da0073e9SAndroid Build Coastguard Worker        class Inner(nn.Module):
15849*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
15850*da0073e9SAndroid Build Coastguard Worker                super().__init__()
15851*da0073e9SAndroid Build Coastguard Worker                self.submod = InnerSubmod()
15852*da0073e9SAndroid Build Coastguard Worker
15853*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
15854*da0073e9SAndroid Build Coastguard Worker                return self.submod(x)
15855*da0073e9SAndroid Build Coastguard Worker
15856*da0073e9SAndroid Build Coastguard Worker        class Wrapper(nn.Module):
15857*da0073e9SAndroid Build Coastguard Worker            def __init__(self, inner):
15858*da0073e9SAndroid Build Coastguard Worker                super().__init__()
15859*da0073e9SAndroid Build Coastguard Worker                self.inner = inner
15860*da0073e9SAndroid Build Coastguard Worker
15861*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
15862*da0073e9SAndroid Build Coastguard Worker                # access inner elements
15863*da0073e9SAndroid Build Coastguard Worker                ret = self.inner.submod(x) + self.inner.submod.foo + self.inner.submod.bar + self.inner.submod.baz
15864*da0073e9SAndroid Build Coastguard Worker                ret = ret + self.inner.submod.my_constant
15865*da0073e9SAndroid Build Coastguard Worker                return ret
15866*da0073e9SAndroid Build Coastguard Worker
15867*da0073e9SAndroid Build Coastguard Worker        inner_module = torch.jit.script(Inner())
15868*da0073e9SAndroid Build Coastguard Worker        wrapped = Wrapper(inner_module)
15869*da0073e9SAndroid Build Coastguard Worker        self.checkModule(wrapped, torch.ones(1))
15870*da0073e9SAndroid Build Coastguard Worker
15871*da0073e9SAndroid Build Coastguard Worker        inner_module_loaded = self.getExportImportCopy(inner_module)
15872*da0073e9SAndroid Build Coastguard Worker        wrapped_loaded = Wrapper(inner_module_loaded)
15873*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(wrapped(torch.ones(1)), wrapped_loaded(torch.ones(1)))
15874*da0073e9SAndroid Build Coastguard Worker
15875*da0073e9SAndroid Build Coastguard Worker    def test_interpret_graph(self):
15876*da0073e9SAndroid Build Coastguard Worker        def fn(x):
15877*da0073e9SAndroid Build Coastguard Worker            return x.unfold(0, 1, 1)
15878*da0073e9SAndroid Build Coastguard Worker
15879*da0073e9SAndroid Build Coastguard Worker        graph_str = """
15880*da0073e9SAndroid Build Coastguard Worker        graph(%a : Tensor, %b : Tensor):
15881*da0073e9SAndroid Build Coastguard Worker          %c : Tensor = aten::mul(%a, %b)
15882*da0073e9SAndroid Build Coastguard Worker          return (%c)
15883*da0073e9SAndroid Build Coastguard Worker        """
15884*da0073e9SAndroid Build Coastguard Worker        graph = parse_ir(graph_str)
15885*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(10)
15886*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(10)
15887*da0073e9SAndroid Build Coastguard Worker        test = torch._C._jit_interpret_graph(graph, (a, b))
15888*da0073e9SAndroid Build Coastguard Worker        ref = a * b
15889*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(test, ref)
15890*da0073e9SAndroid Build Coastguard Worker
15891*da0073e9SAndroid Build Coastguard Worker    def test_signed_float_zero(self):
15892*da0073e9SAndroid Build Coastguard Worker
15893*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
15894*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
15895*da0073e9SAndroid Build Coastguard Worker                return torch.div(x, -0.)
15896*da0073e9SAndroid Build Coastguard Worker
15897*da0073e9SAndroid Build Coastguard Worker        inp = torch.ones(1)
15898*da0073e9SAndroid Build Coastguard Worker        self.checkModule(MyModule(), inp)
15899*da0073e9SAndroid Build Coastguard Worker
15900*da0073e9SAndroid Build Coastguard Worker    def test_index_with_tuple(self):
15901*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
15902*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
15903*da0073e9SAndroid Build Coastguard Worker                return x[(1,)]
15904*da0073e9SAndroid Build Coastguard Worker
15905*da0073e9SAndroid Build Coastguard Worker        self.checkModule(MyModule(), (torch.ones(2, 3),))
15906*da0073e9SAndroid Build Coastguard Worker
15907*da0073e9SAndroid Build Coastguard Worker    def test_context_manager(self):
15908*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
15909*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y):
15910*da0073e9SAndroid Build Coastguard Worker                p = x + y
15911*da0073e9SAndroid Build Coastguard Worker                q = p + 2.0
15912*da0073e9SAndroid Build Coastguard Worker                return q
15913*da0073e9SAndroid Build Coastguard Worker
15914*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 2, dtype=torch.float)
15915*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(3, 2, dtype=torch.float)
15916*da0073e9SAndroid Build Coastguard Worker        for fuser_name in ['fuser0', 'fuser1', 'none']:
15917*da0073e9SAndroid Build Coastguard Worker            with torch.jit.fuser(fuser_name):
15918*da0073e9SAndroid Build Coastguard Worker                self.checkModule(MyModule(), (x, y))
15919*da0073e9SAndroid Build Coastguard Worker
15920*da0073e9SAndroid Build Coastguard Worker    def test_zero_dimension_tensor_trace(self):
15921*da0073e9SAndroid Build Coastguard Worker        def f(x):
15922*da0073e9SAndroid Build Coastguard Worker            return x[x > 0]
15923*da0073e9SAndroid Build Coastguard Worker        jf = torch.jit.trace(f, torch.tensor(2., device="cpu"))
15924*da0073e9SAndroid Build Coastguard Worker
15925*da0073e9SAndroid Build Coastguard Worker# known to be failing in tracer
15926*da0073e9SAndroid Build Coastguard WorkerEXCLUDE_TRACED = {
15927*da0073e9SAndroid Build Coastguard Worker    # The following fail due to #12024.
15928*da0073e9SAndroid Build Coastguard Worker    # A prim::ListConstruct is involved and the indices get traced as TensorType,
15929*da0073e9SAndroid Build Coastguard Worker    # which always require_grad. This causes a crash in autodiff.
15930*da0073e9SAndroid Build Coastguard Worker    'test___getitem___adv_index',
15931*da0073e9SAndroid Build Coastguard Worker    'test___getitem___adv_index_beg',
15932*da0073e9SAndroid Build Coastguard Worker    'test___getitem___adv_index_comb',
15933*da0073e9SAndroid Build Coastguard Worker    'test___getitem___adv_index_dup',
15934*da0073e9SAndroid Build Coastguard Worker    'test___getitem___adv_index_sub',
15935*da0073e9SAndroid Build Coastguard Worker    'test___getitem___adv_index_sub_2',
15936*da0073e9SAndroid Build Coastguard Worker    'test___getitem___adv_index_sub_3',
15937*da0073e9SAndroid Build Coastguard Worker    'test___getitem___adv_index_var',
15938*da0073e9SAndroid Build Coastguard Worker
15939*da0073e9SAndroid Build Coastguard Worker    # jit doesn't support sparse tensors.
15940*da0073e9SAndroid Build Coastguard Worker    'test_to_sparse',
15941*da0073e9SAndroid Build Coastguard Worker    'test_to_sparse_dim',
15942*da0073e9SAndroid Build Coastguard Worker}
15943*da0073e9SAndroid Build Coastguard Worker
15944*da0073e9SAndroid Build Coastguard WorkerEXCLUDE_TYPE_CHECK = {
15945*da0073e9SAndroid Build Coastguard Worker    # slogdet tests use itemgetter to select its only differentiable output,
15946*da0073e9SAndroid Build Coastguard Worker    # but this happens outside of the graph we handle, so there are fewer
15947*da0073e9SAndroid Build Coastguard Worker    # reference outputs than graph outputs.
15948*da0073e9SAndroid Build Coastguard Worker    'test_slogdet_1x1_neg_det',
15949*da0073e9SAndroid Build Coastguard Worker    'test_slogdet_1x1_pos_det',
15950*da0073e9SAndroid Build Coastguard Worker    'test_slogdet_distinct_singular_values',
15951*da0073e9SAndroid Build Coastguard Worker    'test_slogdet_neg_det',
15952*da0073e9SAndroid Build Coastguard Worker    'test_slogdet_pos_det',
15953*da0073e9SAndroid Build Coastguard Worker    'test_slogdet_symmetric',
15954*da0073e9SAndroid Build Coastguard Worker    'test_slogdet_symmetric_pd',
15955*da0073e9SAndroid Build Coastguard Worker    'test_slogdet_batched_1x1_neg_det',
15956*da0073e9SAndroid Build Coastguard Worker    'test_slogdet_batched_pos_det',
15957*da0073e9SAndroid Build Coastguard Worker    'test_slogdet_batched_symmetric',
15958*da0073e9SAndroid Build Coastguard Worker    'test_slogdet_batched_symmetric_pd',
15959*da0073e9SAndroid Build Coastguard Worker    'test_slogdet_batched_distinct_singular_values'
15960*da0073e9SAndroid Build Coastguard Worker}
15961*da0073e9SAndroid Build Coastguard Worker
15962*da0073e9SAndroid Build Coastguard Worker# chunk returns a list in scripting and we don't unpack the list,
15963*da0073e9SAndroid Build Coastguard Worker# Thus it won't be replaced by ConstantChunk and run AD.
15964*da0073e9SAndroid Build Coastguard Worker# It's explicitly checked in test_chunk_constant_script_ad
15965*da0073e9SAndroid Build Coastguard Worker# Similary for split, it's replaced by split_with_sizes in tracing,
15966*da0073e9SAndroid Build Coastguard Worker# but we don't have AD formula for aten::split(Tensor, int[], int),
15967*da0073e9SAndroid Build Coastguard Worker# an op registered in JIT so AD is not triggered in scripting.
15968*da0073e9SAndroid Build Coastguard WorkerEXCLUDE_SCRIPT_AD_CHECK = {
15969*da0073e9SAndroid Build Coastguard Worker    'test_chunk',
15970*da0073e9SAndroid Build Coastguard Worker    'test_chunk_dim',
15971*da0073e9SAndroid Build Coastguard Worker    'test_chunk_dim_neg0',
15972*da0073e9SAndroid Build Coastguard Worker    'test_split_size_list',
15973*da0073e9SAndroid Build Coastguard Worker    'test_split_size_list_dim',
15974*da0073e9SAndroid Build Coastguard Worker    'test_split_size_list_dim_neg0',
15975*da0073e9SAndroid Build Coastguard Worker    'test_tensor_indices_sections',
15976*da0073e9SAndroid Build Coastguard Worker    'test_tensor_indices_sections_dim',
15977*da0073e9SAndroid Build Coastguard Worker    'test_tensor_indices_sections_dim_neg0',
15978*da0073e9SAndroid Build Coastguard Worker    'test_tensor_split_sections',
15979*da0073e9SAndroid Build Coastguard Worker    'test_tensor_split_sections_dim',
15980*da0073e9SAndroid Build Coastguard Worker    'test_tensor_split_sections_dim_neg0'
15981*da0073e9SAndroid Build Coastguard Worker}
15982*da0073e9SAndroid Build Coastguard Worker
15983*da0073e9SAndroid Build Coastguard WorkerEXCLUDE_PYTHON_PRINT = {
15984*da0073e9SAndroid Build Coastguard Worker    # no support for BroadcastingList in python printer
15985*da0073e9SAndroid Build Coastguard Worker    'test_nn_max_unpool1d',
15986*da0073e9SAndroid Build Coastguard Worker    'test_nn_max_unpool2d',
15987*da0073e9SAndroid Build Coastguard Worker    'test_nn_max_unpool3d',
15988*da0073e9SAndroid Build Coastguard Worker    'test_nn_max_pool1d',
15989*da0073e9SAndroid Build Coastguard Worker    'test_nn_max_pool2d',
15990*da0073e9SAndroid Build Coastguard Worker    'test_nn_max_pool3d',
15991*da0073e9SAndroid Build Coastguard Worker    'test_nn_max_pool1d_with_indices',
15992*da0073e9SAndroid Build Coastguard Worker}
15993*da0073e9SAndroid Build Coastguard Worker
15994*da0073e9SAndroid Build Coastguard WorkerEXCLUDE_ALIAS = {
15995*da0073e9SAndroid Build Coastguard Worker    # aliases, which may appear in method_tests but are tested elsewhere
15996*da0073e9SAndroid Build Coastguard Worker    'true_divide',
15997*da0073e9SAndroid Build Coastguard Worker
15998*da0073e9SAndroid Build Coastguard Worker    # Disable tests for lu from common_methods_invocations.py
15999*da0073e9SAndroid Build Coastguard Worker    # TODO(@nikitaved) Enable jit tests once autograd.Function does support scripting
16000*da0073e9SAndroid Build Coastguard Worker    'lu'
16001*da0073e9SAndroid Build Coastguard Worker}
16002*da0073e9SAndroid Build Coastguard Worker
16003*da0073e9SAndroid Build Coastguard Worker
16004*da0073e9SAndroid Build Coastguard Workerclass TestJitGeneratedModule(JitTestCase):
16005*da0073e9SAndroid Build Coastguard Worker    pass
16006*da0073e9SAndroid Build Coastguard Worker
16007*da0073e9SAndroid Build Coastguard Worker
16008*da0073e9SAndroid Build Coastguard Workerclass TestJitGeneratedFunctional(JitTestCase):
16009*da0073e9SAndroid Build Coastguard Worker    pass
16010*da0073e9SAndroid Build Coastguard Worker
16011*da0073e9SAndroid Build Coastguard Worker# UBSAN per-function exclusions don't seem to work with OpenMP pragmas,
16012*da0073e9SAndroid Build Coastguard Worker# and we have to disable the failing tests here instead.
16013*da0073e9SAndroid Build Coastguard WorkerUBSAN_DISABLED_TESTS = [
16014*da0073e9SAndroid Build Coastguard Worker    "test___rdiv___constant",
16015*da0073e9SAndroid Build Coastguard Worker    "test___rdiv___scalar_constant",
16016*da0073e9SAndroid Build Coastguard Worker    "test_addcdiv",
16017*da0073e9SAndroid Build Coastguard Worker    "test_addcdiv_broadcast_all",
16018*da0073e9SAndroid Build Coastguard Worker    "test_addcdiv_broadcast_rhs",
16019*da0073e9SAndroid Build Coastguard Worker    "test_addcdiv_scalar",
16020*da0073e9SAndroid Build Coastguard Worker    "test_addcdiv_scalar_broadcast_lhs",
16021*da0073e9SAndroid Build Coastguard Worker    "test_addcdiv_scalar_broadcast_rhs",
16022*da0073e9SAndroid Build Coastguard Worker    "test_addcdiv_scalar_scale",
16023*da0073e9SAndroid Build Coastguard Worker    "test_addcdiv_scalar_scale_broadcast_lhs",
16024*da0073e9SAndroid Build Coastguard Worker    "test_addcdiv_scalar_scale_broadcast_rhs",
16025*da0073e9SAndroid Build Coastguard Worker    "test_addcdiv_scale",
16026*da0073e9SAndroid Build Coastguard Worker    "test_addcdiv_scale_broadcast_all",
16027*da0073e9SAndroid Build Coastguard Worker    "test_addcdiv_scale_broadcast_rhs",
16028*da0073e9SAndroid Build Coastguard Worker    "test_add_broadcast_all",
16029*da0073e9SAndroid Build Coastguard Worker    "test_add_broadcast_lhs",
16030*da0073e9SAndroid Build Coastguard Worker    "test_add_broadcast_rhs",
16031*da0073e9SAndroid Build Coastguard Worker    "test_add_constant",
16032*da0073e9SAndroid Build Coastguard Worker    "test_add_scalar",
16033*da0073e9SAndroid Build Coastguard Worker    "test_add_scalar_broadcast_lhs",
16034*da0073e9SAndroid Build Coastguard Worker    "test_add_scalar_broadcast_rhs",
16035*da0073e9SAndroid Build Coastguard Worker    "test_div",
16036*da0073e9SAndroid Build Coastguard Worker    "test_div_broadcast_all",
16037*da0073e9SAndroid Build Coastguard Worker    "test_div_broadcast_lhs",
16038*da0073e9SAndroid Build Coastguard Worker    "test_div_broadcast_rhs",
16039*da0073e9SAndroid Build Coastguard Worker    "test_div_scalar",
16040*da0073e9SAndroid Build Coastguard Worker    "test_div_scalar_broadcast_lhs",
16041*da0073e9SAndroid Build Coastguard Worker    "test_div_scalar_broadcast_rhs",
16042*da0073e9SAndroid Build Coastguard Worker    "test_rsqrt",
16043*da0073e9SAndroid Build Coastguard Worker    "test_rsqrt_scalar",
16044*da0073e9SAndroid Build Coastguard Worker    "test_add",
16045*da0073e9SAndroid Build Coastguard Worker    "test_reciprocal",
16046*da0073e9SAndroid Build Coastguard Worker    "test_reciprocal_scalar",
16047*da0073e9SAndroid Build Coastguard Worker]
16048*da0073e9SAndroid Build Coastguard Worker
16049*da0073e9SAndroid Build Coastguard WorkerL = 20
16050*da0073e9SAndroid Build Coastguard WorkerM = 10
16051*da0073e9SAndroid Build Coastguard WorkerS = 5
16052*da0073e9SAndroid Build Coastguard Worker
16053*da0073e9SAndroid Build Coastguard Workerdef add_nn_module_test(*args, **kwargs):
16054*da0073e9SAndroid Build Coastguard Worker    no_grad = False if 'no_grad' not in kwargs else kwargs['no_grad']
16055*da0073e9SAndroid Build Coastguard Worker
16056*da0073e9SAndroid Build Coastguard Worker    if 'desc' in kwargs and 'eval' in kwargs['desc']:
16057*da0073e9SAndroid Build Coastguard Worker        # eval() is not supported, so skip these tests
16058*da0073e9SAndroid Build Coastguard Worker        return
16059*da0073e9SAndroid Build Coastguard Worker
16060*da0073e9SAndroid Build Coastguard Worker    test_name = get_nn_mod_test_name(**kwargs)
16061*da0073e9SAndroid Build Coastguard Worker
16062*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
16063*da0073e9SAndroid Build Coastguard Worker    def do_test(self):
16064*da0073e9SAndroid Build Coastguard Worker        if test_name in EXCLUDE_SCRIPT_MODULES:
16065*da0073e9SAndroid Build Coastguard Worker            return
16066*da0073e9SAndroid Build Coastguard Worker        if not kwargs.get('check_jit', True):
16067*da0073e9SAndroid Build Coastguard Worker            raise unittest.SkipTest('module test skipped on JIT')
16068*da0073e9SAndroid Build Coastguard Worker
16069*da0073e9SAndroid Build Coastguard Worker        default_dtype = torch.get_default_dtype()
16070*da0073e9SAndroid Build Coastguard Worker        if 'default_dtype' in kwargs and kwargs['default_dtype'] is not None:
16071*da0073e9SAndroid Build Coastguard Worker            default_dtype = kwargs['default_dtype']
16072*da0073e9SAndroid Build Coastguard Worker
16073*da0073e9SAndroid Build Coastguard Worker        module_name = get_nn_module_name_from_kwargs(**kwargs)
16074*da0073e9SAndroid Build Coastguard Worker
16075*da0073e9SAndroid Build Coastguard Worker        if 'constructor' in kwargs:
16076*da0073e9SAndroid Build Coastguard Worker            nn_module = kwargs['constructor']
16077*da0073e9SAndroid Build Coastguard Worker        else:
16078*da0073e9SAndroid Build Coastguard Worker            nn_module = getattr(torch.nn, module_name)
16079*da0073e9SAndroid Build Coastguard Worker
16080*da0073e9SAndroid Build Coastguard Worker        if "FunctionalModule" in str(nn_module):
16081*da0073e9SAndroid Build Coastguard Worker            return
16082*da0073e9SAndroid Build Coastguard Worker
16083*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(default_dtype):
16084*da0073e9SAndroid Build Coastguard Worker            if 'constructor_args_fn' in kwargs:
16085*da0073e9SAndroid Build Coastguard Worker                constructor_args = kwargs['constructor_args_fn']()
16086*da0073e9SAndroid Build Coastguard Worker            else:
16087*da0073e9SAndroid Build Coastguard Worker                constructor_args = kwargs.get('constructor_args', ())
16088*da0073e9SAndroid Build Coastguard Worker
16089*da0073e9SAndroid Build Coastguard Worker            def create_script_module(*args, **kwargs):
16090*da0073e9SAndroid Build Coastguard Worker                """Construct a script module that passes arguments through to self.submodule"""
16091*da0073e9SAndroid Build Coastguard Worker                formals, tensors, actuals = get_script_args(args)
16092*da0073e9SAndroid Build Coastguard Worker
16093*da0073e9SAndroid Build Coastguard Worker                method_args = ', '.join(['self'] + actuals)
16094*da0073e9SAndroid Build Coastguard Worker                call_args_str = ', '.join(actuals)
16095*da0073e9SAndroid Build Coastguard Worker                call = f"self.submodule({call_args_str})"
16096*da0073e9SAndroid Build Coastguard Worker                script = script_method_template.format(method_args, call)
16097*da0073e9SAndroid Build Coastguard Worker
16098*da0073e9SAndroid Build Coastguard Worker                submodule_constants = []
16099*da0073e9SAndroid Build Coastguard Worker                if kwargs.get('is_constant'):
16100*da0073e9SAndroid Build Coastguard Worker                    submodule_constants = ['submodule']
16101*da0073e9SAndroid Build Coastguard Worker
16102*da0073e9SAndroid Build Coastguard Worker                # Create module to use the script method
16103*da0073e9SAndroid Build Coastguard Worker                class TheModule(torch.jit.ScriptModule):
16104*da0073e9SAndroid Build Coastguard Worker                    __constants__ = submodule_constants
16105*da0073e9SAndroid Build Coastguard Worker
16106*da0073e9SAndroid Build Coastguard Worker                    def __init__(self) -> None:
16107*da0073e9SAndroid Build Coastguard Worker                        super().__init__()
16108*da0073e9SAndroid Build Coastguard Worker                        self.submodule = nn_module(*constructor_args)
16109*da0073e9SAndroid Build Coastguard Worker
16110*da0073e9SAndroid Build Coastguard Worker                def make_module(script):
16111*da0073e9SAndroid Build Coastguard Worker                    module = TheModule()
16112*da0073e9SAndroid Build Coastguard Worker                    # check __repr__
16113*da0073e9SAndroid Build Coastguard Worker                    str(module)
16114*da0073e9SAndroid Build Coastguard Worker                    module.define(script)
16115*da0073e9SAndroid Build Coastguard Worker                    return module
16116*da0073e9SAndroid Build Coastguard Worker
16117*da0073e9SAndroid Build Coastguard Worker                module = make_module(script)
16118*da0073e9SAndroid Build Coastguard Worker                self.assertExportImportModule(module, tensors)
16119*da0073e9SAndroid Build Coastguard Worker                create_script_module.last_graph = module.graph
16120*da0073e9SAndroid Build Coastguard Worker                mod = module(*args)
16121*da0073e9SAndroid Build Coastguard Worker                return mod
16122*da0073e9SAndroid Build Coastguard Worker
16123*da0073e9SAndroid Build Coastguard Worker            # Construct a normal nn module to stay consistent with create_script_module
16124*da0073e9SAndroid Build Coastguard Worker            # and make use of a single global rng_state in module initialization
16125*da0073e9SAndroid Build Coastguard Worker            def create_nn_module(*args, **kwargs):
16126*da0073e9SAndroid Build Coastguard Worker                module = nn_module(*constructor_args)
16127*da0073e9SAndroid Build Coastguard Worker                return module(*args)
16128*da0073e9SAndroid Build Coastguard Worker
16129*da0073e9SAndroid Build Coastguard Worker            # Set up inputs from tuple of sizes or constructor fn
16130*da0073e9SAndroid Build Coastguard Worker            dtype = torch.get_default_dtype()
16131*da0073e9SAndroid Build Coastguard Worker            if 'input_fn' in kwargs:
16132*da0073e9SAndroid Build Coastguard Worker                input = kwargs['input_fn']()
16133*da0073e9SAndroid Build Coastguard Worker                if isinstance(input, Tensor):
16134*da0073e9SAndroid Build Coastguard Worker                    input = (input,)
16135*da0073e9SAndroid Build Coastguard Worker
16136*da0073e9SAndroid Build Coastguard Worker                if all(tensor.is_complex() for tensor in input):
16137*da0073e9SAndroid Build Coastguard Worker                    if dtype == torch.float:
16138*da0073e9SAndroid Build Coastguard Worker                        dtype = torch.cfloat
16139*da0073e9SAndroid Build Coastguard Worker                    elif dtype == torch.double:
16140*da0073e9SAndroid Build Coastguard Worker                        dtype = torch.cdouble
16141*da0073e9SAndroid Build Coastguard Worker                    else:
16142*da0073e9SAndroid Build Coastguard Worker                        raise AssertionError(f"default_dtype {default_dtype} is not supported")
16143*da0073e9SAndroid Build Coastguard Worker
16144*da0073e9SAndroid Build Coastguard Worker            else:
16145*da0073e9SAndroid Build Coastguard Worker                input = (kwargs['input_size'],)
16146*da0073e9SAndroid Build Coastguard Worker
16147*da0073e9SAndroid Build Coastguard Worker            if 'target_size' in kwargs:
16148*da0073e9SAndroid Build Coastguard Worker                input = input + (kwargs['target_size'],)
16149*da0073e9SAndroid Build Coastguard Worker            elif 'target_fn' in kwargs:
16150*da0073e9SAndroid Build Coastguard Worker                if torch.is_tensor(input):
16151*da0073e9SAndroid Build Coastguard Worker                    input = (input,)
16152*da0073e9SAndroid Build Coastguard Worker                input = input + (kwargs['target_fn'](),)
16153*da0073e9SAndroid Build Coastguard Worker            elif 'target' in kwargs:
16154*da0073e9SAndroid Build Coastguard Worker                input = input + (kwargs['target'],)
16155*da0073e9SAndroid Build Coastguard Worker
16156*da0073e9SAndroid Build Coastguard Worker            # Extra parameters to forward()
16157*da0073e9SAndroid Build Coastguard Worker            if 'extra_args' in kwargs:
16158*da0073e9SAndroid Build Coastguard Worker                input = input + kwargs['extra_args']
16159*da0073e9SAndroid Build Coastguard Worker
16160*da0073e9SAndroid Build Coastguard Worker            args_variable, kwargs_variable = create_input(input, dtype=dtype)
16161*da0073e9SAndroid Build Coastguard Worker            f_args_variable = deepcopy(unpack_variables(args_variable))
16162*da0073e9SAndroid Build Coastguard Worker
16163*da0073e9SAndroid Build Coastguard Worker            # TODO(issue#52052) Neither this nor no_grad should be required
16164*da0073e9SAndroid Build Coastguard Worker            # if check_against_reference() is updated to check gradients
16165*da0073e9SAndroid Build Coastguard Worker            # w.r.t. weights and then only check w.r.t. inputs if any
16166*da0073e9SAndroid Build Coastguard Worker            # inputs require it.
16167*da0073e9SAndroid Build Coastguard Worker            any_requires_grad = any(input.requires_grad for input in f_args_variable)
16168*da0073e9SAndroid Build Coastguard Worker
16169*da0073e9SAndroid Build Coastguard Worker            # Check against Python module as reference
16170*da0073e9SAndroid Build Coastguard Worker            check_against_reference(self, create_script_module, create_nn_module,
16171*da0073e9SAndroid Build Coastguard Worker                                    lambda x: x, f_args_variable,
16172*da0073e9SAndroid Build Coastguard Worker                                    no_grad=no_grad or not any_requires_grad)
16173*da0073e9SAndroid Build Coastguard Worker
16174*da0073e9SAndroid Build Coastguard Worker    if 'slowTest' in kwargs:
16175*da0073e9SAndroid Build Coastguard Worker        do_test = slowTest(do_test)
16176*da0073e9SAndroid Build Coastguard Worker
16177*da0073e9SAndroid Build Coastguard Worker    post_add_test(test_name, (), do_test, TestJitGeneratedModule)
16178*da0073e9SAndroid Build Coastguard Worker
16179*da0073e9SAndroid Build Coastguard Worker
16180*da0073e9SAndroid Build Coastguard Workerdef post_add_test(test_name, skipTestIf, do_test, test_class):
16181*da0073e9SAndroid Build Coastguard Worker    assert not hasattr(test_class, test_name), 'Two tests have the same name: ' + test_name
16182*da0073e9SAndroid Build Coastguard Worker
16183*da0073e9SAndroid Build Coastguard Worker    for skip in skipTestIf:
16184*da0073e9SAndroid Build Coastguard Worker        do_test = skip(do_test)
16185*da0073e9SAndroid Build Coastguard Worker
16186*da0073e9SAndroid Build Coastguard Worker    if not (TEST_WITH_UBSAN and test_name in UBSAN_DISABLED_TESTS):
16187*da0073e9SAndroid Build Coastguard Worker        setattr(test_class, test_name, do_test)
16188*da0073e9SAndroid Build Coastguard Worker
16189*da0073e9SAndroid Build Coastguard Worker
16190*da0073e9SAndroid Build Coastguard Workerdef normalize_check_ad(check_ad, name):
16191*da0073e9SAndroid Build Coastguard Worker    # normalized check_ad is 3-element tuple: (bool, List[str], List[str])
16192*da0073e9SAndroid Build Coastguard Worker    if len(check_ad) == 0:
16193*da0073e9SAndroid Build Coastguard Worker        check_ad = [False, ['aten::' + name], []]
16194*da0073e9SAndroid Build Coastguard Worker    elif len(check_ad) == 1:
16195*da0073e9SAndroid Build Coastguard Worker        check_ad = [check_ad[0], ['aten::' + name], []]
16196*da0073e9SAndroid Build Coastguard Worker    elif len(check_ad) == 2:
16197*da0073e9SAndroid Build Coastguard Worker        check_ad = [check_ad[0], check_ad[1], []]
16198*da0073e9SAndroid Build Coastguard Worker    elif len(check_ad) == 3:
16199*da0073e9SAndroid Build Coastguard Worker        check_ad = list(check_ad)
16200*da0073e9SAndroid Build Coastguard Worker    else:
16201*da0073e9SAndroid Build Coastguard Worker        raise Exception('Invalid check_ad, requires (bool, str|List[str], str|List[str])')  # noqa: TRY002
16202*da0073e9SAndroid Build Coastguard Worker
16203*da0073e9SAndroid Build Coastguard Worker    check_ad = [[t] if isinstance(t, str) else t for t in check_ad]
16204*da0073e9SAndroid Build Coastguard Worker
16205*da0073e9SAndroid Build Coastguard Worker    return check_ad
16206*da0073e9SAndroid Build Coastguard Worker
16207*da0073e9SAndroid Build Coastguard Worker
16208*da0073e9SAndroid Build Coastguard Workerclass TestProducerVersion(TestCase):
16209*da0073e9SAndroid Build Coastguard Worker
16210*da0073e9SAndroid Build Coastguard Worker    def test_version(self):
16211*da0073e9SAndroid Build Coastguard Worker        # issue gh-32561
16212*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.__version__.startswith(torch.onnx.producer_version))
16213*da0073e9SAndroid Build Coastguard Worker
16214*da0073e9SAndroid Build Coastguard Workerfor test in module_tests + new_module_tests + additional_module_tests:
16215*da0073e9SAndroid Build Coastguard Worker    add_nn_module_test(**test)
16216*da0073e9SAndroid Build Coastguard Worker
16217*da0073e9SAndroid Build Coastguard Workerfor test in criterion_tests:
16218*da0073e9SAndroid Build Coastguard Worker    test['no_grad'] = True
16219*da0073e9SAndroid Build Coastguard Worker    add_nn_module_test(**test)
16220*da0073e9SAndroid Build Coastguard Worker
16221*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__':
16222*da0073e9SAndroid Build Coastguard Worker    TestCase._default_dtype_check_enabled = True
16223*da0073e9SAndroid Build Coastguard Worker    run_tests()
16224*da0073e9SAndroid Build Coastguard Worker    import jit.test_module_interface
16225*da0073e9SAndroid Build Coastguard Worker    suite = unittest.findTestCases(jit.test_module_interface)
16226*da0073e9SAndroid Build Coastguard Worker    unittest.TextTestRunner().run(suite)
16227